diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index e0a5b130fe28827cda003fdbf01072a88075be9a..6861ff4c588c0c8f5c92f963db393dc24c27f9b5 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -5,7 +5,7 @@ from operator import add, mul, sub import sympy as sp from ...sympyextensions import Assignment, AssignmentCollection -from ...sympyextensions.typed_sympy import TypedSymbol +from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc from ...field import Field, FieldType from .context import KernelCreationContext @@ -26,6 +26,7 @@ from ..ast.expressions import ( PsConstantExpr, PsArrayInitList, PsSubscript, + PsCast ) from ..constants import PsConstant @@ -305,3 +306,6 @@ class FreezeExpressions: args = tuple(self.visit_expr(arg) for arg in func.args) return PsCall(func_symbol, args) + + def map_CastFunc(self, cast_expr: CastFunc): + return PsCast(cast_expr.dtype, self.visit_expr(cast_expr.expr)) diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index cc526e89539bb4b5db35589ea06c5150bcd2b92a..972a47c8c14775bc30a267fbcd8e94d0a2b89dee 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -22,6 +22,7 @@ from ..ast.expressions import ( PsLookup, PsCall, PsArrayInitList, + PsCast ) from ..functions import PsMathFunction @@ -280,5 +281,9 @@ class Typifier: arr_type = PsArrayType(items_tc.target_type, len(items)) tc.apply_and_check(expr, arr_type) + case PsCast(dtype, operand): + self.visit_expr(operand, TypeContext()) + tc.apply_and_check(expr, dtype) + case _: raise NotImplementedError(f"Can't typify {expr}") diff --git a/src/pystencils/boundaries/boundaryhandling.py b/src/pystencils/boundaries/boundaryhandling.py index e7b44099dc3c38283f37bbf1f63478a2be6361e9..57a1cd95f4cb2e166359e60684a2fe9687e9d320 100644 --- a/src/pystencils/boundaries/boundaryhandling.py +++ b/src/pystencils/boundaries/boundaryhandling.py @@ -8,8 +8,8 @@ from pystencils.sympyextensions import Assignment from pystencils.boundaries.createindexlist import ( create_boundary_index_array, numpy_data_type_for_boundary_object) from pystencils.sympyextensions import TypedSymbol -from pystencils.defaults import DEFAULTS -from pystencils.types.quick import Arr, create_type +from pystencils.types import PsIntegerType +from pystencils.types.quick import Arr, SInt from pystencils.gpu.gpu_array_handler import GPUArrayHandler from pystencils.field import Field, FieldType from pystencils.backend.kernelfunction import FieldPointerParam @@ -417,37 +417,46 @@ class BoundaryOffsetInfo: @staticmethod def inv_dir(dir_idx): - return sp.IndexedBase(BoundaryOffsetInfo.INV_DIR_SYMBOL, shape=(1,))[dir_idx] + return sp.IndexedBase(BoundaryOffsetInfo._inv_dir_symbol(), shape=(1,))[dir_idx] # ---------------------------------- Internal --------------------------------------------- - @staticmethod - def get_array_declarations(stencil) -> list[Assignment]: - dim = len(stencil[0]) + def __init__(self, stencil, index_dtype: PsIntegerType = SInt(32)) -> None: + self._stencil = stencil + self._dim = len(stencil[0]) + self._index_dtype = index_dtype + + def get_array_declarations(self) -> list[Assignment]: asms = [] - for i, offset_symb in enumerate(BoundaryOffsetInfo._offset_symbols(dim)): - offsets = tuple(d[i] for d in stencil) + for i, offset_symb in enumerate(BoundaryOffsetInfo._offset_symbols(self._dim)): + offsets = tuple(d[i] for d in self._stencil) asms.append(Assignment(offset_symb, offsets)) inv_dirs = [] - for direction in stencil: + for direction in self._stencil: inverse_dir = tuple([-i for i in direction]) - inv_dirs.append(str(stencil.index(inverse_dir))) + inv_dirs.append(str(self._stencil.index(inverse_dir))) - asms.append(Assignment(BoundaryOffsetInfo.INV_DIR_SYMBOL, tuple(inv_dirs))) + asms.append(Assignment(BoundaryOffsetInfo._inv_dir_symbol(), tuple(inv_dirs))) return asms @staticmethod - def _offset_symbols(dim): - return [TypedSymbol(f"c{d}", Arr(create_type(DEFAULTS.index_dtype))) for d in ['x', 'y', 'z'][:dim]] + def _offset_symbols(dim, dtype: PsIntegerType = SInt(32)): + return [TypedSymbol(f"c{d}", Arr(dtype)) for d in ['x', 'y', 'z'][:dim]] - INV_DIR_SYMBOL = TypedSymbol("invdir", Arr(create_type(DEFAULTS.index_dtype))) + @staticmethod + def _inv_dir_symbol(dtype: PsIntegerType = SInt(32)): + return TypedSymbol("invdir", Arr(dtype)) def create_boundary_kernel(field, index_field, stencil, boundary_functor, target=Target.CPU, **kernel_creation_args): - elements = BoundaryOffsetInfo.get_array_declarations(stencil) - dir_symbol = TypedSymbol("dir", DEFAULTS.index_dtype) + # TODO: reconsider how to control the index_dtype in boundary kernels + config = CreateKernelConfig(index_field=index_field, target=target, index_dtype=SInt(32), **kernel_creation_args) + + offset_info = BoundaryOffsetInfo(stencil, config.index_dtype) + elements = offset_info.get_array_declarations() + dir_symbol = TypedSymbol("dir", config.index_dtype) elements += [Assignment(dir_symbol, index_field[0]('dir'))] elements += boundary_functor(field, direction_symbol=dir_symbol, index_field=index_field) - config = CreateKernelConfig(index_field=index_field, target=target, **kernel_creation_args) + return create_kernel(elements, config=config) diff --git a/src/pystencils/boundaries/createindexlist.py b/src/pystencils/boundaries/createindexlist.py index 34bd0766ec904d76d15915aac3e08400f7737683..2bc346cc00f301d0228e16e2263a82ddfd8b9799 100644 --- a/src/pystencils/boundaries/createindexlist.py +++ b/src/pystencils/boundaries/createindexlist.py @@ -1,7 +1,7 @@ import warnings import numpy as np -from pystencils.defaults import DEFAULTS +from pystencils.types.quick import SInt try: @@ -22,7 +22,7 @@ if cython_funcs_available: boundary_index_array_coordinate_names = ["x", "y", "z"] direction_member_name = "dir" -default_index_array_dtype = DEFAULTS.index_dtype +default_index_array_dtype = SInt(32) def numpy_data_type_for_boundary_object(boundary_object, dim): diff --git a/src/pystencils/sympyextensions/typed_sympy.py b/src/pystencils/sympyextensions/typed_sympy.py index accea6d40b596326edecfd54938876f67df88991..c49c46c399bd991fa58b799de905cad1b3987b99 100644 --- a/src/pystencils/sympyextensions/typed_sympy.py +++ b/src/pystencils/sympyextensions/typed_sympy.py @@ -34,6 +34,22 @@ def is_loop_counter_symbol(symbol): return None +class PsTypeAtom(sp.Atom): + """Wrapper around a PsType to disguise it as a SymPy atom.""" + + def __new__(cls, *args, **kwargs): + return sp.Basic.__new__(cls) + + def __init__(self, dtype: PsType) -> None: + self._dtype = dtype + + def _sympystr(self, *args, **kwargs): + return str(self._dtype) + + def get(self) -> PsType: + return self._dtype + + class TypedSymbol(sp.Symbol): def __new__(cls, *args, **kwds): obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds) @@ -192,7 +208,9 @@ class CastFunc(sp.Function): # This optimisation is only available for simple casts. Thus the == is intended here! if expr.__class__ == CastFunc: expr = expr.args[0] - dtype = create_type(dtype) + + if not isinstance(dtype, PsTypeAtom): + dtype = PsTypeAtom(create_type(dtype)) # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads # to problems when for example comparing cast_func's for equality @@ -220,8 +238,9 @@ class CastFunc(sp.Function): return self.args[0].is_commutative @property - def dtype(self): - return self.args[1] + def dtype(self) -> PsType: + assert isinstance(self.args[1], PsTypeAtom) + return self.args[1].get() @property def expr(self): diff --git a/tests/test_boundary.py b/tests/test_boundary.py index 84c390221649bb96e8deeb567e3a66c67d8c3b13..a94d3782020cd494a4b01009f04016e889fca9c0 100644 --- a/tests/test_boundary.py +++ b/tests/test_boundary.py @@ -244,6 +244,3 @@ def test_dirichlet(with_indices): assert all([np.allclose(a, np.array(value)) for a in dh.cpu_arrays.src[1:-2, -1]]) assert all([np.allclose(a, np.array(value)) for a in dh.cpu_arrays.src[0, 1:-2]]) assert all([np.allclose(a, np.array(value)) for a in dh.cpu_arrays.src[-1, 1:-2]]) - - -test_kernel_vs_copy_boundary()