Skip to content
Snippets Groups Projects
Commit 2f1c1194 authored by Markus Holzer's avatar Markus Holzer
Browse files

Fix and improved bit mask support

parent 69286c9b
Branches
Tags
No related merge requests found
import sympy as sp import sympy as sp
from pystencils.typing import get_type_of_expression # from pystencils.typing import get_type_of_expression
# noinspection PyPep8Naming # noinspection PyPep8Naming
...@@ -22,13 +22,14 @@ class flag_cond(sp.Function): ...@@ -22,13 +22,14 @@ class flag_cond(sp.Function):
def __new__(cls, flag_bit, mask_expression, *expressions): def __new__(cls, flag_bit, mask_expression, *expressions):
flag_dtype = get_type_of_expression(flag_bit) # TODO reintroduce checking
if not flag_dtype.is_int(): # flag_dtype = get_type_of_expression(flag_bit)
raise ValueError('Argument flag_bit must be of integer type.') # if not flag_dtype.is_int():
# raise ValueError('Argument flag_bit must be of integer type.')
mask_dtype = get_type_of_expression(mask_expression) #
if not mask_dtype.is_int(): # mask_dtype = get_type_of_expression(mask_expression)
raise ValueError('Argument mask_expression must be of integer type.') # if not mask_dtype.is_int():
# raise ValueError('Argument mask_expression must be of integer type.')
return super().__new__(cls, flag_bit, mask_expression, *expressions) return super().__new__(cls, flag_bit, mask_expression, *expressions)
......
...@@ -175,7 +175,10 @@ class TypeAdder: ...@@ -175,7 +175,10 @@ class TypeAdder:
raise NotImplementedError('integer_functions') raise NotImplementedError('integer_functions')
elif isinstance(expr, flag_cond): elif isinstance(expr, flag_cond):
# do not process the arguments to the bit shift - they must remain integers # do not process the arguments to the bit shift - they must remain integers
raise NotImplementedError('flag_cond') args_types = [self.figure_out_type(a) for a in (expr.args[i] for i in range(2, len(expr.args)))]
collated_type = collate_types([t for _, t in args_types])
new_expressions = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
return expr.func(expr.args[0], expr.args[1], *new_expressions), collated_type
#elif isinstance(expr, sp.Mul): #elif isinstance(expr, sp.Mul):
# raise NotImplementedError('sp.Mul') # raise NotImplementedError('sp.Mul')
# # TODO can we ignore this and move it to general expr handling, i.e. removing Mul? # # TODO can we ignore this and move it to general expr handling, i.e. removing Mul?
......
import pytest
import numpy as np import numpy as np
import pystencils as ps
from pystencils import Field, Assignment, create_kernel from pystencils import Field, Assignment, create_kernel
from pystencils.bit_masks import flag_cond from pystencils.bit_masks import flag_cond
def test_flag_condition(): @pytest.mark.parametrize('mask_type', [np.uint8, np.uint16, np.uint32, np.uint64])
def test_flag_condition(mask_type):
f_arr = np.zeros((2, 2, 2), dtype=np.float64) f_arr = np.zeros((2, 2, 2), dtype=np.float64)
mask_arr = np.zeros((2, 2), dtype=np.uint64) mask_arr = np.zeros((2, 2), dtype=mask_type)
mask_arr[0, 1] = (1 << 3) mask_arr[0, 1] = (1 << 3)
mask_arr[1, 0] = (1 << 5) mask_arr[1, 0] = (1 << 5)
...@@ -16,7 +20,7 @@ def test_flag_condition(): ...@@ -16,7 +20,7 @@ def test_flag_condition():
v1 = 42.3 v1 = 42.3
v2 = 39.7 v2 = 39.7
v3 = 119.87 v3 = 119
assignments = [ assignments = [
Assignment(f(0), flag_cond(3, mask(0), v1)), Assignment(f(0), flag_cond(3, mask(0), v1)),
...@@ -25,6 +29,8 @@ def test_flag_condition(): ...@@ -25,6 +29,8 @@ def test_flag_condition():
kernel = create_kernel(assignments).compile() kernel = create_kernel(assignments).compile()
kernel(f=f_arr, mask=mask_arr) kernel(f=f_arr, mask=mask_arr)
code = ps.get_code_str(kernel)
assert '119.0' in code
reference = np.zeros((2, 2, 2), dtype=np.float64) reference = np.zeros((2, 2, 2), dtype=np.float64)
reference[0, 1, 0] = v1 reference[0, 1, 0] = v1
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment