Skip to content
Snippets Groups Projects
Commit 3c8507ba authored by Frederik Hennig's avatar Frederik Hennig
Browse files

introduce type atom wrapper. fix BC index dtype. Add CastFunc to freeze/typify.

parent e9bb252f
No related merge requests found
Pipeline #64059 failed with stages
in 2 minutes and 2 seconds
......@@ -5,7 +5,7 @@ from operator import add, mul, sub
import sympy as sp
from ...sympyextensions import Assignment, AssignmentCollection
from ...sympyextensions.typed_sympy import TypedSymbol
from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc
from ...field import Field, FieldType
from .context import KernelCreationContext
......@@ -26,6 +26,7 @@ from ..ast.expressions import (
PsConstantExpr,
PsArrayInitList,
PsSubscript,
PsCast
)
from ..constants import PsConstant
......@@ -305,3 +306,6 @@ class FreezeExpressions:
args = tuple(self.visit_expr(arg) for arg in func.args)
return PsCall(func_symbol, args)
def map_CastFunc(self, cast_expr: CastFunc):
return PsCast(cast_expr.dtype, self.visit_expr(cast_expr.expr))
......@@ -22,6 +22,7 @@ from ..ast.expressions import (
PsLookup,
PsCall,
PsArrayInitList,
PsCast
)
from ..functions import PsMathFunction
......@@ -280,5 +281,9 @@ class Typifier:
arr_type = PsArrayType(items_tc.target_type, len(items))
tc.apply_and_check(expr, arr_type)
case PsCast(dtype, operand):
self.visit_expr(operand, TypeContext())
tc.apply_and_check(expr, dtype)
case _:
raise NotImplementedError(f"Can't typify {expr}")
......@@ -8,8 +8,8 @@ from pystencils.sympyextensions import Assignment
from pystencils.boundaries.createindexlist import (
create_boundary_index_array, numpy_data_type_for_boundary_object)
from pystencils.sympyextensions import TypedSymbol
from pystencils.defaults import DEFAULTS
from pystencils.types.quick import Arr, create_type
from pystencils.types import PsIntegerType
from pystencils.types.quick import Arr, SInt
from pystencils.gpu.gpu_array_handler import GPUArrayHandler
from pystencils.field import Field, FieldType
from pystencils.backend.kernelfunction import FieldPointerParam
......@@ -417,37 +417,46 @@ class BoundaryOffsetInfo:
@staticmethod
def inv_dir(dir_idx):
return sp.IndexedBase(BoundaryOffsetInfo.INV_DIR_SYMBOL, shape=(1,))[dir_idx]
return sp.IndexedBase(BoundaryOffsetInfo._inv_dir_symbol(), shape=(1,))[dir_idx]
# ---------------------------------- Internal ---------------------------------------------
@staticmethod
def get_array_declarations(stencil) -> list[Assignment]:
dim = len(stencil[0])
def __init__(self, stencil, index_dtype: PsIntegerType = SInt(32)) -> None:
self._stencil = stencil
self._dim = len(stencil[0])
self._index_dtype = index_dtype
def get_array_declarations(self) -> list[Assignment]:
asms = []
for i, offset_symb in enumerate(BoundaryOffsetInfo._offset_symbols(dim)):
offsets = tuple(d[i] for d in stencil)
for i, offset_symb in enumerate(BoundaryOffsetInfo._offset_symbols(self._dim)):
offsets = tuple(d[i] for d in self._stencil)
asms.append(Assignment(offset_symb, offsets))
inv_dirs = []
for direction in stencil:
for direction in self._stencil:
inverse_dir = tuple([-i for i in direction])
inv_dirs.append(str(stencil.index(inverse_dir)))
inv_dirs.append(str(self._stencil.index(inverse_dir)))
asms.append(Assignment(BoundaryOffsetInfo.INV_DIR_SYMBOL, tuple(inv_dirs)))
asms.append(Assignment(BoundaryOffsetInfo._inv_dir_symbol(), tuple(inv_dirs)))
return asms
@staticmethod
def _offset_symbols(dim):
return [TypedSymbol(f"c{d}", Arr(create_type(DEFAULTS.index_dtype))) for d in ['x', 'y', 'z'][:dim]]
def _offset_symbols(dim, dtype: PsIntegerType = SInt(32)):
return [TypedSymbol(f"c{d}", Arr(dtype)) for d in ['x', 'y', 'z'][:dim]]
INV_DIR_SYMBOL = TypedSymbol("invdir", Arr(create_type(DEFAULTS.index_dtype)))
@staticmethod
def _inv_dir_symbol(dtype: PsIntegerType = SInt(32)):
return TypedSymbol("invdir", Arr(dtype))
def create_boundary_kernel(field, index_field, stencil, boundary_functor, target=Target.CPU, **kernel_creation_args):
elements = BoundaryOffsetInfo.get_array_declarations(stencil)
dir_symbol = TypedSymbol("dir", DEFAULTS.index_dtype)
# TODO: reconsider how to control the index_dtype in boundary kernels
config = CreateKernelConfig(index_field=index_field, target=target, index_dtype=SInt(32), **kernel_creation_args)
offset_info = BoundaryOffsetInfo(stencil, config.index_dtype)
elements = offset_info.get_array_declarations()
dir_symbol = TypedSymbol("dir", config.index_dtype)
elements += [Assignment(dir_symbol, index_field[0]('dir'))]
elements += boundary_functor(field, direction_symbol=dir_symbol, index_field=index_field)
config = CreateKernelConfig(index_field=index_field, target=target, **kernel_creation_args)
return create_kernel(elements, config=config)
import warnings
import numpy as np
from pystencils.defaults import DEFAULTS
from pystencils.types.quick import SInt
try:
......@@ -22,7 +22,7 @@ if cython_funcs_available:
boundary_index_array_coordinate_names = ["x", "y", "z"]
direction_member_name = "dir"
default_index_array_dtype = DEFAULTS.index_dtype
default_index_array_dtype = SInt(32)
def numpy_data_type_for_boundary_object(boundary_object, dim):
......
......@@ -34,6 +34,22 @@ def is_loop_counter_symbol(symbol):
return None
class PsTypeAtom(sp.Atom):
"""Wrapper around a PsType to disguise it as a SymPy atom."""
def __new__(cls, *args, **kwargs):
return sp.Basic.__new__(cls)
def __init__(self, dtype: PsType) -> None:
self._dtype = dtype
def _sympystr(self, *args, **kwargs):
return str(self._dtype)
def get(self) -> PsType:
return self._dtype
class TypedSymbol(sp.Symbol):
def __new__(cls, *args, **kwds):
obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
......@@ -192,7 +208,9 @@ class CastFunc(sp.Function):
# This optimisation is only available for simple casts. Thus the == is intended here!
if expr.__class__ == CastFunc:
expr = expr.args[0]
dtype = create_type(dtype)
if not isinstance(dtype, PsTypeAtom):
dtype = PsTypeAtom(create_type(dtype))
# to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
# however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
# to problems when for example comparing cast_func's for equality
......@@ -220,8 +238,9 @@ class CastFunc(sp.Function):
return self.args[0].is_commutative
@property
def dtype(self):
return self.args[1]
def dtype(self) -> PsType:
assert isinstance(self.args[1], PsTypeAtom)
return self.args[1].get()
@property
def expr(self):
......
......@@ -244,6 +244,3 @@ def test_dirichlet(with_indices):
assert all([np.allclose(a, np.array(value)) for a in dh.cpu_arrays.src[1:-2, -1]])
assert all([np.allclose(a, np.array(value)) for a in dh.cpu_arrays.src[0, 1:-2]])
assert all([np.allclose(a, np.array(value)) for a in dh.cpu_arrays.src[-1, 1:-2]])
test_kernel_vs_copy_boundary()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment