From d846c95171aa5976f864e3141a09df2bf3c3d164 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Fri, 26 Oct 2018 15:47:29 +0200
Subject: [PATCH] Bugfix related to cast_func and Booleans

---
 backends/cbackend.py |  6 +++---
 cpu/vectorization.py |  2 +-
 data_types.py        | 21 +++++++++++++++++++--
 sympyextensions.py   |  2 +-
 4 files changed, 24 insertions(+), 7 deletions(-)

diff --git a/backends/cbackend.py b/backends/cbackend.py
index a8c1c32c0..f4d588b07 100644
--- a/backends/cbackend.py
+++ b/backends/cbackend.py
@@ -242,7 +242,7 @@ class CustomSympyPrinter(CCodePrinter):
         }
         if hasattr(expr, 'to_c'):
             return expr.to_c(self._print)
-        if expr.func == cast_func:
+        if isinstance(expr, cast_func):
             arg, data_type = expr.args
             if isinstance(arg, sp.Number):
                 return self._typed_number(arg, data_type)
@@ -286,11 +286,11 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
             return None
 
     def _print_Function(self, expr):
-        if expr.func == vector_memory_access:
+        if isinstance(expr, vector_memory_access):
             arg, data_type, aligned, _ = expr.args
             instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU']
             return instruction.format("& " + self._print(arg))
-        elif expr.func == cast_func:
+        elif isinstance(expr, cast_func):
             arg, data_type = expr.args
             if type(data_type) is VectorType:
                 return self.instruction_set['makeVec'].format(self._print(arg))
diff --git a/cpu/vectorization.py b/cpu/vectorization.py
index 32b47f262..e54109fc4 100644
--- a/cpu/vectorization.py
+++ b/cpu/vectorization.py
@@ -116,7 +116,7 @@ def insert_vector_casts(ast_node):
     """Inserts necessary casts from scalar values to vector values."""
 
     def visit_expr(expr):
-        if expr.func in (cast_func, vector_memory_access):
+        if isinstance(expr, cast_func) or isinstance(expr, vector_memory_access):
             return expr
         elif expr.func in (sp.Add, sp.Mul) or isinstance(expr, sp.Rel) or isinstance(expr, sp.boolalg.BooleanFunction):
             new_args = [visit_expr(a) for a in expr.args]
diff --git a/data_types.py b/data_types.py
index a992d3b10..8a9cadaba 100644
--- a/data_types.py
+++ b/data_types.py
@@ -14,10 +14,22 @@ from sympy.logic.boolalg import Boolean
 
 
 # noinspection PyPep8Naming
-class cast_func(sp.Function, Boolean):
-    # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
+class cast_func(sp.Function):
     is_Atom = True
 
+    def __new__(cls, *args, **kwargs):
+        # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
+        # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
+        # to problems when for example comparing cast_func's for equality
+        #
+        # lhs = bitwise_and(a, cast_func(1, 'int'))
+        # rhs = cast_func(0, 'int')
+        # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
+        # -> thus a separate class bollean_cast_func is introduced
+        if isinstance(args[0], Boolean):
+            cls = boolean_cast_func
+        return sp.Function.__new__(cls, *args, **kwargs)
+
     @property
     def canonical(self):
         if hasattr(self.args[0], 'canonical'):
@@ -34,6 +46,11 @@ class cast_func(sp.Function, Boolean):
         return self.args[1]
 
 
+# noinspection PyPep8Naming
+class boolean_cast_func(cast_func, Boolean):
+    pass
+
+
 # noinspection PyPep8Naming
 class vector_memory_access(cast_func):
     nargs = (4,)
diff --git a/sympyextensions.py b/sympyextensions.py
index b836d3d53..b95fc6bc5 100644
--- a/sympyextensions.py
+++ b/sympyextensions.py
@@ -474,7 +474,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
             visit_children = False
         elif t.is_integer:
             pass
-        elif t.func is cast_func:
+        elif isinstance(t, cast_func):
             visit_children = False
             visit(t.args[0])
         elif t.func is sp.Pow:
-- 
GitLab