diff --git a/src/pystencils/boundaries/boundaryhandling.py b/src/pystencils/boundaries/boundaryhandling.py index a61f062be307da3fffa0f8983fd403c5a39ffc89..cfccdd0783b3ab678129cb44e7f2150ee5331b16 100644 --- a/src/pystencils/boundaries/boundaryhandling.py +++ b/src/pystencils/boundaries/boundaryhandling.py @@ -1,3 +1,4 @@ +from typing import Sequence from functools import lru_cache import numpy as np @@ -8,7 +9,8 @@ from pystencils.sympyextensions import Assignment from pystencils.boundaries.createindexlist import ( create_boundary_index_array, numpy_data_type_for_boundary_object) from pystencils.sympyextensions import TypedSymbol -from pystencils.types import create_type +from pystencils.defaults import DEFAULTS +from pystencils.types.quick import Arr, create_type from pystencils.gpu.gpu_array_handler import GPUArrayHandler from pystencils.field import Field from pystencils.backend.kernelfunction import FieldPointerParam @@ -404,7 +406,6 @@ class BoundaryDataSetter: return self.index_array[item] -# class BoundaryOffsetInfo(CustomCodeNode): # TODO nbackend: replace class BoundaryOffsetInfo: # --------------------------- Functions to be used by boundaries -------------------------- @@ -420,35 +421,32 @@ class BoundaryOffsetInfo: # ---------------------------------- Internal --------------------------------------------- - def __init__(self, stencil): + @staticmethod + def get_array_declarations(stencil) -> list[Assignment]: dim = len(stencil[0]) - - offset_sym = BoundaryOffsetInfo._offset_symbols(dim) - code = "\n" - for i in range(dim): - offset_str = ", ".join([str(d[i]) for d in stencil]) - code += "const int32_t %s [] = { %s };\n" % (offset_sym[i].name, offset_str) + asms = [] + for i, offset_symb in enumerate(BoundaryOffsetInfo._offset_symbols(dim)): + offsets = tuple(d[i] for d in stencil) + asms.append(Assignment(offset_symb, offsets)) inv_dirs = [] for direction in stencil: inverse_dir = tuple([-i for i in direction]) inv_dirs.append(str(stencil.index(inverse_dir))) - code += "const int32_t %s [] = { %s };\n" % (self.INV_DIR_SYMBOL.name, ", ".join(inv_dirs)) - offset_symbols = BoundaryOffsetInfo._offset_symbols(dim) - super(BoundaryOffsetInfo, self).__init__(code, symbols_read=set(), - symbols_defined=set(offset_symbols + [self.INV_DIR_SYMBOL])) + asms.append(Assignment(BoundaryOffsetInfo.INV_DIR_SYMBOL, tuple(inv_dirs))) + return asms @staticmethod def _offset_symbols(dim): - return [TypedSymbol(f"c{d}", create_type(np.int32)) for d in ['x', 'y', 'z'][:dim]] + return [TypedSymbol(f"c{d}", Arr(create_type(DEFAULTS.index_dtype))) for d in ['x', 'y', 'z'][:dim]] - INV_DIR_SYMBOL = TypedSymbol("invdir", np.int32) + INV_DIR_SYMBOL = TypedSymbol("invdir", Arr(create_type(DEFAULTS.index_dtype))) def create_boundary_kernel(field, index_field, stencil, boundary_functor, target=Target.CPU, **kernel_creation_args): - elements = [BoundaryOffsetInfo(stencil)] - dir_symbol = TypedSymbol("dir", np.int32) + elements = BoundaryOffsetInfo.get_array_declarations(stencil) + dir_symbol = TypedSymbol("dir", DEFAULTS.index_dtype) elements += [Assignment(dir_symbol, index_field[0]('dir'))] elements += boundary_functor(field, direction_symbol=dir_symbol, index_field=index_field) config = CreateKernelConfig(index_field=index_field, target=target, **kernel_creation_args) diff --git a/src/pystencils/boundaries/createindexlist.py b/src/pystencils/boundaries/createindexlist.py index 462d3f329263aa1d84e4e7b76897a4cf16b858f7..d20e946fbc82d3cdceb81b3bdb084046d1af731b 100644 --- a/src/pystencils/boundaries/createindexlist.py +++ b/src/pystencils/boundaries/createindexlist.py @@ -1,6 +1,7 @@ import warnings import numpy as np +from pystencils.defaults import DEFAULTS try: @@ -21,14 +22,14 @@ if cython_funcs_available: boundary_index_array_coordinate_names = ["x", "y", "z"] direction_member_name = "dir" -default_index_array_dtype = np.int32 +default_index_array_dtype = DEFAULTS.index_dtype def numpy_data_type_for_boundary_object(boundary_object, dim): coordinate_names = boundary_index_array_coordinate_names[:dim] return np.dtype( - [(name, default_index_array_dtype) for name in coordinate_names] - + [(direction_member_name, default_index_array_dtype)] + [(name, default_index_array_dtype.numpy_dtype) for name in coordinate_names] + + [(direction_member_name, default_index_array_dtype.numpy_dtype)] + [(i[0], i[1].numpy_dtype) for i in boundary_object.additional_data], align=True, )