From d846c95171aa5976f864e3141a09df2bf3c3d164 Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Fri, 26 Oct 2018 15:47:29 +0200 Subject: [PATCH] Bugfix related to cast_func and Booleans --- backends/cbackend.py | 6 +++--- cpu/vectorization.py | 2 +- data_types.py | 21 +++++++++++++++++++-- sympyextensions.py | 2 +- 4 files changed, 24 insertions(+), 7 deletions(-) diff --git a/backends/cbackend.py b/backends/cbackend.py index a8c1c32c0..f4d588b07 100644 --- a/backends/cbackend.py +++ b/backends/cbackend.py @@ -242,7 +242,7 @@ class CustomSympyPrinter(CCodePrinter): } if hasattr(expr, 'to_c'): return expr.to_c(self._print) - if expr.func == cast_func: + if isinstance(expr, cast_func): arg, data_type = expr.args if isinstance(arg, sp.Number): return self._typed_number(arg, data_type) @@ -286,11 +286,11 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): return None def _print_Function(self, expr): - if expr.func == vector_memory_access: + if isinstance(expr, vector_memory_access): arg, data_type, aligned, _ = expr.args instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU'] return instruction.format("& " + self._print(arg)) - elif expr.func == cast_func: + elif isinstance(expr, cast_func): arg, data_type = expr.args if type(data_type) is VectorType: return self.instruction_set['makeVec'].format(self._print(arg)) diff --git a/cpu/vectorization.py b/cpu/vectorization.py index 32b47f262..e54109fc4 100644 --- a/cpu/vectorization.py +++ b/cpu/vectorization.py @@ -116,7 +116,7 @@ def insert_vector_casts(ast_node): """Inserts necessary casts from scalar values to vector values.""" def visit_expr(expr): - if expr.func in (cast_func, vector_memory_access): + if isinstance(expr, cast_func) or isinstance(expr, vector_memory_access): return expr elif expr.func in (sp.Add, sp.Mul) or isinstance(expr, sp.Rel) or isinstance(expr, sp.boolalg.BooleanFunction): new_args = [visit_expr(a) for a in expr.args] diff --git a/data_types.py b/data_types.py index a992d3b10..8a9cadaba 100644 --- a/data_types.py +++ b/data_types.py @@ -14,10 +14,22 @@ from sympy.logic.boolalg import Boolean # noinspection PyPep8Naming -class cast_func(sp.Function, Boolean): - # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well +class cast_func(sp.Function): is_Atom = True + def __new__(cls, *args, **kwargs): + # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well + # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads + # to problems when for example comparing cast_func's for equality + # + # lhs = bitwise_and(a, cast_func(1, 'int')) + # rhs = cast_func(0, 'int') + # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans + # -> thus a separate class bollean_cast_func is introduced + if isinstance(args[0], Boolean): + cls = boolean_cast_func + return sp.Function.__new__(cls, *args, **kwargs) + @property def canonical(self): if hasattr(self.args[0], 'canonical'): @@ -34,6 +46,11 @@ class cast_func(sp.Function, Boolean): return self.args[1] +# noinspection PyPep8Naming +class boolean_cast_func(cast_func, Boolean): + pass + + # noinspection PyPep8Naming class vector_memory_access(cast_func): nargs = (4,) diff --git a/sympyextensions.py b/sympyextensions.py index b836d3d53..b95fc6bc5 100644 --- a/sympyextensions.py +++ b/sympyextensions.py @@ -474,7 +474,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], visit_children = False elif t.is_integer: pass - elif t.func is cast_func: + elif isinstance(t, cast_func): visit_children = False visit(t.args[0]) elif t.func is sp.Pow: -- GitLab