From fc198a1751f8f27c2982b4bcaf2df6100c7e7466 Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Wed, 11 Oct 2017 16:06:03 +0200 Subject: [PATCH] Further vectorization tests & bugfixes - phasefield phi sweep vectorizes successfully --- backends/cbackend.py | 19 +++++++++++++++++++ data_types.py | 15 ++++++++++++--- vectorization.py | 8 +++++++- 3 files changed, 38 insertions(+), 4 deletions(-) diff --git a/backends/cbackend.py b/backends/cbackend.py index 79872b7ca..86f5597cc 100644 --- a/backends/cbackend.py +++ b/backends/cbackend.py @@ -259,6 +259,24 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): processed = func.format(processed, summand.term) return processed + def _print_Pow(self, expr): + """Don't use std::pow function, for small integer exponents, write as multiplication""" + exprType = getTypeOfExpression(expr) + if type(exprType) is not VectorType: + return super(VectorizedCustomSympyPrinter, self)._print_Pow(expr) + assert self.instructionSet['width'] == exprType.width + + if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: + return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")" + else: + if expr.exp == -1: + one = self.instructionSet['makeVec'].format(1.0) + return self.instructionSet['/'].format(one, self._print(expr.base)) + elif expr.exp == 0.5: + return self.instructionSet['sqrt'].format(self._print(expr.base)) + else: + raise ValueError("Generic exponential not supported") + def _print_Mul(self, expr, insideAdd=False): exprType = getTypeOfExpression(expr) if type(exprType) is not VectorType: @@ -286,6 +304,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): a.append(item) a = a or [S.One] + # a = a or [castFunc(S.One, VectorType(createTypeFromString("double"), exprType.width))] a_str = [self._print(x) for x in a] b_str = [self._print(x) for x in b] diff --git a/data_types.py b/data_types.py index 6f766b7aa..75bc673e4 100644 --- a/data_types.py +++ b/data_types.py @@ -10,7 +10,6 @@ from pystencils.utils import allEqual # to work in conditions of sp.Piecewise castFunc has to be of type Relational as well class castFunc(sp.Function, sp.Rel): - @property def canonical(self): if hasattr(self.args[0], 'canonical'): @@ -18,6 +17,10 @@ class castFunc(sp.Function, sp.Rel): else: raise NotImplementedError() + @property + def is_commutative(self): + return self.args[0].is_commutative + class pointerArithmeticFunc(sp.Function, sp.Rel): @@ -281,8 +284,11 @@ def getTypeOfExpression(expr): elif hasattr(expr, 'func') and expr.func == castFunc: return expr.args[1] elif hasattr(expr, 'func') and expr.func == sp.Piecewise: - branchResults = [a[0] for a in expr.args] - return collateTypes(tuple(getTypeOfExpression(a) for a in branchResults)) + collatedResultType = collateTypes(tuple(getTypeOfExpression(a[0]) for a in expr.args)) + collatedConditionType = collateTypes(tuple(getTypeOfExpression(a[1]) for a in expr.args)) + if type(collatedConditionType) is VectorType and type(collatedResultType) is not VectorType: + collatedResultType = VectorType(collatedResultType, width=collatedConditionType.width) + return collatedResultType elif isinstance(expr, sp.Indexed): typedSymbol = expr.base.label return typedSymbol.dtype.baseType @@ -328,6 +334,9 @@ class Type(sp.Basic): def _sympystr(self, *args, **kwargs): return str(self) + def _sympystr(self, *args, **kwargs): + return str(self) + class BasicType(Type): @staticmethod diff --git a/vectorization.py b/vectorization.py index 979666d79..9cf113c5e 100644 --- a/vectorization.py +++ b/vectorization.py @@ -70,6 +70,9 @@ def insertVectorCasts(astNode): castedArgs = [castFunc(a, targetType) if t != targetType else a for a, t in zip(newArgs, argTypes)] return expr.func(*castedArgs) + elif expr.func is sp.Pow: + newArg = visitExpr(expr.args[0]) + return sp.Pow(newArg, expr.args[1]) elif expr.func == sp.Piecewise: newResults = [visitExpr(a[0]) for a in expr.args] newConditions = [visitExpr(a[1]) for a in expr.args] @@ -77,10 +80,13 @@ def insertVectorCasts(astNode): typesOfConditions = [getTypeOfExpression(a) for a in newConditions] resultTargetType = getTypeOfExpression(expr) + conditionTargetType = collateTypes(typesOfConditions) + if type(conditionTargetType) is VectorType and type(resultTargetType) is not VectorType: + resultTargetType = VectorType(resultTargetType, width=conditionTargetType.width) + castedResults = [castFunc(a, resultTargetType) if t != resultTargetType else a for a, t in zip(newResults, typesOfResults)] - conditionTargetType = collateTypes(typesOfConditions) castedConditions = [castFunc(a, conditionTargetType) if t != conditionTargetType and a != True else a for a, t in zip(newConditions, typesOfConditions)] -- GitLab