Commit d846c951 authored by Martin Bauer's avatar Martin Bauer
Browse files

Bugfix related to cast_func and Booleans

parent a6754c61
......@@ -242,7 +242,7 @@ class CustomSympyPrinter(CCodePrinter):
if hasattr(expr, 'to_c'):
return expr.to_c(self._print)
if expr.func == cast_func:
if isinstance(expr, cast_func):
arg, data_type = expr.args
if isinstance(arg, sp.Number):
return self._typed_number(arg, data_type)
......@@ -286,11 +286,11 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
return None
def _print_Function(self, expr):
if expr.func == vector_memory_access:
if isinstance(expr, vector_memory_access):
arg, data_type, aligned, _ = expr.args
instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU']
return instruction.format("& " + self._print(arg))
elif expr.func == cast_func:
elif isinstance(expr, cast_func):
arg, data_type = expr.args
if type(data_type) is VectorType:
return self.instruction_set['makeVec'].format(self._print(arg))
......@@ -116,7 +116,7 @@ def insert_vector_casts(ast_node):
"""Inserts necessary casts from scalar values to vector values."""
def visit_expr(expr):
if expr.func in (cast_func, vector_memory_access):
if isinstance(expr, cast_func) or isinstance(expr, vector_memory_access):
return expr
elif expr.func in (sp.Add, sp.Mul) or isinstance(expr, sp.Rel) or isinstance(expr, sp.boolalg.BooleanFunction):
new_args = [visit_expr(a) for a in expr.args]
......@@ -14,10 +14,22 @@ from sympy.logic.boolalg import Boolean
# noinspection PyPep8Naming
class cast_func(sp.Function, Boolean):
# to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
class cast_func(sp.Function):
is_Atom = True
def __new__(cls, *args, **kwargs):
# to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
# however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
# to problems when for example comparing cast_func's for equality
# lhs = bitwise_and(a, cast_func(1, 'int'))
# rhs = cast_func(0, 'int')
# print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
# -> thus a separate class bollean_cast_func is introduced
if isinstance(args[0], Boolean):
cls = boolean_cast_func
return sp.Function.__new__(cls, *args, **kwargs)
def canonical(self):
if hasattr(self.args[0], 'canonical'):
......@@ -34,6 +46,11 @@ class cast_func(sp.Function, Boolean):
return self.args[1]
# noinspection PyPep8Naming
class boolean_cast_func(cast_func, Boolean):
# noinspection PyPep8Naming
class vector_memory_access(cast_func):
nargs = (4,)
......@@ -474,7 +474,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
visit_children = False
elif t.is_integer:
elif t.func is cast_func:
elif isinstance(t, cast_func):
visit_children = False
elif t.func is sp.Pow:
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment