From 9f501a360758c55390cfa97b38165fc29068f30e Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Sun, 4 Jul 2021 14:57:50 +0000 Subject: [PATCH] Bit Flag Conditional --- pystencils/bit_masks.py | 52 ++++++++++++++++++++++++++++++ pystencils/transformations.py | 5 +++ pystencils_tests/test_bit_masks.py | 41 +++++++++++++++++++++++ 3 files changed, 98 insertions(+) create mode 100644 pystencils/bit_masks.py create mode 100644 pystencils_tests/test_bit_masks.py diff --git a/pystencils/bit_masks.py b/pystencils/bit_masks.py new file mode 100644 index 000000000..0fab63b25 --- /dev/null +++ b/pystencils/bit_masks.py @@ -0,0 +1,52 @@ +import sympy as sp +from pystencils.data_types import get_type_of_expression + + +# noinspection PyPep8Naming +class flag_cond(sp.Function): + """Evaluates a flag condition on a bit mask, and returns the value of one of two expressions, + depending on whether the flag is set. + + Three argument version: + ``` + flag_cond(flag_bit, mask, expr) = expr if (flag_bit is set in mask) else 0 + ``` + + Four argument version: + ``` + flag_cond(flag_bit, mask, expr_then, expr_else) = expr_then if (flag_bit is set in mask) else expr_else + ``` + """ + + nargs = (3, 4) + + def __new__(cls, flag_bit, mask_expression, *expressions): + + flag_dtype = get_type_of_expression(flag_bit) + if not flag_dtype.is_int(): + raise ValueError('Argument flag_bit must be of integer type.') + + mask_dtype = get_type_of_expression(mask_expression) + if not mask_dtype.is_int(): + raise ValueError('Argument mask_expression must be of integer type.') + + return super().__new__(cls, flag_bit, mask_expression, *expressions) + + def to_c(self, print_func): + flag_bit = self.args[0] + mask = self.args[1] + + then_expression = self.args[2] + + flag_bit_code = print_func(flag_bit) + mask_code = print_func(mask) + then_code = print_func(then_expression) + + code = f"(({mask_code}) >> ({flag_bit_code}) & 1) * ({then_code})" + + if len(self.args) > 3: + else_expression = self.args[3] + else_code = print_func(else_expression) + code += f" + (({mask_code}) >> ({flag_bit_code}) ^ 1) * ({else_code})" + + return code diff --git a/pystencils/transformations.py b/pystencils/transformations.py index 1ce3774b6..b037225b9 100644 --- a/pystencils/transformations.py +++ b/pystencils/transformations.py @@ -21,6 +21,7 @@ from pystencils.kernelparameters import FieldPointerSymbol from pystencils.simp.assignment_collection import AssignmentCollection from pystencils.slicing import normalize_slice from pystencils.integer_functions import int_div +from pystencils.bit_masks import flag_cond class NestedScopes: @@ -876,6 +877,10 @@ class KernelConstraintsCheck: else cast_func(a, arg_type) for a in new_args] return rhs.func(*new_args) + elif isinstance(rhs, flag_cond): + # do not process the arguments to the bit shift - they must remain integers + processed_args = (self.process_expression(a) for a in rhs.args[2:]) + return flag_cond(rhs.args[0], rhs.args[1], *processed_args) elif isinstance(rhs, sp.Mul): new_args = [ self.process_expression(arg, type_constants) diff --git a/pystencils_tests/test_bit_masks.py b/pystencils_tests/test_bit_masks.py new file mode 100644 index 000000000..5b3ad66a3 --- /dev/null +++ b/pystencils_tests/test_bit_masks.py @@ -0,0 +1,41 @@ +import numpy as np +import sympy as sp +from pystencils import Field, Assignment, create_kernel +from pystencils.bit_masks import flag_cond +from pystencils import TypedSymbol + + +def test_flag_condition(): + f_arr = np.zeros((2,2,2), dtype=np.float64) + mask_arr = np.zeros((2,2), dtype=np.uint64) + + mask_arr[0,1] = (1<<3) + mask_arr[1,0] = (1<<5) + mask_arr[1,1] = (1<<3) + (1 << 5) + + f = Field.create_from_numpy_array('f', f_arr, index_dimensions=1) + mask = Field.create_from_numpy_array('mask', mask_arr) + + v1 = 42.3 + v2 = 39.7 + v3 = 119.87 + + assignments = [ + Assignment(f(0), flag_cond(3, mask(0), v1)), + Assignment(f(1), flag_cond(5, mask(0), v2, v3)) + ] + + kernel = create_kernel(assignments).compile() + kernel(f=f_arr, mask=mask_arr) + + reference = np.zeros((2,2,2), dtype=np.float64) + reference[0,1,0] = v1 + reference[1,1,0] = v1 + + reference[0,0,1] = v3 + reference[0,1,1] = v3 + + reference[1,0,1] = v2 + reference[1,1,1] = v2 + + np.testing.assert_array_equal(f_arr, reference) -- GitLab