Skip to content
Snippets Groups Projects
Commit fc198a17 authored by Martin Bauer's avatar Martin Bauer
Browse files

Further vectorization tests & bugfixes

- phasefield phi sweep vectorizes successfully
parent 9d1e022d
Branches
Tags
No related merge requests found
......@@ -259,6 +259,24 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
processed = func.format(processed, summand.term)
return processed
def _print_Pow(self, expr):
"""Don't use std::pow function, for small integer exponents, write as multiplication"""
exprType = getTypeOfExpression(expr)
if type(exprType) is not VectorType:
return super(VectorizedCustomSympyPrinter, self)._print_Pow(expr)
assert self.instructionSet['width'] == exprType.width
if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
else:
if expr.exp == -1:
one = self.instructionSet['makeVec'].format(1.0)
return self.instructionSet['/'].format(one, self._print(expr.base))
elif expr.exp == 0.5:
return self.instructionSet['sqrt'].format(self._print(expr.base))
else:
raise ValueError("Generic exponential not supported")
def _print_Mul(self, expr, insideAdd=False):
exprType = getTypeOfExpression(expr)
if type(exprType) is not VectorType:
......@@ -286,6 +304,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
a.append(item)
a = a or [S.One]
# a = a or [castFunc(S.One, VectorType(createTypeFromString("double"), exprType.width))]
a_str = [self._print(x) for x in a]
b_str = [self._print(x) for x in b]
......
......@@ -10,7 +10,6 @@ from pystencils.utils import allEqual
# to work in conditions of sp.Piecewise castFunc has to be of type Relational as well
class castFunc(sp.Function, sp.Rel):
@property
def canonical(self):
if hasattr(self.args[0], 'canonical'):
......@@ -18,6 +17,10 @@ class castFunc(sp.Function, sp.Rel):
else:
raise NotImplementedError()
@property
def is_commutative(self):
return self.args[0].is_commutative
class pointerArithmeticFunc(sp.Function, sp.Rel):
......@@ -281,8 +284,11 @@ def getTypeOfExpression(expr):
elif hasattr(expr, 'func') and expr.func == castFunc:
return expr.args[1]
elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
branchResults = [a[0] for a in expr.args]
return collateTypes(tuple(getTypeOfExpression(a) for a in branchResults))
collatedResultType = collateTypes(tuple(getTypeOfExpression(a[0]) for a in expr.args))
collatedConditionType = collateTypes(tuple(getTypeOfExpression(a[1]) for a in expr.args))
if type(collatedConditionType) is VectorType and type(collatedResultType) is not VectorType:
collatedResultType = VectorType(collatedResultType, width=collatedConditionType.width)
return collatedResultType
elif isinstance(expr, sp.Indexed):
typedSymbol = expr.base.label
return typedSymbol.dtype.baseType
......@@ -328,6 +334,9 @@ class Type(sp.Basic):
def _sympystr(self, *args, **kwargs):
return str(self)
def _sympystr(self, *args, **kwargs):
return str(self)
class BasicType(Type):
@staticmethod
......
......@@ -70,6 +70,9 @@ def insertVectorCasts(astNode):
castedArgs = [castFunc(a, targetType) if t != targetType else a
for a, t in zip(newArgs, argTypes)]
return expr.func(*castedArgs)
elif expr.func is sp.Pow:
newArg = visitExpr(expr.args[0])
return sp.Pow(newArg, expr.args[1])
elif expr.func == sp.Piecewise:
newResults = [visitExpr(a[0]) for a in expr.args]
newConditions = [visitExpr(a[1]) for a in expr.args]
......@@ -77,10 +80,13 @@ def insertVectorCasts(astNode):
typesOfConditions = [getTypeOfExpression(a) for a in newConditions]
resultTargetType = getTypeOfExpression(expr)
conditionTargetType = collateTypes(typesOfConditions)
if type(conditionTargetType) is VectorType and type(resultTargetType) is not VectorType:
resultTargetType = VectorType(resultTargetType, width=conditionTargetType.width)
castedResults = [castFunc(a, resultTargetType) if t != resultTargetType else a
for a, t in zip(newResults, typesOfResults)]
conditionTargetType = collateTypes(typesOfConditions)
castedConditions = [castFunc(a, conditionTargetType) if t != conditionTargetType and a != True else a
for a, t in zip(newConditions, typesOfConditions)]
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment