From 0c9e9fcdff09fc8ed4146ca1cb8cd247975b3d48 Mon Sep 17 00:00:00 2001
From: Markus Holzer <markus.holzer@fau.de>
Date: Fri, 10 Dec 2021 14:17:07 +0100
Subject: [PATCH] Support min max and conditional Field Access

---
 pystencils/backends/cbackend.py               |  2 -
 pystencils/config.py                          |  1 +
 pystencils/typing/leaf_typing.py              | 65 +++++++++++--------
 pystencils_tests/test_Min_Max.py              | 65 +++++++++++++++++--
 .../test_conditional_field_access.py          | 30 ++++-----
 pystencils_tests/test_types.py                |  6 +-
 6 files changed, 113 insertions(+), 56 deletions(-)

diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py
index 4aa1d0964..5b27f7461 100644
--- a/pystencils/backends/cbackend.py
+++ b/pystencils/backends/cbackend.py
@@ -501,8 +501,6 @@ class CustomSympyPrinter(CCodePrinter):
             return f"({self._print(1 / sp.sqrt(expr.args[0]))})"
         elif isinstance(expr, sp.Abs):
             return f"abs({self._print(expr.args[0])})"
-        elif isinstance(expr, sp.Max):
-            return self._print(expr)
         elif isinstance(expr, sp.Mod):
             if expr.args[0].is_integer and expr.args[1].is_integer:
                 return f"({self._print(expr.args[0])} % {self._print(expr.args[1])})"
diff --git a/pystencils/config.py b/pystencils/config.py
index 97f2e2d8a..936a92cf2 100644
--- a/pystencils/config.py
+++ b/pystencils/config.py
@@ -31,6 +31,7 @@ class CreateKernelConfig:
     """
     # TODO: config should check that the datatype is a Numpy type
     # TODO: check for the python types and issue warnings
+    # TODO: QoL default_number_float and default_number_int should be data_type if they are not specified by the user
     data_type: Union[str, Dict[str, BasicType]] = 'float64'
     """
     Data type used for all untyped symbols (i.e. non-fields), can also be a dict from symbol name to type
diff --git a/pystencils/typing/leaf_typing.py b/pystencils/typing/leaf_typing.py
index 1c84ffd65..04bfacbf4 100644
--- a/pystencils/typing/leaf_typing.py
+++ b/pystencils/typing/leaf_typing.py
@@ -1,6 +1,5 @@
 from collections import namedtuple, defaultdict
-from copy import copy
-from typing import Union, Dict, Tuple, Any
+from typing import Union, Tuple, Any
 import logging
 
 import numpy as np
@@ -14,8 +13,9 @@ from sympy.logic.boolalg import BooleanFunction
 from sympy.logic.boolalg import BooleanAtom
 
 from pystencils import astnodes as ast
+from pystencils.functions import DivFunc
 from pystencils.field import Field
-from pystencils.typing.types import AbstractType, BasicType, create_type
+from pystencils.typing.types import BasicType, create_type
 from pystencils.typing.utilities import get_type_of_expression, collate_types
 from pystencils.typing.cast_functions import CastFunc, BooleanCastFunc
 from pystencils.typing.typed_sympy import TypedSymbol
@@ -40,9 +40,9 @@ class TypeAdder:
     """
     FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index'])
 
-    def __init__(self, type_for_symbol: Dict[str, BasicType], default_number_float: BasicType,
+    def __init__(self, type_for_symbol: defaultdict[str, BasicType], default_number_float: BasicType,
                  default_number_int: BasicType):
-        self.type_for_symbol = ContextVar(type_for_symbol)
+        self.type_for_symbol = type_for_symbol
         self.default_number_float = ContextVar(default_number_float)
         self.default_number_int = ContextVar(default_number_int)
 
@@ -72,13 +72,16 @@ class TypeAdder:
     def process_assignment(self, assignment: Union[sp.Eq, ast.SympyAssignment, Assignment]) -> ast.SympyAssignment:
         # for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1
         new_rhs, rhs_type = self.figure_out_type(assignment.rhs)
-        # TODO:
-        dt = copy(rhs_type)  # The copy is necessary because BasicType has sympy shinanigans
-        dd = defaultdict(lambda: BasicType(dt))
-        dd.update(self.type_for_symbol.get())
-        with self.type_for_symbol(dd):
-            new_lhs, lhs_type = self.figure_out_type(assignment.lhs)
-        # TODO add symbol to dict with type if defined!
+
+        lhs = assignment.lhs
+        if not isinstance(lhs, (Field.Access, TypedSymbol)):
+            if isinstance(lhs, sp.Symbol):
+                self.type_for_symbol[lhs.name] = rhs_type
+            else:
+                raise ValueError(f'Lhs: `{lhs}` is not a subtype of sp.Symbol')
+        new_lhs, lhs_type = self.figure_out_type(lhs)
+        assert isinstance(new_lhs, (Field.Access, TypedSymbol))
+
         if lhs_type != rhs_type:
             logging.warning(f'Lhs"{new_lhs} of type "{lhs_type}" is assigned with a different datatype '
                             f'rhs: "{new_rhs}" of type "{rhs_type}".')
@@ -89,7 +92,8 @@ class TypeAdder:
     # Type System Specification
     # - Defined Types: TypedSymbol, Field, Field.Access, ...?
     # - Indexed: always unsigned_integer64
-    # - Undefined Types: Symbol - Is specified in Config in the dict or as 'default_type'
+    # - Undefined Types: Symbol
+    #       - Is specified in Config in the dict or as 'default_type' or behaves like `auto` in the case of lhs.
     # - Constants/Numbers: Are either integer or floating. The precision and sign is specified via config
     #       - Example: 1.4 config:float32 -> float32
     # - Expressions deduce types from arguments
@@ -100,7 +104,7 @@ class TypeAdder:
     # Possible Problems - Do we need to support this?
     # - Mixture in expression with int and float
     # - Mixture in expression with uint64 and sint64
-
+    # TODO: Lowest log level should log all casts ----> cast factory, make cast should contain logging
     def figure_out_type(self, expr) -> Tuple[Any, BasicType]:  # TODO or abstract type? vector type?
         # Trivial cases
         from pystencils.field import Field
@@ -113,7 +117,7 @@ class TypeAdder:
         elif isinstance(expr, TypedSymbol):
             return expr, expr.dtype
         elif isinstance(expr, sp.Symbol):
-            t = TypedSymbol(expr.name, self.type_for_symbol.get()[expr.name])  # TODO with or without name
+            t = TypedSymbol(expr.name, self.type_for_symbol[expr.name])  # TODO with or without name
             return t, t.dtype
         elif isinstance(expr, np.generic):
             assert False, f'Why do we have a np.generic in rhs???? {expr}'
@@ -139,6 +143,22 @@ class TypeAdder:
         elif isinstance(expr, CastFunc):
             new_expr, _ = self.figure_out_type(expr.expr)
             return expr.func(*[new_expr, expr.dtype]), expr.dtype
+        elif isinstance(expr, ast.ConditionalFieldAccess):
+            access, access_type = self.figure_out_type(expr.access)
+            value, value_type = self.figure_out_type(expr.outofbounds_value)
+            condition, condition_type = self.figure_out_type(expr.outofbounds_condition)
+            assert condition_type == bool_type
+            collated_type = collate_types([access_type, value_type])
+            if collated_type == access_type:
+                new_access = access
+            else:
+                logging.warning(f"In {expr} the Field Access had to be casted to {collated_type}. This is "
+                                f"probably due to a type missmatch of the Field and the value of "
+                                f"ConditionalFieldAccess")
+                new_access = CastFunc(access, collated_type)
+
+            new_value = value if value_type == collated_type else CastFunc(value, collated_type)
+            return expr.func(new_access, condition, new_value), collated_type
         elif isinstance(expr, BooleanFunction):
             args_types = [self.figure_out_type(a) for a in expr.args]
             new_args = [a if t.dtype_eq(bool_type) else BooleanCastFunc(a, bool_type) for a, t in args_types]
@@ -177,16 +197,15 @@ class TypeAdder:
                 else:
                     new_args.append(a)
             return expr.func(*new_args) if new_args else expr, collated_type
-        else:
+        elif isinstance(expr, (sp.Add, sp.Mul, sp.Abs, sp.Min, sp.Max, DivFunc)):
             args_types = [self.figure_out_type(arg) for arg in expr.args]
             collated_type = collate_types([t for _, t in args_types])
             new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
             return expr.func(*new_args) if new_args else expr, collated_type
+        else:
+            raise NotImplementedError(f'expr {expr} unknown to typing')
 
-    def apply_type(self, expr, data_type: AbstractType):
-        pass
-
-    def process_expression(self, rhs, type_constants=True):  # TODO default_type as parameter
+    def process_expression(self, rhs, type_constants=True):  # TODO DELETE
         import pystencils.integer_functions
         from pystencils.bit_masks import flag_cond
 
@@ -242,9 +261,3 @@ class TypeAdder:
         else:
             new_args = [self.process_expression(arg, type_constants) for arg in rhs.args]
             return rhs.func(*new_args) if new_args else rhs
-
-    def process_lhs(self, lhs: Union[Field.Access, TypedSymbol, sp.Symbol]):
-        if not isinstance(lhs, (Field.Access, TypedSymbol)):
-            return TypedSymbol(lhs.name, self.type_for_symbol.get()[lhs.name])
-        else:
-            return lhs
diff --git a/pystencils_tests/test_Min_Max.py b/pystencils_tests/test_Min_Max.py
index c227fbf14..7fb48b18d 100644
--- a/pystencils_tests/test_Min_Max.py
+++ b/pystencils_tests/test_Min_Max.py
@@ -6,31 +6,37 @@ import pystencils
 from pystencils.datahandling import create_data_handling
 
 
+@pytest.mark.parametrize('dtype', ["float64", "float32"])
 @pytest.mark.parametrize('sympy_function', [sp.Min, sp.Max])
-def test_max(sympy_function):
+def test_max(dtype, sympy_function):
     dh = create_data_handling(domain_size=(10, 10), periodicity=True)
 
-    x = dh.add_array('x', values_per_cell=1)
+    x = dh.add_array('x', values_per_cell=1, dtype=dtype)
     dh.fill("x", 0.0, ghost_layers=True)
-    y = dh.add_array('y', values_per_cell=1)
+    y = dh.add_array('y', values_per_cell=1, dtype=dtype)
     dh.fill("y", 1.0, ghost_layers=True)
-    z = dh.add_array('z', values_per_cell=1)
+    z = dh.add_array('z', values_per_cell=1, dtype=dtype)
     dh.fill("z", 2.0, ghost_layers=True)
 
+    config = pystencils.CreateKernelConfig(default_number_float=dtype)
+
     # test sp.Max with one argument
     assignment_1 = pystencils.Assignment(x.center, sympy_function(y.center + 3.3))
-    ast_1 = pystencils.create_kernel(assignment_1)
+    ast_1 = pystencils.create_kernel(assignment_1, config=config)
     kernel_1 = ast_1.compile()
+    # pystencils.show_code(ast_1)
 
     # test sp.Max with two arguments
     assignment_2 = pystencils.Assignment(x.center, sympy_function(0.5, y.center - 1.5))
-    ast_2 = pystencils.create_kernel(assignment_2)
+    ast_2 = pystencils.create_kernel(assignment_2, config=config)
     kernel_2 = ast_2.compile()
+    # pystencils.show_code(ast_2)
 
     # test sp.Max with many arguments
     assignment_3 = pystencils.Assignment(x.center, sympy_function(z.center, 4.5, y.center - 1.5, y.center + z.center))
-    ast_3 = pystencils.create_kernel(assignment_3)
+    ast_3 = pystencils.create_kernel(assignment_3, config=config)
     kernel_3 = ast_3.compile()
+    # pystencils.show_code(ast_3)
 
     if sympy_function is sp.Max:
         results = [4.3, 0.5, 4.5]
@@ -43,3 +49,48 @@ def test_max(sympy_function):
     assert numpy.all(dh.gather_array('x') == results[1])
     dh.run_kernel(kernel_3)
     assert numpy.all(dh.gather_array('x') == results[2])
+
+
+@pytest.mark.parametrize('dtype', ["int64", 'int32'])
+@pytest.mark.parametrize('sympy_function', [sp.Min, sp.Max])
+def test_max_integer(dtype, sympy_function):
+    dh = create_data_handling(domain_size=(10, 10), periodicity=True)
+
+    x = dh.add_array('x', values_per_cell=1, dtype=dtype)
+    dh.fill("x", 0, ghost_layers=True)
+    y = dh.add_array('y', values_per_cell=1, dtype=dtype)
+    dh.fill("y", 1, ghost_layers=True)
+    z = dh.add_array('z', values_per_cell=1, dtype=dtype)
+    dh.fill("z", 2, ghost_layers=True)
+
+    config = pystencils.CreateKernelConfig(default_number_int=dtype)
+
+    # test sp.Max with one argument
+    assignment_1 = pystencils.Assignment(x.center, sympy_function(y.center + 3))
+    ast_1 = pystencils.create_kernel(assignment_1, config=config)
+    kernel_1 = ast_1.compile()
+    # pystencils.show_code(ast_1)
+
+    # test sp.Max with two arguments
+    assignment_2 = pystencils.Assignment(x.center, sympy_function(1, y.center - 1))
+    ast_2 = pystencils.create_kernel(assignment_2, config=config)
+    kernel_2 = ast_2.compile()
+    # pystencils.show_code(ast_2)
+
+    # test sp.Max with many arguments
+    assignment_3 = pystencils.Assignment(x.center, sympy_function(z.center, 4, y.center - 1, y.center + z.center))
+    ast_3 = pystencils.create_kernel(assignment_3, config=config)
+    kernel_3 = ast_3.compile()
+    # pystencils.show_code(ast_3)
+
+    if sympy_function is sp.Max:
+        results = [4, 1, 4]
+    else:
+        results = [4, 0, 0]
+
+    dh.run_kernel(kernel_1)
+    assert numpy.all(dh.gather_array('x') == results[0])
+    dh.run_kernel(kernel_2)
+    assert numpy.all(dh.gather_array('x') == results[1])
+    dh.run_kernel(kernel_3)
+    assert numpy.all(dh.gather_array('x') == results[2])
diff --git a/pystencils_tests/test_conditional_field_access.py b/pystencils_tests/test_conditional_field_access.py
index f39d4767e..f8026c7dc 100644
--- a/pystencils_tests/test_conditional_field_access.py
+++ b/pystencils_tests/test_conditional_field_access.py
@@ -35,11 +35,11 @@ def add_fixed_constant_boundary_handling(assignments, with_cse):
             for a in assignment.rhs.atoms(Field.Access) if not a.is_absolute_access
         })) for assignment in assignments.all_assignments]
 
-    subs = [{a: ConditionalFieldAccess(a, is_out_of_bound(
-        sp.Matrix(a.offsets) + x_vector(ndim), common_shape))
-        for a in assignment.rhs.atoms(Field.Access) if not a.is_absolute_access
-    } for assignment in assignments.all_assignments]
-    print(subs)
+    # subs = [{a: ConditionalFieldAccess(a, is_out_of_bound(
+    #     sp.Matrix(a.offsets) + x_vector(ndim), common_shape))
+    #     for a in assignment.rhs.atoms(Field.Access) if not a.is_absolute_access
+    # } for assignment in assignments.all_assignments]
+    # print(subs)
 
     if with_cse:
         safe_assignments = sympy_cse(ps.AssignmentCollection(safe_assignments))
@@ -48,24 +48,20 @@ def add_fixed_constant_boundary_handling(assignments, with_cse):
         return ps.AssignmentCollection(safe_assignments)
 
 
+@pytest.mark.parametrize('dtype', ('float64', 'float32'))
 @pytest.mark.parametrize('with_cse', (False, 'with_cse'))
-def test_boundary_check(with_cse):
-    if not with_cse:
-        return True
+def test_boundary_check(dtype, with_cse):
+    f, g = ps.fields(f"f, g : {dtype}[2D]")
+    stencil = ps.Assignment(g[0, 0], (f[1, 0] + f[-1, 0] + f[0, 1] + f[0, -1]) / 4)
 
-    f, g = ps.fields("f, g : [2D]")
-    stencil = ps.Assignment(g[0, 0],
-                            (f[1, 0] + f[-1, 0] + f[0, 1] + f[0, -1]) / 4)
-
-    f_arr = np.random.rand(10, 10)
+    f_arr = np.random.rand(10, 10).astype(dtype=dtype)
     g_arr = np.zeros_like(f_arr)
-    # kernel(f=f_arr, g=g_arr)
 
     assignments = add_fixed_constant_boundary_handling(ps.AssignmentCollection([stencil]), with_cse)
 
-    print(assignments)
-    kernel_checked = ps.create_kernel(assignments, ghost_layers=0).compile()
-    ps.show_code(kernel_checked)
+    config = ps.CreateKernelConfig(data_type=dtype, default_number_float=dtype, ghost_layers=0)
+    kernel_checked = ps.create_kernel(assignments, config=config).compile()
+    # ps.show_code(kernel_checked)
 
     # No SEGFAULT, please!!
     kernel_checked(f=f_arr, g=g_arr)
diff --git a/pystencils_tests/test_types.py b/pystencils_tests/test_types.py
index c55816b05..164d941cf 100644
--- a/pystencils_tests/test_types.py
+++ b/pystencils_tests/test_types.py
@@ -159,7 +159,8 @@ def test_sqrt_of_integer(dtype):
     assignments = [ps.Assignment(tmp, sp.sqrt(3)),
                    ps.Assignment(f[0], tmp)]
     arr = np.array([1], dtype=dtype)
-    config = pystencils.config.CreateKernelConfig(data_type=dtype)
+    # TODO Jupyter add auto lhs float/double problem
+    config = pystencils.config.CreateKernelConfig(data_type=dtype, default_number_float=dtype)
 
     ast = ps.create_kernel(assignments, config=config)
     kernel = ast.compile()
@@ -189,9 +190,6 @@ def test_integer_comparision(dtype):
         t = "_data_f_00[_stride_f_1*ctr_1] = ((((dir) == (1))) ? (0.0): (_data_f_00[_stride_f_1*ctr_1]));"
     else:
         t = "_data_f_00[_stride_f_1*ctr_1] = ((((dir) == (1))) ? (0.0f): (_data_f_00[_stride_f_1*ctr_1]));"
-
-    print(code)
-
     assert t in code
 
 
-- 
GitLab