diff --git a/backends/cbackend.py b/backends/cbackend.py index f3336be4bf12ba210946f22055063a115ea1c9d8..f0787bc7dd41701217ab1a7ad868a178416c449b 100644 --- a/backends/cbackend.py +++ b/backends/cbackend.py @@ -326,14 +326,18 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): if type(data_type) is VectorType: return self.instruction_set['makeVec'].format(self._print(arg)) elif expr.func == fast_division: - return self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1])) + result = self._scalarFallback('_print_Function', expr) + if not result: + return self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1])) elif expr.func == fast_sqrt: return "({})".format(self._print(sp.sqrt(expr.args[0]))) elif expr.func == fast_inv_sqrt: - if self.instruction_set['rsqrt']: - return self.instruction_set['rsqrt'].format(self._print(expr.args[0])) - else: - return "({})".format(self.doprint(1 / sp.sqrt(expr.args[0]))) + result = self._scalarFallback('_print_Function', expr) + if not result: + if self.instruction_set['rsqrt']: + return self.instruction_set['rsqrt'].format(self._print(expr.args[0])) + else: + return "({})".format(self._print(1 / sp.sqrt(expr.args[0]))) return super(VectorizedCustomSympyPrinter, self)._print_Function(expr) def _print_And(self, expr): diff --git a/cpu/vectorization.py b/cpu/vectorization.py index 6556a01f68b587aad92ed887a7e2b0b0b0cb63d2..6a55b692e4e27b992d728dd19badc584ebc8d46b 100644 --- a/cpu/vectorization.py +++ b/cpu/vectorization.py @@ -2,6 +2,7 @@ import sympy as sp import warnings from typing import Union, Container from pystencils.backends.simd_instruction_sets import get_vector_instruction_set +from pystencils.fast_approximation import fast_division, fast_sqrt, fast_inv_sqrt from pystencils.integer_functions import modulo_floor, modulo_ceil from pystencils.sympyextensions import fast_subs from pystencils.data_types import TypedSymbol, VectorType, get_type_of_expression, vector_memory_access, cast_func, \ @@ -118,10 +119,13 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a def insert_vector_casts(ast_node): """Inserts necessary casts from scalar values to vector values.""" + handled_functions = (sp.Add, sp.Mul, fast_division, fast_sqrt, fast_inv_sqrt) + def visit_expr(expr): + if isinstance(expr, cast_func) or isinstance(expr, vector_memory_access): return expr - elif expr.func in (sp.Add, sp.Mul) or isinstance(expr, sp.Rel) or isinstance(expr, sp.boolalg.BooleanFunction): + elif expr.func in handled_functions or isinstance(expr, sp.Rel) or isinstance(expr, sp.boolalg.BooleanFunction): new_args = [visit_expr(a) for a in expr.args] arg_types = [get_type_of_expression(a) for a in new_args] if not any(type(t) is VectorType for t in arg_types):