From 7cf5b5bf57b7f7b41210150608300686cbb08d75 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Fri, 16 Aug 2019 13:19:49 +0200
Subject: [PATCH] Improved handling of integer functions and typing of
 constants

- numpy constants get directly their numpy type
- integer functions check for integer types at construction
---
 pystencils/data_types.py            | 11 +++--
 pystencils/integer_functions.py     | 65 +++++++++++++++++++++++++----
 pystencils/transformations.py       |  5 +++
 pystencils_tests/test_address_of.py | 12 ------
 4 files changed, 70 insertions(+), 23 deletions(-)

diff --git a/pystencils/data_types.py b/pystencils/data_types.py
index 8692ec5..23dcf4f 100644
--- a/pystencils/data_types.py
+++ b/pystencils/data_types.py
@@ -47,6 +47,11 @@ class cast_func(sp.Function):
     is_Atom = True
 
     def __new__(cls, *args, **kwargs):
+        if len(args) != 2:
+            pass
+        expr, dtype, *other_args = args
+        if not isinstance(dtype, Type):
+            dtype = create_type(dtype)
         # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
         # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
         # to problems when for example comparing cast_func's for equality
@@ -55,9 +60,9 @@ class cast_func(sp.Function):
         # rhs = cast_func(0, 'int')
         # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
         # -> thus a separate class boolean_cast_func is introduced
-        if isinstance(args[0], Boolean):
+        if isinstance(expr, Boolean):
             cls = boolean_cast_func
-        return sp.Function.__new__(cls, *args, **kwargs)
+        return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs)
 
     @property
     def canonical(self):
@@ -471,7 +476,7 @@ class BasicType(Type):
         return 1
 
     def is_int(self):
-        return self.numpy_dtype in np.sctypes['int']
+        return self.numpy_dtype in np.sctypes['int'] or self.numpy_dtype in np.sctypes['uint']
 
     def is_float(self):
         return self.numpy_dtype in np.sctypes['float']
diff --git a/pystencils/integer_functions.py b/pystencils/integer_functions.py
index 4e583d9..b54bbaa 100644
--- a/pystencils/integer_functions.py
+++ b/pystencils/integer_functions.py
@@ -1,15 +1,64 @@
+import numpy as np
 import sympy as sp
 
-from pystencils.data_types import collate_types, get_type_of_expression
+from pystencils.data_types import cast_func, collate_types, create_type, get_type_of_expression
 from pystencils.sympyextensions import is_integer_sequence
 
-bitwise_xor = sp.Function("bitwise_xor")
-bit_shift_right = sp.Function("bit_shift_right")
-bit_shift_left = sp.Function("bit_shift_left")
-bitwise_and = sp.Function("bitwise_and")
-bitwise_or = sp.Function("bitwise_or")
-int_div = sp.Function("int_div")
-int_power_of_2 = sp.Function("int_power_of_2")
+
+class IntegerFunctionTwoArgsMixIn(sp.Function):
+    def __new__(cls, arg1, arg2):
+        args = []
+        for a in (arg1, arg2):
+            if isinstance(a, sp.Number) or isinstance(a, int):
+                args.append(cast_func(a, create_type("int")))
+            elif isinstance(a, np.generic):
+                args.append(cast_func(a, a.dtype))
+            else:
+                args.append(a)
+
+        for a in args:
+            try:
+                type = get_type_of_expression(a)
+                if not type.is_int():
+                    raise ValueError("Argument to integer function is not an int but " + str(type))
+            except NotImplementedError:
+                raise ValueError("Integer functions can only be constructed with typed expressions")
+        return super().__new__(cls, *args)
+
+
+# noinspection PyPep8Naming
+class bitwise_xor(IntegerFunctionTwoArgsMixIn):
+    pass
+
+
+# noinspection PyPep8Naming
+class bit_shift_right(IntegerFunctionTwoArgsMixIn):
+    pass
+
+
+# noinspection PyPep8Naming
+class bit_shift_left(IntegerFunctionTwoArgsMixIn):
+    pass
+
+
+# noinspection PyPep8Naming
+class bitwise_and(IntegerFunctionTwoArgsMixIn):
+    pass
+
+
+# noinspection PyPep8Naming
+class bitwise_or(IntegerFunctionTwoArgsMixIn):
+    pass
+
+
+# noinspection PyPep8Naming
+class int_div(IntegerFunctionTwoArgsMixIn):
+    pass
+
+
+# noinspection PyPep8Naming
+class int_power_of_2(IntegerFunctionTwoArgsMixIn):
+    pass
 
 
 # noinspection PyPep8Naming
diff --git a/pystencils/transformations.py b/pystencils/transformations.py
index 60a22d8..39a9abd 100644
--- a/pystencils/transformations.py
+++ b/pystencils/transformations.py
@@ -5,6 +5,7 @@ from collections import OrderedDict, defaultdict, namedtuple
 from copy import deepcopy
 from types import MappingProxyType
 
+import numpy as np
 import sympy as sp
 from sympy.logic.boolalg import Boolean
 
@@ -802,6 +803,8 @@ class KernelConstraintsCheck:
             return rhs
         elif isinstance(rhs, sp.Symbol):
             return TypedSymbol(rhs.name, self._type_for_symbol[rhs.name])
+        elif type_constants and isinstance(rhs, np.generic):
+            return cast_func(rhs, create_type(rhs.dtype))
         elif type_constants and isinstance(rhs, sp.Number):
             return cast_func(rhs, create_type(self._type_for_symbol['_constant']))
         elif isinstance(rhs, sp.Mul):
@@ -809,6 +812,8 @@ class KernelConstraintsCheck:
             return rhs.func(*new_args) if new_args else rhs
         elif isinstance(rhs, sp.Indexed):
             return rhs
+        elif isinstance(rhs, cast_func):
+            return cast_func(self.process_expression(rhs.args[0], type_constants=False), rhs.dtype)
         else:
             if isinstance(rhs, sp.Pow):
                 # don't process exponents -> they should remain integers
diff --git a/pystencils_tests/test_address_of.py b/pystencils_tests/test_address_of.py
index 6e23d5f..c31091c 100644
--- a/pystencils_tests/test_address_of.py
+++ b/pystencils_tests/test_address_of.py
@@ -1,10 +1,7 @@
-
 """
 Test of pystencils.data_types.address_of
 """
 
-import sympy
-
 import pystencils
 from pystencils.data_types import PointerType, address_of, cast_func
 from pystencils.simp.simplifications import sympy_cse
@@ -48,12 +45,3 @@ def test_address_of_with_cse():
     ast = pystencils.create_kernel(assignments_cse)
     code = pystencils.show_code(ast)
     print(code)
-
-
-def main():
-    test_address_of()
-    test_address_of_with_cse()
-
-
-if __name__ == '__main__':
-    main()
-- 
GitLab