Skip to content
Snippets Groups Projects
Commit 8a89d0bc authored by Martin Bauer's avatar Martin Bauer
Browse files

Fixes for fast_* nodes and SIMD printer

parent 2f5f6ad6
No related merge requests found
......@@ -326,14 +326,18 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
if type(data_type) is VectorType:
return self.instruction_set['makeVec'].format(self._print(arg))
elif expr.func == fast_division:
return self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1]))
result = self._scalarFallback('_print_Function', expr)
if not result:
return self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1]))
elif expr.func == fast_sqrt:
return "({})".format(self._print(sp.sqrt(expr.args[0])))
elif expr.func == fast_inv_sqrt:
if self.instruction_set['rsqrt']:
return self.instruction_set['rsqrt'].format(self._print(expr.args[0]))
else:
return "({})".format(self.doprint(1 / sp.sqrt(expr.args[0])))
result = self._scalarFallback('_print_Function', expr)
if not result:
if self.instruction_set['rsqrt']:
return self.instruction_set['rsqrt'].format(self._print(expr.args[0]))
else:
return "({})".format(self._print(1 / sp.sqrt(expr.args[0])))
return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)
def _print_And(self, expr):
......
......@@ -2,6 +2,7 @@ import sympy as sp
import warnings
from typing import Union, Container
from pystencils.backends.simd_instruction_sets import get_vector_instruction_set
from pystencils.fast_approximation import fast_division, fast_sqrt, fast_inv_sqrt
from pystencils.integer_functions import modulo_floor, modulo_ceil
from pystencils.sympyextensions import fast_subs
from pystencils.data_types import TypedSymbol, VectorType, get_type_of_expression, vector_memory_access, cast_func, \
......@@ -118,10 +119,13 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
def insert_vector_casts(ast_node):
"""Inserts necessary casts from scalar values to vector values."""
handled_functions = (sp.Add, sp.Mul, fast_division, fast_sqrt, fast_inv_sqrt)
def visit_expr(expr):
if isinstance(expr, cast_func) or isinstance(expr, vector_memory_access):
return expr
elif expr.func in (sp.Add, sp.Mul) or isinstance(expr, sp.Rel) or isinstance(expr, sp.boolalg.BooleanFunction):
elif expr.func in handled_functions or isinstance(expr, sp.Rel) or isinstance(expr, sp.boolalg.BooleanFunction):
new_args = [visit_expr(a) for a in expr.args]
arg_types = [get_type_of_expression(a) for a in new_args]
if not any(type(t) is VectorType for t in arg_types):
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment