Commit 2ed62a8e authored by Markus Holzer's avatar Markus Holzer
Browse files

Merge branch 'bit_masks' into 'master'

Bit Flag Conditional

See merge request pycodegen/pystencils!257
parents 699144dd 9f501a36
import sympy as sp
from pystencils.data_types import get_type_of_expression
# noinspection PyPep8Naming
class flag_cond(sp.Function):
"""Evaluates a flag condition on a bit mask, and returns the value of one of two expressions,
depending on whether the flag is set.
Three argument version:
```
flag_cond(flag_bit, mask, expr) = expr if (flag_bit is set in mask) else 0
```
Four argument version:
```
flag_cond(flag_bit, mask, expr_then, expr_else) = expr_then if (flag_bit is set in mask) else expr_else
```
"""
nargs = (3, 4)
def __new__(cls, flag_bit, mask_expression, *expressions):
flag_dtype = get_type_of_expression(flag_bit)
if not flag_dtype.is_int():
raise ValueError('Argument flag_bit must be of integer type.')
mask_dtype = get_type_of_expression(mask_expression)
if not mask_dtype.is_int():
raise ValueError('Argument mask_expression must be of integer type.')
return super().__new__(cls, flag_bit, mask_expression, *expressions)
def to_c(self, print_func):
flag_bit = self.args[0]
mask = self.args[1]
then_expression = self.args[2]
flag_bit_code = print_func(flag_bit)
mask_code = print_func(mask)
then_code = print_func(then_expression)
code = f"(({mask_code}) >> ({flag_bit_code}) & 1) * ({then_code})"
if len(self.args) > 3:
else_expression = self.args[3]
else_code = print_func(else_expression)
code += f" + (({mask_code}) >> ({flag_bit_code}) ^ 1) * ({else_code})"
return code
......@@ -21,6 +21,7 @@ from pystencils.kernelparameters import FieldPointerSymbol
from pystencils.simp.assignment_collection import AssignmentCollection
from pystencils.slicing import normalize_slice
from pystencils.integer_functions import int_div
from pystencils.bit_masks import flag_cond
class NestedScopes:
......@@ -876,6 +877,10 @@ class KernelConstraintsCheck:
else cast_func(a, arg_type)
for a in new_args]
return rhs.func(*new_args)
elif isinstance(rhs, flag_cond):
# do not process the arguments to the bit shift - they must remain integers
processed_args = (self.process_expression(a) for a in rhs.args[2:])
return flag_cond(rhs.args[0], rhs.args[1], *processed_args)
elif isinstance(rhs, sp.Mul):
new_args = [
self.process_expression(arg, type_constants)
......
import numpy as np
import sympy as sp
from pystencils import Field, Assignment, create_kernel
from pystencils.bit_masks import flag_cond
from pystencils import TypedSymbol
def test_flag_condition():
f_arr = np.zeros((2,2,2), dtype=np.float64)
mask_arr = np.zeros((2,2), dtype=np.uint64)
mask_arr[0,1] = (1<<3)
mask_arr[1,0] = (1<<5)
mask_arr[1,1] = (1<<3) + (1 << 5)
f = Field.create_from_numpy_array('f', f_arr, index_dimensions=1)
mask = Field.create_from_numpy_array('mask', mask_arr)
v1 = 42.3
v2 = 39.7
v3 = 119.87
assignments = [
Assignment(f(0), flag_cond(3, mask(0), v1)),
Assignment(f(1), flag_cond(5, mask(0), v2, v3))
]
kernel = create_kernel(assignments).compile()
kernel(f=f_arr, mask=mask_arr)
reference = np.zeros((2,2,2), dtype=np.float64)
reference[0,1,0] = v1
reference[1,1,0] = v1
reference[0,0,1] = v3
reference[0,1,1] = v3
reference[1,0,1] = v2
reference[1,1,1] = v2
np.testing.assert_array_equal(f_arr, reference)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment