diff --git a/pystencils/data_types.py b/pystencils/data_types.py index 8692ec51d0312df8b28399171835f2a563aa7ef9..23dcf4f1260ac75bf6c01afd037fd3cdda51f545 100644 --- a/pystencils/data_types.py +++ b/pystencils/data_types.py @@ -47,6 +47,11 @@ class cast_func(sp.Function): is_Atom = True def __new__(cls, *args, **kwargs): + if len(args) != 2: + pass + expr, dtype, *other_args = args + if not isinstance(dtype, Type): + dtype = create_type(dtype) # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads # to problems when for example comparing cast_func's for equality @@ -55,9 +60,9 @@ class cast_func(sp.Function): # rhs = cast_func(0, 'int') # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans # -> thus a separate class boolean_cast_func is introduced - if isinstance(args[0], Boolean): + if isinstance(expr, Boolean): cls = boolean_cast_func - return sp.Function.__new__(cls, *args, **kwargs) + return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs) @property def canonical(self): @@ -471,7 +476,7 @@ class BasicType(Type): return 1 def is_int(self): - return self.numpy_dtype in np.sctypes['int'] + return self.numpy_dtype in np.sctypes['int'] or self.numpy_dtype in np.sctypes['uint'] def is_float(self): return self.numpy_dtype in np.sctypes['float'] diff --git a/pystencils/integer_functions.py b/pystencils/integer_functions.py index 4e583d9ab53469deef93c665c26876ae08364f4a..b54bbaab216676b62244c6ab907e36af2f762959 100644 --- a/pystencils/integer_functions.py +++ b/pystencils/integer_functions.py @@ -1,15 +1,64 @@ +import numpy as np import sympy as sp -from pystencils.data_types import collate_types, get_type_of_expression +from pystencils.data_types import cast_func, collate_types, create_type, get_type_of_expression from pystencils.sympyextensions import is_integer_sequence -bitwise_xor = sp.Function("bitwise_xor") -bit_shift_right = sp.Function("bit_shift_right") -bit_shift_left = sp.Function("bit_shift_left") -bitwise_and = sp.Function("bitwise_and") -bitwise_or = sp.Function("bitwise_or") -int_div = sp.Function("int_div") -int_power_of_2 = sp.Function("int_power_of_2") + +class IntegerFunctionTwoArgsMixIn(sp.Function): + def __new__(cls, arg1, arg2): + args = [] + for a in (arg1, arg2): + if isinstance(a, sp.Number) or isinstance(a, int): + args.append(cast_func(a, create_type("int"))) + elif isinstance(a, np.generic): + args.append(cast_func(a, a.dtype)) + else: + args.append(a) + + for a in args: + try: + type = get_type_of_expression(a) + if not type.is_int(): + raise ValueError("Argument to integer function is not an int but " + str(type)) + except NotImplementedError: + raise ValueError("Integer functions can only be constructed with typed expressions") + return super().__new__(cls, *args) + + +# noinspection PyPep8Naming +class bitwise_xor(IntegerFunctionTwoArgsMixIn): + pass + + +# noinspection PyPep8Naming +class bit_shift_right(IntegerFunctionTwoArgsMixIn): + pass + + +# noinspection PyPep8Naming +class bit_shift_left(IntegerFunctionTwoArgsMixIn): + pass + + +# noinspection PyPep8Naming +class bitwise_and(IntegerFunctionTwoArgsMixIn): + pass + + +# noinspection PyPep8Naming +class bitwise_or(IntegerFunctionTwoArgsMixIn): + pass + + +# noinspection PyPep8Naming +class int_div(IntegerFunctionTwoArgsMixIn): + pass + + +# noinspection PyPep8Naming +class int_power_of_2(IntegerFunctionTwoArgsMixIn): + pass # noinspection PyPep8Naming diff --git a/pystencils/transformations.py b/pystencils/transformations.py index 60a22d812578db85d7f375a74144352e431ffc62..39a9abddbad8dab1ae6c4e3361a43ff1885c4ee8 100644 --- a/pystencils/transformations.py +++ b/pystencils/transformations.py @@ -5,6 +5,7 @@ from collections import OrderedDict, defaultdict, namedtuple from copy import deepcopy from types import MappingProxyType +import numpy as np import sympy as sp from sympy.logic.boolalg import Boolean @@ -802,6 +803,8 @@ class KernelConstraintsCheck: return rhs elif isinstance(rhs, sp.Symbol): return TypedSymbol(rhs.name, self._type_for_symbol[rhs.name]) + elif type_constants and isinstance(rhs, np.generic): + return cast_func(rhs, create_type(rhs.dtype)) elif type_constants and isinstance(rhs, sp.Number): return cast_func(rhs, create_type(self._type_for_symbol['_constant'])) elif isinstance(rhs, sp.Mul): @@ -809,6 +812,8 @@ class KernelConstraintsCheck: return rhs.func(*new_args) if new_args else rhs elif isinstance(rhs, sp.Indexed): return rhs + elif isinstance(rhs, cast_func): + return cast_func(self.process_expression(rhs.args[0], type_constants=False), rhs.dtype) else: if isinstance(rhs, sp.Pow): # don't process exponents -> they should remain integers diff --git a/pystencils_tests/test_address_of.py b/pystencils_tests/test_address_of.py index 6e23d5ffd6513cba16acb2d711a9de0cf16c222e..c31091cee2819a1a9a8fc7db851c51a9afec5be9 100644 --- a/pystencils_tests/test_address_of.py +++ b/pystencils_tests/test_address_of.py @@ -1,10 +1,7 @@ - """ Test of pystencils.data_types.address_of """ -import sympy - import pystencils from pystencils.data_types import PointerType, address_of, cast_func from pystencils.simp.simplifications import sympy_cse @@ -48,12 +45,3 @@ def test_address_of_with_cse(): ast = pystencils.create_kernel(assignments_cse) code = pystencils.show_code(ast) print(code) - - -def main(): - test_address_of() - test_address_of_with_cse() - - -if __name__ == '__main__': - main()