From ffd7b240916ae034bfcef611777c6678369ab8e8 Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Thu, 12 Oct 2017 11:11:49 +0200 Subject: [PATCH] Vectorization bugfixes and improvements - support for logical operators, and/or - both phase field kernels can be vectorized now --- backends/cbackend.py | 80 +++++++++++++++++++------------ backends/simd_instruction_sets.py | 3 +- data_types.py | 2 +- vectorization.py | 3 +- 4 files changed, 55 insertions(+), 33 deletions(-) diff --git a/backends/cbackend.py b/backends/cbackend.py index 86f5597cc..f19dec05f 100644 --- a/backends/cbackend.py +++ b/backends/cbackend.py @@ -221,6 +221,14 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): super(VectorizedCustomSympyPrinter, self).__init__(constantsAsFloats) self.instructionSet = instructionSet + def _scalarFallback(self, funcName, expr, *args, **kwargs): + exprType = getTypeOfExpression(expr) + if type(exprType) is not VectorType: + return getattr(super(VectorizedCustomSympyPrinter, self), funcName)(expr, *args, **kwargs) + else: + assert self.instructionSet['width'] == exprType.width + return None + def _print_Function(self, expr): if expr.func == castFunc: arg, dtype = expr.args @@ -232,11 +240,34 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): return super(VectorizedCustomSympyPrinter, self)._print_Function(expr) + def _print_And(self, expr): + result = self._scalarFallback('_print_And', expr) + if result: + return result + + argStrings = [self._print(a) for a in expr.args] + assert len(argStrings) > 0 + result = argStrings[0] + for item in argStrings[1:]: + result = self.instructionSet['&'].format(result, item) + return result + + def _print_Or(self, expr): + result = self._scalarFallback('_print_Or', expr) + if result: + return result + + argStrings = [self._print(a) for a in expr.args] + assert len(argStrings) > 0 + result = argStrings[0] + for item in argStrings[1:]: + result = self.instructionSet['|'].format(result, item) + return result + def _print_Add(self, expr, order=None): - exprType = getTypeOfExpression(expr) - if type(exprType) is not VectorType: - return super(VectorizedCustomSympyPrinter, self)._print_Add(expr, order) - assert self.instructionSet['width'] == exprType.width + result = self._scalarFallback('_print_Add', expr) + if result: + return result summands = [] for term in expr.args: @@ -260,11 +291,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): return processed def _print_Pow(self, expr): - """Don't use std::pow function, for small integer exponents, write as multiplication""" - exprType = getTypeOfExpression(expr) - if type(exprType) is not VectorType: - return super(VectorizedCustomSympyPrinter, self)._print_Pow(expr) - assert self.instructionSet['width'] == exprType.width + result = self._scalarFallback('_print_Pow', expr) + if result: + return result if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")" @@ -278,10 +307,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): raise ValueError("Generic exponential not supported") def _print_Mul(self, expr, insideAdd=False): - exprType = getTypeOfExpression(expr) - if type(exprType) is not VectorType: - return super(VectorizedCustomSympyPrinter, self)._print_Mul(expr) - assert self.instructionSet['width'] == exprType.width + result = self._scalarFallback('_print_Mul', expr) + if result: + return result c, e = expr.as_coeff_Mul() if c < 0: @@ -328,26 +356,21 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): return result def _print_Relational(self, expr): - exprType = getTypeOfExpression(expr) - if type(exprType) is not VectorType: - return super(VectorizedCustomSympyPrinter, self)._print_Relational(expr) - assert self.instructionSet['width'] == exprType.width - + result = self._scalarFallback('_print_Relational', expr) + if result: + return result return self.instructionSet[expr.rel_op].format(self._print(expr.lhs), self._print(expr.rhs)) def _print_Equality(self, expr): - exprType = getTypeOfExpression(expr) - if type(exprType) is not VectorType: - return super(VectorizedCustomSympyPrinter, self)._print_Equality(expr) - assert self.instructionSet['width'] == exprType.width - + result = self._scalarFallback('_print_Equality', expr) + if result: + return result return self.instructionSet['=='].format(self._print(expr.lhs), self._print(expr.rhs)) def _print_Piecewise(self, expr): - exprType = getTypeOfExpression(expr) - if type(exprType) is not VectorType: - return super(VectorizedCustomSympyPrinter, self)._print_Piecewise(expr) - assert self.instructionSet['width'] == exprType.width + result = self._scalarFallback('_print_Piecewise', expr) + if result: + return result if expr.args[-1].cond != True: # We need the last conditional to be a True, otherwise the resulting @@ -362,6 +385,3 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): for trueExpr, condition in reversed(expr.args[:-1]): result = self.instructionSet['blendv'].format(result, self._print(trueExpr), self._print(condition)) return result - - - diff --git a/backends/simd_instruction_sets.py b/backends/simd_instruction_sets.py index 23109165b..6760837c8 100644 --- a/backends/simd_instruction_sets.py +++ b/backends/simd_instruction_sets.py @@ -13,7 +13,8 @@ def x86VectorInstructionSet(dataType='double', instructionSet='avx'): '<=': 'cmp[0, 1, _CMP_LE_OQ ]', '<': 'cmp[0, 1, _CMP_NGE_UQ ]', '>': 'cmp[0, 1, _CMP_NLE_UQ ]', - + '&': 'and[0, 1]', + '|': 'or[0, 1]', 'blendv': 'blendv[0, 1, 2]', 'sqrt': 'sqrt[0]', diff --git a/data_types.py b/data_types.py index 75bc673e4..d2a8a52b2 100644 --- a/data_types.py +++ b/data_types.py @@ -292,7 +292,7 @@ def getTypeOfExpression(expr): elif isinstance(expr, sp.Indexed): typedSymbol = expr.base.label return typedSymbol.dtype.baseType - elif isinstance(expr, sp.boolalg.Boolean): + elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction): # if any arg is of vector type return a vector boolean, else return a normal scalar boolean result = createTypeFromString("bool") vecArgs = [getTypeOfExpression(a) for a in expr.args if isinstance(getTypeOfExpression(a), VectorType)] diff --git a/vectorization.py b/vectorization.py index 9cf113c5e..291f8a96a 100644 --- a/vectorization.py +++ b/vectorization.py @@ -60,7 +60,8 @@ def insertVectorCasts(astNode): Inserts necessary casts from scalar values to vector values """ def visitExpr(expr): - if expr.func in (sp.Add, sp.Mul) or (isinstance(expr, sp.Rel) and not expr.func == castFunc): + if expr.func in (sp.Add, sp.Mul) or (isinstance(expr, sp.Rel) and not expr.func == castFunc) or \ + isinstance(expr, sp.boolalg.BooleanFunction): newArgs = [visitExpr(a) for a in expr.args] argTypes = [getTypeOfExpression(a) for a in newArgs] if not any(type(t) is VectorType for t in argTypes): -- GitLab