Commit fc198a17 authored by Martin Bauer's avatar Martin Bauer
Browse files

Further vectorization tests & bugfixes

- phasefield phi sweep vectorizes successfully
parent 9d1e022d
...@@ -259,6 +259,24 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -259,6 +259,24 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
processed = func.format(processed, summand.term) processed = func.format(processed, summand.term)
return processed return processed
def _print_Pow(self, expr):
"""Don't use std::pow function, for small integer exponents, write as multiplication"""
exprType = getTypeOfExpression(expr)
if type(exprType) is not VectorType:
return super(VectorizedCustomSympyPrinter, self)._print_Pow(expr)
assert self.instructionSet['width'] == exprType.width
if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
else:
if expr.exp == -1:
one = self.instructionSet['makeVec'].format(1.0)
return self.instructionSet['/'].format(one, self._print(expr.base))
elif expr.exp == 0.5:
return self.instructionSet['sqrt'].format(self._print(expr.base))
else:
raise ValueError("Generic exponential not supported")
def _print_Mul(self, expr, insideAdd=False): def _print_Mul(self, expr, insideAdd=False):
exprType = getTypeOfExpression(expr) exprType = getTypeOfExpression(expr)
if type(exprType) is not VectorType: if type(exprType) is not VectorType:
...@@ -286,6 +304,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -286,6 +304,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
a.append(item) a.append(item)
a = a or [S.One] a = a or [S.One]
# a = a or [castFunc(S.One, VectorType(createTypeFromString("double"), exprType.width))]
a_str = [self._print(x) for x in a] a_str = [self._print(x) for x in a]
b_str = [self._print(x) for x in b] b_str = [self._print(x) for x in b]
......
...@@ -10,7 +10,6 @@ from pystencils.utils import allEqual ...@@ -10,7 +10,6 @@ from pystencils.utils import allEqual
# to work in conditions of sp.Piecewise castFunc has to be of type Relational as well # to work in conditions of sp.Piecewise castFunc has to be of type Relational as well
class castFunc(sp.Function, sp.Rel): class castFunc(sp.Function, sp.Rel):
@property @property
def canonical(self): def canonical(self):
if hasattr(self.args[0], 'canonical'): if hasattr(self.args[0], 'canonical'):
...@@ -18,6 +17,10 @@ class castFunc(sp.Function, sp.Rel): ...@@ -18,6 +17,10 @@ class castFunc(sp.Function, sp.Rel):
else: else:
raise NotImplementedError() raise NotImplementedError()
@property
def is_commutative(self):
return self.args[0].is_commutative
class pointerArithmeticFunc(sp.Function, sp.Rel): class pointerArithmeticFunc(sp.Function, sp.Rel):
...@@ -281,8 +284,11 @@ def getTypeOfExpression(expr): ...@@ -281,8 +284,11 @@ def getTypeOfExpression(expr):
elif hasattr(expr, 'func') and expr.func == castFunc: elif hasattr(expr, 'func') and expr.func == castFunc:
return expr.args[1] return expr.args[1]
elif hasattr(expr, 'func') and expr.func == sp.Piecewise: elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
branchResults = [a[0] for a in expr.args] collatedResultType = collateTypes(tuple(getTypeOfExpression(a[0]) for a in expr.args))
return collateTypes(tuple(getTypeOfExpression(a) for a in branchResults)) collatedConditionType = collateTypes(tuple(getTypeOfExpression(a[1]) for a in expr.args))
if type(collatedConditionType) is VectorType and type(collatedResultType) is not VectorType:
collatedResultType = VectorType(collatedResultType, width=collatedConditionType.width)
return collatedResultType
elif isinstance(expr, sp.Indexed): elif isinstance(expr, sp.Indexed):
typedSymbol = expr.base.label typedSymbol = expr.base.label
return typedSymbol.dtype.baseType return typedSymbol.dtype.baseType
...@@ -328,6 +334,9 @@ class Type(sp.Basic): ...@@ -328,6 +334,9 @@ class Type(sp.Basic):
def _sympystr(self, *args, **kwargs): def _sympystr(self, *args, **kwargs):
return str(self) return str(self)
def _sympystr(self, *args, **kwargs):
return str(self)
class BasicType(Type): class BasicType(Type):
@staticmethod @staticmethod
......
...@@ -70,6 +70,9 @@ def insertVectorCasts(astNode): ...@@ -70,6 +70,9 @@ def insertVectorCasts(astNode):
castedArgs = [castFunc(a, targetType) if t != targetType else a castedArgs = [castFunc(a, targetType) if t != targetType else a
for a, t in zip(newArgs, argTypes)] for a, t in zip(newArgs, argTypes)]
return expr.func(*castedArgs) return expr.func(*castedArgs)
elif expr.func is sp.Pow:
newArg = visitExpr(expr.args[0])
return sp.Pow(newArg, expr.args[1])
elif expr.func == sp.Piecewise: elif expr.func == sp.Piecewise:
newResults = [visitExpr(a[0]) for a in expr.args] newResults = [visitExpr(a[0]) for a in expr.args]
newConditions = [visitExpr(a[1]) for a in expr.args] newConditions = [visitExpr(a[1]) for a in expr.args]
...@@ -77,10 +80,13 @@ def insertVectorCasts(astNode): ...@@ -77,10 +80,13 @@ def insertVectorCasts(astNode):
typesOfConditions = [getTypeOfExpression(a) for a in newConditions] typesOfConditions = [getTypeOfExpression(a) for a in newConditions]
resultTargetType = getTypeOfExpression(expr) resultTargetType = getTypeOfExpression(expr)
conditionTargetType = collateTypes(typesOfConditions)
if type(conditionTargetType) is VectorType and type(resultTargetType) is not VectorType:
resultTargetType = VectorType(resultTargetType, width=conditionTargetType.width)
castedResults = [castFunc(a, resultTargetType) if t != resultTargetType else a castedResults = [castFunc(a, resultTargetType) if t != resultTargetType else a
for a, t in zip(newResults, typesOfResults)] for a, t in zip(newResults, typesOfResults)]
conditionTargetType = collateTypes(typesOfConditions)
castedConditions = [castFunc(a, conditionTargetType) if t != conditionTargetType and a != True else a castedConditions = [castFunc(a, conditionTargetType) if t != conditionTargetType and a != True else a
for a, t in zip(newConditions, typesOfConditions)] for a, t in zip(newConditions, typesOfConditions)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment