Skip to content
Snippets Groups Projects
Commit 8a89d0bc authored by Martin Bauer's avatar Martin Bauer
Browse files

Fixes for fast_* nodes and SIMD printer

parent 2f5f6ad6
No related merge requests found
...@@ -326,14 +326,18 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -326,14 +326,18 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
if type(data_type) is VectorType: if type(data_type) is VectorType:
return self.instruction_set['makeVec'].format(self._print(arg)) return self.instruction_set['makeVec'].format(self._print(arg))
elif expr.func == fast_division: elif expr.func == fast_division:
return self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1])) result = self._scalarFallback('_print_Function', expr)
if not result:
return self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1]))
elif expr.func == fast_sqrt: elif expr.func == fast_sqrt:
return "({})".format(self._print(sp.sqrt(expr.args[0]))) return "({})".format(self._print(sp.sqrt(expr.args[0])))
elif expr.func == fast_inv_sqrt: elif expr.func == fast_inv_sqrt:
if self.instruction_set['rsqrt']: result = self._scalarFallback('_print_Function', expr)
return self.instruction_set['rsqrt'].format(self._print(expr.args[0])) if not result:
else: if self.instruction_set['rsqrt']:
return "({})".format(self.doprint(1 / sp.sqrt(expr.args[0]))) return self.instruction_set['rsqrt'].format(self._print(expr.args[0]))
else:
return "({})".format(self._print(1 / sp.sqrt(expr.args[0])))
return super(VectorizedCustomSympyPrinter, self)._print_Function(expr) return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)
def _print_And(self, expr): def _print_And(self, expr):
......
...@@ -2,6 +2,7 @@ import sympy as sp ...@@ -2,6 +2,7 @@ import sympy as sp
import warnings import warnings
from typing import Union, Container from typing import Union, Container
from pystencils.backends.simd_instruction_sets import get_vector_instruction_set from pystencils.backends.simd_instruction_sets import get_vector_instruction_set
from pystencils.fast_approximation import fast_division, fast_sqrt, fast_inv_sqrt
from pystencils.integer_functions import modulo_floor, modulo_ceil from pystencils.integer_functions import modulo_floor, modulo_ceil
from pystencils.sympyextensions import fast_subs from pystencils.sympyextensions import fast_subs
from pystencils.data_types import TypedSymbol, VectorType, get_type_of_expression, vector_memory_access, cast_func, \ from pystencils.data_types import TypedSymbol, VectorType, get_type_of_expression, vector_memory_access, cast_func, \
...@@ -118,10 +119,13 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a ...@@ -118,10 +119,13 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
def insert_vector_casts(ast_node): def insert_vector_casts(ast_node):
"""Inserts necessary casts from scalar values to vector values.""" """Inserts necessary casts from scalar values to vector values."""
handled_functions = (sp.Add, sp.Mul, fast_division, fast_sqrt, fast_inv_sqrt)
def visit_expr(expr): def visit_expr(expr):
if isinstance(expr, cast_func) or isinstance(expr, vector_memory_access): if isinstance(expr, cast_func) or isinstance(expr, vector_memory_access):
return expr return expr
elif expr.func in (sp.Add, sp.Mul) or isinstance(expr, sp.Rel) or isinstance(expr, sp.boolalg.BooleanFunction): elif expr.func in handled_functions or isinstance(expr, sp.Rel) or isinstance(expr, sp.boolalg.BooleanFunction):
new_args = [visit_expr(a) for a in expr.args] new_args = [visit_expr(a) for a in expr.args]
arg_types = [get_type_of_expression(a) for a in new_args] arg_types = [get_type_of_expression(a) for a in new_args]
if not any(type(t) is VectorType for t in arg_types): if not any(type(t) is VectorType for t in arg_types):
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment