From 2e015cf51e822cb419a4f4e4faa8a8a4b0ae9e0d Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Mon, 9 Oct 2017 14:47:45 +0200 Subject: [PATCH] Vectorization: Piecewise/blend support - sympys piecewise defined functions are mapped to blend instructions - cast function is now a class - several bugfixes --- astnodes.py | 8 ++-- backends/cbackend.py | 75 +++++++++++++++++++------------ backends/simd_instruction_sets.py | 3 +- cpu/cpujit.py | 2 +- types.py | 22 +++++++-- vectorization.py | 43 +++++++++--------- 6 files changed, 94 insertions(+), 59 deletions(-) diff --git a/astnodes.py b/astnodes.py index 16c1e9829..cdaf8fbc2 100644 --- a/astnodes.py +++ b/astnodes.py @@ -1,7 +1,7 @@ import sympy as sp from sympy.tensor import IndexedBase from pystencils.field import Field -from pystencils.types import TypedSymbol, createType, get_type_from_sympy, createTypeFromString +from pystencils.types import TypedSymbol, createType, get_type_from_sympy, createTypeFromString, castFunc class ResolvedFieldAccess(sp.Indexed): @@ -383,8 +383,8 @@ class SympyAssignment(Node): self._lhsSymbol = lhsSymbol self.rhs = rhsTerm self._isDeclaration = True - isCast = str(self._lhsSymbol.func).lower() == 'cast' if hasattr(self._lhsSymbol, "func") else False - if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, IndexedBase) or isCast: + isCast = self._lhsSymbol.func == castFunc + if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, sp.Indexed) or isCast: self._isDeclaration = False self._isConst = isConst @@ -396,7 +396,7 @@ class SympyAssignment(Node): def lhs(self, newValue): self._lhsSymbol = newValue self._isDeclaration = True - isCast = str(self._lhsSymbol.func).lower() == 'cast' if hasattr(self._lhsSymbol, "func") else False + isCast = self._lhsSymbol.func == castFunc if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, sp.Indexed) or isCast: self._isDeclaration = False diff --git a/backends/cbackend.py b/backends/cbackend.py index 5bbb9b787..98d4b9c23 100644 --- a/backends/cbackend.py +++ b/backends/cbackend.py @@ -39,6 +39,8 @@ def getHeaders(astNode): headers.update(getHeaders(a)) return headers + + # --------------------------------------- Backend Specific Nodes ------------------------------------------------------- @@ -75,6 +77,7 @@ class PrintNode(CustomCppCode): # ------------------------------------------- Printer ------------------------------------------------------------------ + class CBackend(object): def __init__(self, constantsAsFloats=False, sympyPrinter=None, signatureOnly=False, vectorInstructionSet=None): @@ -204,8 +207,7 @@ class CustomSympyPrinter(CCodePrinter): return res def _print_Function(self, expr): - name = str(expr.func).lower() - if name == 'cast': + if expr.func == castFunc: arg, type = expr.args return "*((%s)(& %s))" % (PointerType(type), self._print(arg)) else: @@ -220,8 +222,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): self.instructionSet = instructionSet def _print_Function(self, expr): - name = str(expr.func).lower() - if name == 'cast': + if expr.func == castFunc: arg, dtype = expr.args if type(dtype) is VectorType: if type(arg) is ResolvedFieldAccess: @@ -291,41 +292,57 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): result = a_str[0] for item in a_str[1:]: - result = self.intrinsics['*'].format(result, item) + result = self.instructionSet['*'].format(result, item) if len(b) > 0: denominator_str = b_str[0] for item in b_str[1:]: - denominator_str = self.intrinsics['*'].format(denominator_str, item) - result = self.intrinsics['/'].format(result, denominator_str) + denominator_str = self.instructionSet['*'].format(denominator_str, item) + result = self.instructionSet['/'].format(result, denominator_str) if insideAdd: return sign, result else: if sign < 0: - return self.intrinsics['*'].format(self._print(S.NegativeOne), result) + return self.instructionSet['*'].format(self._print(S.NegativeOne), result) else: return result -# def _print_Piecewise(self, expr): -# if expr.args[-1].cond != True: -# # We need the last conditional to be a True, otherwise the resulting -# # function may not return a result. -# raise ValueError("All Piecewise expressions must contain an " -# "(expr, True) statement to be used as a default " -# "condition. Without one, the generated " -# "expression may not evaluate to anything under " -# "some condition.") -# -# result = self._print(expr.args[-1][0]) -# for trueExpr, condition in reversed(expr.args[:-1]): -# result = self.intrinsics['blendv'].format(result, self._print(trueExpr), self._print(condition)) -# return result -# -# def _print_Relational(self, expr): -# return self.intrinsics[expr.rel_op].format(expr.lhs, expr.rhs) -# -# def _print_Equality(self, expr): -# return self.intrinsics['=='].format(self._print(expr.lhs), self._print(expr.rhs)) -# + def _print_Relational(self, expr): + exprType = getTypeOfExpression(expr) + if type(exprType) is not VectorType: + return super(VectorizedCustomSympyPrinter, self)._print_Relational(expr) + assert self.instructionSet['width'] == exprType.width + + return self.instructionSet[expr.rel_op].format(self._print(expr.lhs), self._print(expr.rhs)) + + def _print_Equality(self, expr): + exprType = getTypeOfExpression(expr) + if type(exprType) is not VectorType: + return super(VectorizedCustomSympyPrinter, self)._print_Equality(expr) + assert self.instructionSet['width'] == exprType.width + + return self.instructionSet['=='].format(self._print(expr.lhs), self._print(expr.rhs)) + + def _print_Piecewise(self, expr): + exprType = getTypeOfExpression(expr) + if type(exprType) is not VectorType: + return super(VectorizedCustomSympyPrinter, self)._print_Piecewise(expr) + assert self.instructionSet['width'] == exprType.width + + if expr.args[-1].cond != True: + # We need the last conditional to be a True, otherwise the resulting + # function may not return a result. + raise ValueError("All Piecewise expressions must contain an " + "(expr, True) statement to be used as a default " + "condition. Without one, the generated " + "expression may not evaluate to anything under " + "some condition.") + + result = self._print(expr.args[-1][0]) + for trueExpr, condition in reversed(expr.args[:-1]): + result = self.instructionSet['blendv'].format(result, self._print(trueExpr), self._print(condition)) + return result + + diff --git a/backends/simd_instruction_sets.py b/backends/simd_instruction_sets.py index 213b4c481..23109165b 100644 --- a/backends/simd_instruction_sets.py +++ b/backends/simd_instruction_sets.py @@ -66,7 +66,7 @@ def x86VectorInstructionSet(dataType='double', instructionSet='avx'): if arg in ('0', '1', '2', '3', '4', '5'): argString += "{" + arg + "}," else: - argString += arg + argString += arg + "," argString = argString[:-1] + ")" result[intrinsicId] = pre + "_" + name + "_" + suf + argString @@ -80,6 +80,7 @@ def x86VectorInstructionSet(dataType='double', instructionSet='avx'): result['double'] = "__m%dd" % (bitWidth,) result['float'] = "__m%d" % (bitWidth,) result['int'] = "__m%di" % (bitWidth,) + result['bool'] = "__m%dd" % (bitWidth,) result['headers'] = headers[instructionSet] return result diff --git a/cpu/cpujit.py b/cpu/cpujit.py index c5919fc6c..e893a7771 100644 --- a/cpu/cpujit.py +++ b/cpu/cpujit.py @@ -302,7 +302,7 @@ def runCompileStep(command): subprocess.check_output(command, env=compileEnvironment, stderr=subprocess.STDOUT, shell=shell) except subprocess.CalledProcessError as e: print(" ".join(command)) - print(e.output) + print(e.output.decode('utf8')) raise e diff --git a/types.py b/types.py index 23c487501..6442147cd 100644 --- a/types.py +++ b/types.py @@ -6,7 +6,16 @@ from sympy.core.cache import cacheit from pystencils.cache import memorycache from pystencils.utils import allEqual -castFunc = sp.Function("cast") + +# to work in conditions of sp.Piecewise castFunc has to be of type Relational as well +class castFunc(sp.Function, sp.Rel): + + @property + def canonical(self): + if hasattr(self.args[0], 'canonical'): + return self.args[0].canonical + else: + raise NotImplementedError() class TypedSymbol(sp.Symbol): @@ -202,11 +211,16 @@ def getTypeOfExpression(expr): elif isinstance(expr, sp.Indexed): typedSymbol = expr.base.label return typedSymbol.dtype + elif isinstance(expr, sp.boolalg.Boolean): + # if any arg is of vector type return a vector boolean, else return a normal scalar boolean + result = createTypeFromString("bool") + vecArgs = [getTypeOfExpression(a) for a in expr.args if isinstance(getTypeOfExpression(a), VectorType)] + if vecArgs: + result = VectorType(result, width=vecArgs[0].width) + return result elif isinstance(expr, sp.Expr): types = tuple(getTypeOfExpression(a) for a in expr.args) return collateTypes(types) - elif isinstance(expr, sp.boolalg.Boolean): - return createTypeFromString("bool") raise NotImplementedError("Could not determine type for " + str(expr)) @@ -344,6 +358,8 @@ class VectorType(Type): return self.instructionSet['double'] elif self.baseType == createTypeFromString("float"): return self.instructionSet['float'] + elif self.baseType == createTypeFromString("bool"): + return self.instructionSet['bool'] else: raise NotImplementedError() diff --git a/vectorization.py b/vectorization.py index 54a9819f7..c933a5fe0 100644 --- a/vectorization.py +++ b/vectorization.py @@ -2,26 +2,16 @@ import sympy as sp import warnings from pystencils.transformations import filteredTreeIteration -from pystencils.types import TypedSymbol, VectorType, PointerType, BasicType, getTypeOfExpression, castFunc +from pystencils.types import TypedSymbol, VectorType, BasicType, getTypeOfExpression, castFunc, collateTypes import pystencils.astnodes as ast -from pystencils.utils import allEqual -def asVectorType(resolvedFieldAccess, vectorizationWidth): - """Returns a new ResolvedFieldAccess that has a vector type""" - dtype = resolvedFieldAccess.typedSymbol.dtype - assert type(dtype) is PointerType - basicType = dtype.baseType - assert type(basicType) is BasicType, "Structs are not supported" - - newDtype = VectorType(basicType, vectorizationWidth) - newDtype = PointerType(newDtype, dtype.const, dtype.restrict) - newTypedSymbol = TypedSymbol(resolvedFieldAccess.typedSymbol.name, newDtype) - return ast.ResolvedFieldAccess(newTypedSymbol, resolvedFieldAccess.args[1], resolvedFieldAccess.field, - resolvedFieldAccess.offsets, resolvedFieldAccess.idxCoordinateValues) +def vectorize(astNode, vectorWidth=4): + vectorizeInnerLoopsAndAdaptLoadStores(astNode, vectorWidth) + insertVectorCasts(astNode) -def vectorize(astNode, vectorWidth=4): +def vectorizeInnerLoopsAndAdaptLoadStores(astNode, vectorWidth=4): """ Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type if - loop bounds are constant @@ -54,20 +44,31 @@ def insertVectorCasts(astNode): Inserts necessary casts from scalar values to vector values """ def visitExpr(expr): - if expr.func in (sp.Add, sp.Mul): + if expr.func in (sp.Add, sp.Mul) or (isinstance(expr, sp.Rel) and not expr.func == castFunc): newArgs = [visitExpr(a) for a in expr.args] argTypes = [getTypeOfExpression(a) for a in newArgs] if not any(type(t) is VectorType for t in argTypes): return expr else: - vectorWidths = [d.width for d in argTypes if type(d) is VectorType] - assert allEqual(vectorWidths), "Incompatible vector type widths" - vectorWidth = vectorWidths[0] - castedArgs = [castFunc(a, VectorType(t, vectorWidth)) if type(t) is not VectorType else a + targetType = collateTypes(argTypes) + castedArgs = [castFunc(a, targetType) if t != targetType else a for a, t in zip(newArgs, argTypes)] return expr.func(*castedArgs) elif expr.func == sp.Piecewise: - raise NotImplementedError() + newResults = [visitExpr(a[0]) for a in expr.args] + newConditions = [visitExpr(a[1]) for a in expr.args] + typesOfResults = [getTypeOfExpression(a) for a in newResults] + typesOfConditions = [getTypeOfExpression(a) for a in newConditions] + + resultTargetType = getTypeOfExpression(expr) + castedResults = [castFunc(a, resultTargetType) if t != resultTargetType else a + for a, t in zip(newResults, typesOfResults)] + + conditionTargetType = collateTypes(typesOfConditions) + castedConditions = [castFunc(a, conditionTargetType) if t != conditionTargetType and a != True else a + for a, t in zip(newConditions, typesOfConditions)] + + return sp.Piecewise(*[(r, c) for r, c in zip(castedResults, castedConditions)]) else: return expr -- GitLab