diff --git a/backends/cbackend.py b/backends/cbackend.py index a8c1c32c0e17d9b27be694304bb78acc89c2faf0..f4d588b0741538d40cdbd62165fbc4c06f28316a 100644 --- a/backends/cbackend.py +++ b/backends/cbackend.py @@ -242,7 +242,7 @@ class CustomSympyPrinter(CCodePrinter): } if hasattr(expr, 'to_c'): return expr.to_c(self._print) - if expr.func == cast_func: + if isinstance(expr, cast_func): arg, data_type = expr.args if isinstance(arg, sp.Number): return self._typed_number(arg, data_type) @@ -286,11 +286,11 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): return None def _print_Function(self, expr): - if expr.func == vector_memory_access: + if isinstance(expr, vector_memory_access): arg, data_type, aligned, _ = expr.args instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU'] return instruction.format("& " + self._print(arg)) - elif expr.func == cast_func: + elif isinstance(expr, cast_func): arg, data_type = expr.args if type(data_type) is VectorType: return self.instruction_set['makeVec'].format(self._print(arg)) diff --git a/cpu/vectorization.py b/cpu/vectorization.py index 32b47f2622747f7d2586bc0d5a0b14cbfd1008b3..e54109fc4e9fbd88832cfc17e18431380e77b732 100644 --- a/cpu/vectorization.py +++ b/cpu/vectorization.py @@ -116,7 +116,7 @@ def insert_vector_casts(ast_node): """Inserts necessary casts from scalar values to vector values.""" def visit_expr(expr): - if expr.func in (cast_func, vector_memory_access): + if isinstance(expr, cast_func) or isinstance(expr, vector_memory_access): return expr elif expr.func in (sp.Add, sp.Mul) or isinstance(expr, sp.Rel) or isinstance(expr, sp.boolalg.BooleanFunction): new_args = [visit_expr(a) for a in expr.args] diff --git a/data_types.py b/data_types.py index a992d3b106a9374ddc359604554eb65aa411ad97..8a9cadaba8a2e8e66e2040928913b7430a125130 100644 --- a/data_types.py +++ b/data_types.py @@ -14,10 +14,22 @@ from sympy.logic.boolalg import Boolean # noinspection PyPep8Naming -class cast_func(sp.Function, Boolean): - # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well +class cast_func(sp.Function): is_Atom = True + def __new__(cls, *args, **kwargs): + # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well + # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads + # to problems when for example comparing cast_func's for equality + # + # lhs = bitwise_and(a, cast_func(1, 'int')) + # rhs = cast_func(0, 'int') + # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans + # -> thus a separate class bollean_cast_func is introduced + if isinstance(args[0], Boolean): + cls = boolean_cast_func + return sp.Function.__new__(cls, *args, **kwargs) + @property def canonical(self): if hasattr(self.args[0], 'canonical'): @@ -34,6 +46,11 @@ class cast_func(sp.Function, Boolean): return self.args[1] +# noinspection PyPep8Naming +class boolean_cast_func(cast_func, Boolean): + pass + + # noinspection PyPep8Naming class vector_memory_access(cast_func): nargs = (4,) diff --git a/sympyextensions.py b/sympyextensions.py index b836d3d53ee870184ab63f602fa543ea21b2bd39..b95fc6bc5dc40bc1ebc3b5677367d1849fe7c10a 100644 --- a/sympyextensions.py +++ b/sympyextensions.py @@ -474,7 +474,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], visit_children = False elif t.is_integer: pass - elif t.func is cast_func: + elif isinstance(t, cast_func): visit_children = False visit(t.args[0]) elif t.func is sp.Pow: