From 2f1c1194d0967322e0d83f2804677ea51de4d469 Mon Sep 17 00:00:00 2001 From: Markus Holzer <markus.holzer@fau.de> Date: Fri, 21 Jan 2022 14:27:08 +0100 Subject: [PATCH] Fix and improved bit mask support --- pystencils/bit_masks.py | 17 +++++++++-------- pystencils/typing/leaf_typing.py | 5 ++++- pystencils_tests/test_bit_masks.py | 12 +++++++++--- 3 files changed, 22 insertions(+), 12 deletions(-) diff --git a/pystencils/bit_masks.py b/pystencils/bit_masks.py index 73c18688c..ad4e967f0 100644 --- a/pystencils/bit_masks.py +++ b/pystencils/bit_masks.py @@ -1,5 +1,5 @@ import sympy as sp -from pystencils.typing import get_type_of_expression +# from pystencils.typing import get_type_of_expression # noinspection PyPep8Naming @@ -22,13 +22,14 @@ class flag_cond(sp.Function): def __new__(cls, flag_bit, mask_expression, *expressions): - flag_dtype = get_type_of_expression(flag_bit) - if not flag_dtype.is_int(): - raise ValueError('Argument flag_bit must be of integer type.') - - mask_dtype = get_type_of_expression(mask_expression) - if not mask_dtype.is_int(): - raise ValueError('Argument mask_expression must be of integer type.') + # TODO reintroduce checking + # flag_dtype = get_type_of_expression(flag_bit) + # if not flag_dtype.is_int(): + # raise ValueError('Argument flag_bit must be of integer type.') + # + # mask_dtype = get_type_of_expression(mask_expression) + # if not mask_dtype.is_int(): + # raise ValueError('Argument mask_expression must be of integer type.') return super().__new__(cls, flag_bit, mask_expression, *expressions) diff --git a/pystencils/typing/leaf_typing.py b/pystencils/typing/leaf_typing.py index b620c9c7e..ef81b529d 100644 --- a/pystencils/typing/leaf_typing.py +++ b/pystencils/typing/leaf_typing.py @@ -175,7 +175,10 @@ class TypeAdder: raise NotImplementedError('integer_functions') elif isinstance(expr, flag_cond): # do not process the arguments to the bit shift - they must remain integers - raise NotImplementedError('flag_cond') + args_types = [self.figure_out_type(a) for a in (expr.args[i] for i in range(2, len(expr.args)))] + collated_type = collate_types([t for _, t in args_types]) + new_expressions = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types] + return expr.func(expr.args[0], expr.args[1], *new_expressions), collated_type #elif isinstance(expr, sp.Mul): # raise NotImplementedError('sp.Mul') # # TODO can we ignore this and move it to general expr handling, i.e. removing Mul? diff --git a/pystencils_tests/test_bit_masks.py b/pystencils_tests/test_bit_masks.py index 57371976f..423fc13cc 100644 --- a/pystencils_tests/test_bit_masks.py +++ b/pystencils_tests/test_bit_masks.py @@ -1,11 +1,15 @@ +import pytest import numpy as np + +import pystencils as ps from pystencils import Field, Assignment, create_kernel from pystencils.bit_masks import flag_cond -def test_flag_condition(): +@pytest.mark.parametrize('mask_type', [np.uint8, np.uint16, np.uint32, np.uint64]) +def test_flag_condition(mask_type): f_arr = np.zeros((2, 2, 2), dtype=np.float64) - mask_arr = np.zeros((2, 2), dtype=np.uint64) + mask_arr = np.zeros((2, 2), dtype=mask_type) mask_arr[0, 1] = (1 << 3) mask_arr[1, 0] = (1 << 5) @@ -16,7 +20,7 @@ def test_flag_condition(): v1 = 42.3 v2 = 39.7 - v3 = 119.87 + v3 = 119 assignments = [ Assignment(f(0), flag_cond(3, mask(0), v1)), @@ -25,6 +29,8 @@ def test_flag_condition(): kernel = create_kernel(assignments).compile() kernel(f=f_arr, mask=mask_arr) + code = ps.get_code_str(kernel) + assert '119.0' in code reference = np.zeros((2, 2, 2), dtype=np.float64) reference[0, 1, 0] = v1 -- GitLab