From 2f1c1194d0967322e0d83f2804677ea51de4d469 Mon Sep 17 00:00:00 2001
From: Markus Holzer <markus.holzer@fau.de>
Date: Fri, 21 Jan 2022 14:27:08 +0100
Subject: [PATCH] Fix and improved bit mask support

---
 pystencils/bit_masks.py            | 17 +++++++++--------
 pystencils/typing/leaf_typing.py   |  5 ++++-
 pystencils_tests/test_bit_masks.py | 12 +++++++++---
 3 files changed, 22 insertions(+), 12 deletions(-)

diff --git a/pystencils/bit_masks.py b/pystencils/bit_masks.py
index 73c18688c..ad4e967f0 100644
--- a/pystencils/bit_masks.py
+++ b/pystencils/bit_masks.py
@@ -1,5 +1,5 @@
 import sympy as sp
-from pystencils.typing import get_type_of_expression
+# from pystencils.typing import get_type_of_expression
 
 
 # noinspection PyPep8Naming
@@ -22,13 +22,14 @@ class flag_cond(sp.Function):
 
     def __new__(cls, flag_bit, mask_expression, *expressions):
 
-        flag_dtype = get_type_of_expression(flag_bit)
-        if not flag_dtype.is_int():
-            raise ValueError('Argument flag_bit must be of integer type.')
-
-        mask_dtype = get_type_of_expression(mask_expression)
-        if not mask_dtype.is_int():
-            raise ValueError('Argument mask_expression must be of integer type.')
+        # TODO reintroduce checking
+        # flag_dtype = get_type_of_expression(flag_bit)
+        # if not flag_dtype.is_int():
+        #     raise ValueError('Argument flag_bit must be of integer type.')
+        #
+        # mask_dtype = get_type_of_expression(mask_expression)
+        # if not mask_dtype.is_int():
+        #     raise ValueError('Argument mask_expression must be of integer type.')
 
         return super().__new__(cls, flag_bit, mask_expression, *expressions)
 
diff --git a/pystencils/typing/leaf_typing.py b/pystencils/typing/leaf_typing.py
index b620c9c7e..ef81b529d 100644
--- a/pystencils/typing/leaf_typing.py
+++ b/pystencils/typing/leaf_typing.py
@@ -175,7 +175,10 @@ class TypeAdder:
             raise NotImplementedError('integer_functions')
         elif isinstance(expr, flag_cond):
             #   do not process the arguments to the bit shift - they must remain integers
-            raise NotImplementedError('flag_cond')
+            args_types = [self.figure_out_type(a) for a in (expr.args[i] for i in range(2, len(expr.args)))]
+            collated_type = collate_types([t for _, t in args_types])
+            new_expressions = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
+            return expr.func(expr.args[0], expr.args[1], *new_expressions), collated_type
         #elif isinstance(expr, sp.Mul):
         #    raise NotImplementedError('sp.Mul')
         #    # TODO can we ignore this and move it to general expr handling, i.e. removing Mul?
diff --git a/pystencils_tests/test_bit_masks.py b/pystencils_tests/test_bit_masks.py
index 57371976f..423fc13cc 100644
--- a/pystencils_tests/test_bit_masks.py
+++ b/pystencils_tests/test_bit_masks.py
@@ -1,11 +1,15 @@
+import pytest
 import numpy as np
+
+import pystencils as ps
 from pystencils import Field, Assignment, create_kernel
 from pystencils.bit_masks import flag_cond
 
 
-def test_flag_condition():
+@pytest.mark.parametrize('mask_type', [np.uint8, np.uint16, np.uint32, np.uint64])
+def test_flag_condition(mask_type):
     f_arr = np.zeros((2, 2, 2), dtype=np.float64)
-    mask_arr = np.zeros((2, 2), dtype=np.uint64)
+    mask_arr = np.zeros((2, 2), dtype=mask_type)
 
     mask_arr[0, 1] = (1 << 3)
     mask_arr[1, 0] = (1 << 5)
@@ -16,7 +20,7 @@ def test_flag_condition():
 
     v1 = 42.3
     v2 = 39.7
-    v3 = 119.87
+    v3 = 119
 
     assignments = [
         Assignment(f(0), flag_cond(3, mask(0), v1)),
@@ -25,6 +29,8 @@ def test_flag_condition():
 
     kernel = create_kernel(assignments).compile()
     kernel(f=f_arr, mask=mask_arr)
+    code = ps.get_code_str(kernel)
+    assert '119.0' in code
 
     reference = np.zeros((2, 2, 2), dtype=np.float64)
     reference[0, 1, 0] = v1
-- 
GitLab