From 3c8507ba963c5e5c0d7006cd573134a40e52f26f Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Sun, 10 Mar 2024 16:53:38 +0100
Subject: [PATCH] introduce type atom wrapper. fix BC index dtype. Add CastFunc
 to freeze/typify.

---
 .../backend/kernelcreation/freeze.py          |  6 ++-
 .../backend/kernelcreation/typification.py    |  5 +++
 src/pystencils/boundaries/boundaryhandling.py | 43 +++++++++++--------
 src/pystencils/boundaries/createindexlist.py  |  4 +-
 src/pystencils/sympyextensions/typed_sympy.py | 25 +++++++++--
 tests/test_boundary.py                        |  3 --
 6 files changed, 60 insertions(+), 26 deletions(-)

diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py
index e0a5b130f..6861ff4c5 100644
--- a/src/pystencils/backend/kernelcreation/freeze.py
+++ b/src/pystencils/backend/kernelcreation/freeze.py
@@ -5,7 +5,7 @@ from operator import add, mul, sub
 import sympy as sp
 
 from ...sympyextensions import Assignment, AssignmentCollection
-from ...sympyextensions.typed_sympy import TypedSymbol
+from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc
 from ...field import Field, FieldType
 
 from .context import KernelCreationContext
@@ -26,6 +26,7 @@ from ..ast.expressions import (
     PsConstantExpr,
     PsArrayInitList,
     PsSubscript,
+    PsCast
 )
 
 from ..constants import PsConstant
@@ -305,3 +306,6 @@ class FreezeExpressions:
 
         args = tuple(self.visit_expr(arg) for arg in func.args)
         return PsCall(func_symbol, args)
+    
+    def map_CastFunc(self, cast_expr: CastFunc):
+        return PsCast(cast_expr.dtype, self.visit_expr(cast_expr.expr))
diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py
index cc526e895..972a47c8c 100644
--- a/src/pystencils/backend/kernelcreation/typification.py
+++ b/src/pystencils/backend/kernelcreation/typification.py
@@ -22,6 +22,7 @@ from ..ast.expressions import (
     PsLookup,
     PsCall,
     PsArrayInitList,
+    PsCast
 )
 from ..functions import PsMathFunction
 
@@ -280,5 +281,9 @@ class Typifier:
                     arr_type = PsArrayType(items_tc.target_type, len(items))
                     tc.apply_and_check(expr, arr_type)
 
+            case PsCast(dtype, operand):
+                self.visit_expr(operand, TypeContext())
+                tc.apply_and_check(expr, dtype)
+
             case _:
                 raise NotImplementedError(f"Can't typify {expr}")
diff --git a/src/pystencils/boundaries/boundaryhandling.py b/src/pystencils/boundaries/boundaryhandling.py
index e7b44099d..57a1cd95f 100644
--- a/src/pystencils/boundaries/boundaryhandling.py
+++ b/src/pystencils/boundaries/boundaryhandling.py
@@ -8,8 +8,8 @@ from pystencils.sympyextensions import Assignment
 from pystencils.boundaries.createindexlist import (
     create_boundary_index_array, numpy_data_type_for_boundary_object)
 from pystencils.sympyextensions import TypedSymbol
-from pystencils.defaults import DEFAULTS
-from pystencils.types.quick import Arr, create_type
+from pystencils.types import PsIntegerType
+from pystencils.types.quick import Arr, SInt
 from pystencils.gpu.gpu_array_handler import GPUArrayHandler
 from pystencils.field import Field, FieldType
 from pystencils.backend.kernelfunction import FieldPointerParam
@@ -417,37 +417,46 @@ class BoundaryOffsetInfo:
 
     @staticmethod
     def inv_dir(dir_idx):
-        return sp.IndexedBase(BoundaryOffsetInfo.INV_DIR_SYMBOL, shape=(1,))[dir_idx]
+        return sp.IndexedBase(BoundaryOffsetInfo._inv_dir_symbol(), shape=(1,))[dir_idx]
 
     # ---------------------------------- Internal ---------------------------------------------
 
-    @staticmethod
-    def get_array_declarations(stencil) -> list[Assignment]:
-        dim = len(stencil[0])
+    def __init__(self, stencil, index_dtype: PsIntegerType = SInt(32)) -> None:
+        self._stencil = stencil
+        self._dim = len(stencil[0])
+        self._index_dtype = index_dtype
+
+    def get_array_declarations(self) -> list[Assignment]:
         asms = []
-        for i, offset_symb in enumerate(BoundaryOffsetInfo._offset_symbols(dim)):
-            offsets = tuple(d[i] for d in stencil)
+        for i, offset_symb in enumerate(BoundaryOffsetInfo._offset_symbols(self._dim)):
+            offsets = tuple(d[i] for d in self._stencil)
             asms.append(Assignment(offset_symb, offsets))
 
         inv_dirs = []
-        for direction in stencil:
+        for direction in self._stencil:
             inverse_dir = tuple([-i for i in direction])
-            inv_dirs.append(str(stencil.index(inverse_dir)))
+            inv_dirs.append(str(self._stencil.index(inverse_dir)))
 
-        asms.append(Assignment(BoundaryOffsetInfo.INV_DIR_SYMBOL, tuple(inv_dirs)))
+        asms.append(Assignment(BoundaryOffsetInfo._inv_dir_symbol(), tuple(inv_dirs)))
         return asms
 
     @staticmethod
-    def _offset_symbols(dim):
-        return [TypedSymbol(f"c{d}", Arr(create_type(DEFAULTS.index_dtype))) for d in ['x', 'y', 'z'][:dim]]
+    def _offset_symbols(dim, dtype: PsIntegerType = SInt(32)):
+        return [TypedSymbol(f"c{d}", Arr(dtype)) for d in ['x', 'y', 'z'][:dim]]
 
-    INV_DIR_SYMBOL = TypedSymbol("invdir", Arr(create_type(DEFAULTS.index_dtype)))
+    @staticmethod
+    def _inv_dir_symbol(dtype: PsIntegerType = SInt(32)):
+        return TypedSymbol("invdir", Arr(dtype))
 
 
 def create_boundary_kernel(field, index_field, stencil, boundary_functor, target=Target.CPU, **kernel_creation_args):
-    elements = BoundaryOffsetInfo.get_array_declarations(stencil)
-    dir_symbol = TypedSymbol("dir", DEFAULTS.index_dtype)
+    #   TODO: reconsider how to control the index_dtype in boundary kernels
+    config = CreateKernelConfig(index_field=index_field, target=target, index_dtype=SInt(32), **kernel_creation_args)
+
+    offset_info = BoundaryOffsetInfo(stencil, config.index_dtype)
+    elements = offset_info.get_array_declarations()
+    dir_symbol = TypedSymbol("dir", config.index_dtype)
     elements += [Assignment(dir_symbol, index_field[0]('dir'))]
     elements += boundary_functor(field, direction_symbol=dir_symbol, index_field=index_field)
-    config = CreateKernelConfig(index_field=index_field, target=target, **kernel_creation_args)
+    
     return create_kernel(elements, config=config)
diff --git a/src/pystencils/boundaries/createindexlist.py b/src/pystencils/boundaries/createindexlist.py
index 34bd0766e..2bc346cc0 100644
--- a/src/pystencils/boundaries/createindexlist.py
+++ b/src/pystencils/boundaries/createindexlist.py
@@ -1,7 +1,7 @@
 import warnings
 
 import numpy as np
-from pystencils.defaults import DEFAULTS
+from pystencils.types.quick import SInt
 
 
 try:
@@ -22,7 +22,7 @@ if cython_funcs_available:
 
 boundary_index_array_coordinate_names = ["x", "y", "z"]
 direction_member_name = "dir"
-default_index_array_dtype = DEFAULTS.index_dtype
+default_index_array_dtype = SInt(32)
 
 
 def numpy_data_type_for_boundary_object(boundary_object, dim):
diff --git a/src/pystencils/sympyextensions/typed_sympy.py b/src/pystencils/sympyextensions/typed_sympy.py
index accea6d40..c49c46c39 100644
--- a/src/pystencils/sympyextensions/typed_sympy.py
+++ b/src/pystencils/sympyextensions/typed_sympy.py
@@ -34,6 +34,22 @@ def is_loop_counter_symbol(symbol):
         return None
 
 
+class PsTypeAtom(sp.Atom):
+    """Wrapper around a PsType to disguise it as a SymPy atom."""
+
+    def __new__(cls, *args, **kwargs):
+        return sp.Basic.__new__(cls)
+    
+    def __init__(self, dtype: PsType) -> None:
+        self._dtype = dtype
+
+    def _sympystr(self, *args, **kwargs):
+        return str(self._dtype)
+
+    def get(self) -> PsType:
+        return self._dtype
+
+
 class TypedSymbol(sp.Symbol):
     def __new__(cls, *args, **kwds):
         obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
@@ -192,7 +208,9 @@ class CastFunc(sp.Function):
         # This optimisation is only available for simple casts. Thus the == is intended here!
         if expr.__class__ == CastFunc:
             expr = expr.args[0]
-        dtype = create_type(dtype)
+
+        if not isinstance(dtype, PsTypeAtom):
+            dtype = PsTypeAtom(create_type(dtype))
         # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
         # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
         # to problems when for example comparing cast_func's for equality
@@ -220,8 +238,9 @@ class CastFunc(sp.Function):
         return self.args[0].is_commutative
 
     @property
-    def dtype(self):
-        return self.args[1]
+    def dtype(self) -> PsType:
+        assert isinstance(self.args[1], PsTypeAtom)
+        return self.args[1].get()
 
     @property
     def expr(self):
diff --git a/tests/test_boundary.py b/tests/test_boundary.py
index 84c390221..a94d37820 100644
--- a/tests/test_boundary.py
+++ b/tests/test_boundary.py
@@ -244,6 +244,3 @@ def test_dirichlet(with_indices):
     assert all([np.allclose(a, np.array(value)) for a in dh.cpu_arrays.src[1:-2, -1]])
     assert all([np.allclose(a, np.array(value)) for a in dh.cpu_arrays.src[0, 1:-2]])
     assert all([np.allclose(a, np.array(value)) for a in dh.cpu_arrays.src[-1, 1:-2]])
-
-
-test_kernel_vs_copy_boundary()
-- 
GitLab