From f00895417ed61b5e5d4c626bf608816854bb5a27 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Sun, 10 Mar 2024 14:19:40 +0100
Subject: [PATCH] adapt BoundaryOffsetInfo and create_boundary_kernel

---
 src/pystencils/boundaries/boundaryhandling.py | 32 +++++++++----------
 src/pystencils/boundaries/createindexlist.py  |  7 ++--
 2 files changed, 19 insertions(+), 20 deletions(-)

diff --git a/src/pystencils/boundaries/boundaryhandling.py b/src/pystencils/boundaries/boundaryhandling.py
index a61f062be..cfccdd078 100644
--- a/src/pystencils/boundaries/boundaryhandling.py
+++ b/src/pystencils/boundaries/boundaryhandling.py
@@ -1,3 +1,4 @@
+from typing import Sequence
 from functools import lru_cache
 
 import numpy as np
@@ -8,7 +9,8 @@ from pystencils.sympyextensions import Assignment
 from pystencils.boundaries.createindexlist import (
     create_boundary_index_array, numpy_data_type_for_boundary_object)
 from pystencils.sympyextensions import TypedSymbol
-from pystencils.types import create_type
+from pystencils.defaults import DEFAULTS
+from pystencils.types.quick import Arr, create_type
 from pystencils.gpu.gpu_array_handler import GPUArrayHandler
 from pystencils.field import Field
 from pystencils.backend.kernelfunction import FieldPointerParam
@@ -404,7 +406,6 @@ class BoundaryDataSetter:
         return self.index_array[item]
 
 
-# class BoundaryOffsetInfo(CustomCodeNode): #   TODO nbackend: replace
 class BoundaryOffsetInfo:
 
     # --------------------------- Functions to be used by boundaries --------------------------
@@ -420,35 +421,32 @@ class BoundaryOffsetInfo:
 
     # ---------------------------------- Internal ---------------------------------------------
 
-    def __init__(self, stencil):
+    @staticmethod
+    def get_array_declarations(stencil) -> list[Assignment]:
         dim = len(stencil[0])
-
-        offset_sym = BoundaryOffsetInfo._offset_symbols(dim)
-        code = "\n"
-        for i in range(dim):
-            offset_str = ", ".join([str(d[i]) for d in stencil])
-            code += "const int32_t %s [] = { %s };\n" % (offset_sym[i].name, offset_str)
+        asms = []
+        for i, offset_symb in enumerate(BoundaryOffsetInfo._offset_symbols(dim)):
+            offsets = tuple(d[i] for d in stencil)
+            asms.append(Assignment(offset_symb, offsets))
 
         inv_dirs = []
         for direction in stencil:
             inverse_dir = tuple([-i for i in direction])
             inv_dirs.append(str(stencil.index(inverse_dir)))
 
-        code += "const int32_t %s [] = { %s };\n" % (self.INV_DIR_SYMBOL.name, ", ".join(inv_dirs))
-        offset_symbols = BoundaryOffsetInfo._offset_symbols(dim)
-        super(BoundaryOffsetInfo, self).__init__(code, symbols_read=set(),
-                                                 symbols_defined=set(offset_symbols + [self.INV_DIR_SYMBOL]))
+        asms.append(Assignment(BoundaryOffsetInfo.INV_DIR_SYMBOL, tuple(inv_dirs)))
+        return asms
 
     @staticmethod
     def _offset_symbols(dim):
-        return [TypedSymbol(f"c{d}", create_type(np.int32)) for d in ['x', 'y', 'z'][:dim]]
+        return [TypedSymbol(f"c{d}", Arr(create_type(DEFAULTS.index_dtype))) for d in ['x', 'y', 'z'][:dim]]
 
-    INV_DIR_SYMBOL = TypedSymbol("invdir", np.int32)
+    INV_DIR_SYMBOL = TypedSymbol("invdir", Arr(create_type(DEFAULTS.index_dtype)))
 
 
 def create_boundary_kernel(field, index_field, stencil, boundary_functor, target=Target.CPU, **kernel_creation_args):
-    elements = [BoundaryOffsetInfo(stencil)]
-    dir_symbol = TypedSymbol("dir", np.int32)
+    elements = BoundaryOffsetInfo.get_array_declarations(stencil)
+    dir_symbol = TypedSymbol("dir", DEFAULTS.index_dtype)
     elements += [Assignment(dir_symbol, index_field[0]('dir'))]
     elements += boundary_functor(field, direction_symbol=dir_symbol, index_field=index_field)
     config = CreateKernelConfig(index_field=index_field, target=target, **kernel_creation_args)
diff --git a/src/pystencils/boundaries/createindexlist.py b/src/pystencils/boundaries/createindexlist.py
index 462d3f329..d20e946fb 100644
--- a/src/pystencils/boundaries/createindexlist.py
+++ b/src/pystencils/boundaries/createindexlist.py
@@ -1,6 +1,7 @@
 import warnings
 
 import numpy as np
+from pystencils.defaults import DEFAULTS
 
 
 try:
@@ -21,14 +22,14 @@ if cython_funcs_available:
 
 boundary_index_array_coordinate_names = ["x", "y", "z"]
 direction_member_name = "dir"
-default_index_array_dtype = np.int32
+default_index_array_dtype = DEFAULTS.index_dtype
 
 
 def numpy_data_type_for_boundary_object(boundary_object, dim):
     coordinate_names = boundary_index_array_coordinate_names[:dim]
     return np.dtype(
-        [(name, default_index_array_dtype) for name in coordinate_names]
-        + [(direction_member_name, default_index_array_dtype)]
+        [(name, default_index_array_dtype.numpy_dtype) for name in coordinate_names]
+        + [(direction_member_name, default_index_array_dtype.numpy_dtype)]
         + [(i[0], i[1].numpy_dtype) for i in boundary_object.additional_data],
         align=True,
     )
-- 
GitLab