diff --git a/backends/cbackend.py b/backends/cbackend.py index 86f5597cc5792e90005fd22d9175ad87bcdfacc9..f19dec05f806a8b3c2b93c0e7b14d026fd889817 100644 --- a/backends/cbackend.py +++ b/backends/cbackend.py @@ -221,6 +221,14 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): super(VectorizedCustomSympyPrinter, self).__init__(constantsAsFloats) self.instructionSet = instructionSet + def _scalarFallback(self, funcName, expr, *args, **kwargs): + exprType = getTypeOfExpression(expr) + if type(exprType) is not VectorType: + return getattr(super(VectorizedCustomSympyPrinter, self), funcName)(expr, *args, **kwargs) + else: + assert self.instructionSet['width'] == exprType.width + return None + def _print_Function(self, expr): if expr.func == castFunc: arg, dtype = expr.args @@ -232,11 +240,34 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): return super(VectorizedCustomSympyPrinter, self)._print_Function(expr) + def _print_And(self, expr): + result = self._scalarFallback('_print_And', expr) + if result: + return result + + argStrings = [self._print(a) for a in expr.args] + assert len(argStrings) > 0 + result = argStrings[0] + for item in argStrings[1:]: + result = self.instructionSet['&'].format(result, item) + return result + + def _print_Or(self, expr): + result = self._scalarFallback('_print_Or', expr) + if result: + return result + + argStrings = [self._print(a) for a in expr.args] + assert len(argStrings) > 0 + result = argStrings[0] + for item in argStrings[1:]: + result = self.instructionSet['|'].format(result, item) + return result + def _print_Add(self, expr, order=None): - exprType = getTypeOfExpression(expr) - if type(exprType) is not VectorType: - return super(VectorizedCustomSympyPrinter, self)._print_Add(expr, order) - assert self.instructionSet['width'] == exprType.width + result = self._scalarFallback('_print_Add', expr) + if result: + return result summands = [] for term in expr.args: @@ -260,11 +291,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): return processed def _print_Pow(self, expr): - """Don't use std::pow function, for small integer exponents, write as multiplication""" - exprType = getTypeOfExpression(expr) - if type(exprType) is not VectorType: - return super(VectorizedCustomSympyPrinter, self)._print_Pow(expr) - assert self.instructionSet['width'] == exprType.width + result = self._scalarFallback('_print_Pow', expr) + if result: + return result if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")" @@ -278,10 +307,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): raise ValueError("Generic exponential not supported") def _print_Mul(self, expr, insideAdd=False): - exprType = getTypeOfExpression(expr) - if type(exprType) is not VectorType: - return super(VectorizedCustomSympyPrinter, self)._print_Mul(expr) - assert self.instructionSet['width'] == exprType.width + result = self._scalarFallback('_print_Mul', expr) + if result: + return result c, e = expr.as_coeff_Mul() if c < 0: @@ -328,26 +356,21 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): return result def _print_Relational(self, expr): - exprType = getTypeOfExpression(expr) - if type(exprType) is not VectorType: - return super(VectorizedCustomSympyPrinter, self)._print_Relational(expr) - assert self.instructionSet['width'] == exprType.width - + result = self._scalarFallback('_print_Relational', expr) + if result: + return result return self.instructionSet[expr.rel_op].format(self._print(expr.lhs), self._print(expr.rhs)) def _print_Equality(self, expr): - exprType = getTypeOfExpression(expr) - if type(exprType) is not VectorType: - return super(VectorizedCustomSympyPrinter, self)._print_Equality(expr) - assert self.instructionSet['width'] == exprType.width - + result = self._scalarFallback('_print_Equality', expr) + if result: + return result return self.instructionSet['=='].format(self._print(expr.lhs), self._print(expr.rhs)) def _print_Piecewise(self, expr): - exprType = getTypeOfExpression(expr) - if type(exprType) is not VectorType: - return super(VectorizedCustomSympyPrinter, self)._print_Piecewise(expr) - assert self.instructionSet['width'] == exprType.width + result = self._scalarFallback('_print_Piecewise', expr) + if result: + return result if expr.args[-1].cond != True: # We need the last conditional to be a True, otherwise the resulting @@ -362,6 +385,3 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): for trueExpr, condition in reversed(expr.args[:-1]): result = self.instructionSet['blendv'].format(result, self._print(trueExpr), self._print(condition)) return result - - - diff --git a/backends/simd_instruction_sets.py b/backends/simd_instruction_sets.py index 23109165b3764068774dc8f7039725c59962b9e7..6760837c8eb91ce1e9bb8f3cc43ac03ada7d8a6e 100644 --- a/backends/simd_instruction_sets.py +++ b/backends/simd_instruction_sets.py @@ -13,7 +13,8 @@ def x86VectorInstructionSet(dataType='double', instructionSet='avx'): '<=': 'cmp[0, 1, _CMP_LE_OQ ]', '<': 'cmp[0, 1, _CMP_NGE_UQ ]', '>': 'cmp[0, 1, _CMP_NLE_UQ ]', - + '&': 'and[0, 1]', + '|': 'or[0, 1]', 'blendv': 'blendv[0, 1, 2]', 'sqrt': 'sqrt[0]', diff --git a/data_types.py b/data_types.py index 75bc673e4ac37ba30ba6eba670afe0b4ad56e25b..d2a8a52b2a3464b68f2eb66940419ea647827565 100644 --- a/data_types.py +++ b/data_types.py @@ -292,7 +292,7 @@ def getTypeOfExpression(expr): elif isinstance(expr, sp.Indexed): typedSymbol = expr.base.label return typedSymbol.dtype.baseType - elif isinstance(expr, sp.boolalg.Boolean): + elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction): # if any arg is of vector type return a vector boolean, else return a normal scalar boolean result = createTypeFromString("bool") vecArgs = [getTypeOfExpression(a) for a in expr.args if isinstance(getTypeOfExpression(a), VectorType)] diff --git a/vectorization.py b/vectorization.py index 9cf113c5eabf69e6332a89d287fe7ab6265d0543..291f8a96af45d47ff7ff78f56b91241e496a9aab 100644 --- a/vectorization.py +++ b/vectorization.py @@ -60,7 +60,8 @@ def insertVectorCasts(astNode): Inserts necessary casts from scalar values to vector values """ def visitExpr(expr): - if expr.func in (sp.Add, sp.Mul) or (isinstance(expr, sp.Rel) and not expr.func == castFunc): + if expr.func in (sp.Add, sp.Mul) or (isinstance(expr, sp.Rel) and not expr.func == castFunc) or \ + isinstance(expr, sp.boolalg.BooleanFunction): newArgs = [visitExpr(a) for a in expr.args] argTypes = [getTypeOfExpression(a) for a in newArgs] if not any(type(t) is VectorType for t in argTypes):