From 9f501a360758c55390cfa97b38165fc29068f30e Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Sun, 4 Jul 2021 14:57:50 +0000
Subject: [PATCH] Bit Flag Conditional

---
 pystencils/bit_masks.py            | 52 ++++++++++++++++++++++++++++++
 pystencils/transformations.py      |  5 +++
 pystencils_tests/test_bit_masks.py | 41 +++++++++++++++++++++++
 3 files changed, 98 insertions(+)
 create mode 100644 pystencils/bit_masks.py
 create mode 100644 pystencils_tests/test_bit_masks.py

diff --git a/pystencils/bit_masks.py b/pystencils/bit_masks.py
new file mode 100644
index 000000000..0fab63b25
--- /dev/null
+++ b/pystencils/bit_masks.py
@@ -0,0 +1,52 @@
+import sympy as sp
+from pystencils.data_types import get_type_of_expression
+
+
+# noinspection PyPep8Naming
+class flag_cond(sp.Function):
+    """Evaluates a flag condition on a bit mask, and returns the value of one of two expressions,
+    depending on whether the flag is set. 
+
+    Three argument version:
+    ```
+        flag_cond(flag_bit, mask, expr) = expr if (flag_bit is set in mask) else 0
+    ```
+
+    Four argument version:
+    ```
+        flag_cond(flag_bit, mask, expr_then, expr_else) = expr_then if (flag_bit is set in mask) else expr_else
+    ```
+    """
+
+    nargs = (3, 4)
+
+    def __new__(cls, flag_bit, mask_expression, *expressions):
+
+        flag_dtype = get_type_of_expression(flag_bit)
+        if not flag_dtype.is_int():
+            raise ValueError('Argument flag_bit must be of integer type.')
+
+        mask_dtype = get_type_of_expression(mask_expression)
+        if not mask_dtype.is_int():
+            raise ValueError('Argument mask_expression must be of integer type.')
+
+        return super().__new__(cls, flag_bit, mask_expression, *expressions)
+
+    def to_c(self, print_func):
+        flag_bit = self.args[0]
+        mask = self.args[1]
+
+        then_expression = self.args[2]
+
+        flag_bit_code = print_func(flag_bit)
+        mask_code = print_func(mask)
+        then_code = print_func(then_expression)
+
+        code = f"(({mask_code}) >> ({flag_bit_code}) & 1) * ({then_code})"
+
+        if len(self.args) > 3:
+            else_expression = self.args[3]
+            else_code = print_func(else_expression)
+            code += f" + (({mask_code}) >> ({flag_bit_code}) ^ 1) * ({else_code})"
+
+        return code
diff --git a/pystencils/transformations.py b/pystencils/transformations.py
index 1ce3774b6..b037225b9 100644
--- a/pystencils/transformations.py
+++ b/pystencils/transformations.py
@@ -21,6 +21,7 @@ from pystencils.kernelparameters import FieldPointerSymbol
 from pystencils.simp.assignment_collection import AssignmentCollection
 from pystencils.slicing import normalize_slice
 from pystencils.integer_functions import int_div
+from pystencils.bit_masks import flag_cond
 
 
 class NestedScopes:
@@ -876,6 +877,10 @@ class KernelConstraintsCheck:
                         else cast_func(a, arg_type)
                         for a in new_args]
             return rhs.func(*new_args)
+        elif isinstance(rhs, flag_cond):
+            #   do not process the arguments to the bit shift - they must remain integers
+            processed_args = (self.process_expression(a) for a in rhs.args[2:])
+            return flag_cond(rhs.args[0], rhs.args[1], *processed_args)
         elif isinstance(rhs, sp.Mul):
             new_args = [
                 self.process_expression(arg, type_constants)
diff --git a/pystencils_tests/test_bit_masks.py b/pystencils_tests/test_bit_masks.py
new file mode 100644
index 000000000..5b3ad66a3
--- /dev/null
+++ b/pystencils_tests/test_bit_masks.py
@@ -0,0 +1,41 @@
+import numpy as np
+import sympy as sp
+from pystencils import Field, Assignment, create_kernel
+from pystencils.bit_masks import flag_cond
+from pystencils import TypedSymbol
+
+
+def test_flag_condition():
+    f_arr = np.zeros((2,2,2), dtype=np.float64)
+    mask_arr = np.zeros((2,2), dtype=np.uint64)
+
+    mask_arr[0,1] = (1<<3)
+    mask_arr[1,0] = (1<<5)
+    mask_arr[1,1] = (1<<3) + (1 << 5)
+
+    f = Field.create_from_numpy_array('f', f_arr, index_dimensions=1)
+    mask = Field.create_from_numpy_array('mask', mask_arr)
+
+    v1 = 42.3
+    v2 = 39.7
+    v3 = 119.87
+
+    assignments = [
+        Assignment(f(0), flag_cond(3, mask(0), v1)),
+        Assignment(f(1), flag_cond(5, mask(0), v2, v3))
+    ]
+
+    kernel = create_kernel(assignments).compile()
+    kernel(f=f_arr, mask=mask_arr)
+
+    reference = np.zeros((2,2,2), dtype=np.float64)
+    reference[0,1,0] = v1
+    reference[1,1,0] = v1
+
+    reference[0,0,1] = v3
+    reference[0,1,1] = v3
+
+    reference[1,0,1] = v2
+    reference[1,1,1] = v2
+
+    np.testing.assert_array_equal(f_arr, reference)
-- 
GitLab