Commit ffd7b240 authored by Martin Bauer's avatar Martin Bauer
Browse files

Vectorization bugfixes and improvements

- support for logical operators, and/or
- both phase field kernels can be vectorized now
parent fc198a17
......@@ -221,6 +221,14 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
super(VectorizedCustomSympyPrinter, self).__init__(constantsAsFloats)
self.instructionSet = instructionSet
def _scalarFallback(self, funcName, expr, *args, **kwargs):
exprType = getTypeOfExpression(expr)
if type(exprType) is not VectorType:
return getattr(super(VectorizedCustomSympyPrinter, self), funcName)(expr, *args, **kwargs)
else:
assert self.instructionSet['width'] == exprType.width
return None
def _print_Function(self, expr):
if expr.func == castFunc:
arg, dtype = expr.args
......@@ -232,11 +240,34 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)
def _print_And(self, expr):
result = self._scalarFallback('_print_And', expr)
if result:
return result
argStrings = [self._print(a) for a in expr.args]
assert len(argStrings) > 0
result = argStrings[0]
for item in argStrings[1:]:
result = self.instructionSet['&'].format(result, item)
return result
def _print_Or(self, expr):
result = self._scalarFallback('_print_Or', expr)
if result:
return result
argStrings = [self._print(a) for a in expr.args]
assert len(argStrings) > 0
result = argStrings[0]
for item in argStrings[1:]:
result = self.instructionSet['|'].format(result, item)
return result
def _print_Add(self, expr, order=None):
exprType = getTypeOfExpression(expr)
if type(exprType) is not VectorType:
return super(VectorizedCustomSympyPrinter, self)._print_Add(expr, order)
assert self.instructionSet['width'] == exprType.width
result = self._scalarFallback('_print_Add', expr)
if result:
return result
summands = []
for term in expr.args:
......@@ -260,11 +291,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
return processed
def _print_Pow(self, expr):
"""Don't use std::pow function, for small integer exponents, write as multiplication"""
exprType = getTypeOfExpression(expr)
if type(exprType) is not VectorType:
return super(VectorizedCustomSympyPrinter, self)._print_Pow(expr)
assert self.instructionSet['width'] == exprType.width
result = self._scalarFallback('_print_Pow', expr)
if result:
return result
if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
......@@ -278,10 +307,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
raise ValueError("Generic exponential not supported")
def _print_Mul(self, expr, insideAdd=False):
exprType = getTypeOfExpression(expr)
if type(exprType) is not VectorType:
return super(VectorizedCustomSympyPrinter, self)._print_Mul(expr)
assert self.instructionSet['width'] == exprType.width
result = self._scalarFallback('_print_Mul', expr)
if result:
return result
c, e = expr.as_coeff_Mul()
if c < 0:
......@@ -328,26 +356,21 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
return result
def _print_Relational(self, expr):
exprType = getTypeOfExpression(expr)
if type(exprType) is not VectorType:
return super(VectorizedCustomSympyPrinter, self)._print_Relational(expr)
assert self.instructionSet['width'] == exprType.width
result = self._scalarFallback('_print_Relational', expr)
if result:
return result
return self.instructionSet[expr.rel_op].format(self._print(expr.lhs), self._print(expr.rhs))
def _print_Equality(self, expr):
exprType = getTypeOfExpression(expr)
if type(exprType) is not VectorType:
return super(VectorizedCustomSympyPrinter, self)._print_Equality(expr)
assert self.instructionSet['width'] == exprType.width
result = self._scalarFallback('_print_Equality', expr)
if result:
return result
return self.instructionSet['=='].format(self._print(expr.lhs), self._print(expr.rhs))
def _print_Piecewise(self, expr):
exprType = getTypeOfExpression(expr)
if type(exprType) is not VectorType:
return super(VectorizedCustomSympyPrinter, self)._print_Piecewise(expr)
assert self.instructionSet['width'] == exprType.width
result = self._scalarFallback('_print_Piecewise', expr)
if result:
return result
if expr.args[-1].cond != True:
# We need the last conditional to be a True, otherwise the resulting
......@@ -362,6 +385,3 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
for trueExpr, condition in reversed(expr.args[:-1]):
result = self.instructionSet['blendv'].format(result, self._print(trueExpr), self._print(condition))
return result
......@@ -13,7 +13,8 @@ def x86VectorInstructionSet(dataType='double', instructionSet='avx'):
'<=': 'cmp[0, 1, _CMP_LE_OQ ]',
'<': 'cmp[0, 1, _CMP_NGE_UQ ]',
'>': 'cmp[0, 1, _CMP_NLE_UQ ]',
'&': 'and[0, 1]',
'|': 'or[0, 1]',
'blendv': 'blendv[0, 1, 2]',
'sqrt': 'sqrt[0]',
......
......@@ -292,7 +292,7 @@ def getTypeOfExpression(expr):
elif isinstance(expr, sp.Indexed):
typedSymbol = expr.base.label
return typedSymbol.dtype.baseType
elif isinstance(expr, sp.boolalg.Boolean):
elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
# if any arg is of vector type return a vector boolean, else return a normal scalar boolean
result = createTypeFromString("bool")
vecArgs = [getTypeOfExpression(a) for a in expr.args if isinstance(getTypeOfExpression(a), VectorType)]
......
......@@ -60,7 +60,8 @@ def insertVectorCasts(astNode):
Inserts necessary casts from scalar values to vector values
"""
def visitExpr(expr):
if expr.func in (sp.Add, sp.Mul) or (isinstance(expr, sp.Rel) and not expr.func == castFunc):
if expr.func in (sp.Add, sp.Mul) or (isinstance(expr, sp.Rel) and not expr.func == castFunc) or \
isinstance(expr, sp.boolalg.BooleanFunction):
newArgs = [visitExpr(a) for a in expr.args]
argTypes = [getTypeOfExpression(a) for a in newArgs]
if not any(type(t) is VectorType for t in argTypes):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment