From 5ee0ec34770f83c622be08e0cb9709f9f4e9e22b Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Wed, 24 Jul 2024 17:29:26 +0200 Subject: [PATCH] Refactor Field Indexing Symbols --- src/pystencils/__init__.py | 3 +- src/pystencils/backend/arrays.py | 74 ++++---- .../backend/kernelcreation/context.py | 73 ++++++-- .../backend/kernelcreation/freeze.py | 7 + src/pystencils/backend/kernelfunction.py | 17 -- .../datahandling/parallel_datahandling.py | 6 +- src/pystencils/defaults.py | 13 +- src/pystencils/field.py | 22 ++- src/pystencils/sympyextensions/math.py | 4 +- src/pystencils/sympyextensions/typed_sympy.py | 175 +++++------------- src/pystencils/typing.py | 11 ++ tests/nbackend/kernelcreation/test_context.py | 101 ++++++++++ tests/nbackend/kernelcreation/test_freeze.py | 21 ++- 13 files changed, 317 insertions(+), 210 deletions(-) create mode 100644 src/pystencils/typing.py create mode 100644 tests/nbackend/kernelcreation/test_context.py diff --git a/src/pystencils/__init__.py b/src/pystencils/__init__.py index f5cb3e10b..a4685d3a7 100644 --- a/src/pystencils/__init__.py +++ b/src/pystencils/__init__.py @@ -32,7 +32,7 @@ from .spatial_coordinates import ( ) from .assignment import Assignment, AddAugmentedAssignment, assignment_from_stencil from .simp import AssignmentCollection -from .sympyextensions.typed_sympy import TypedSymbol +from .sympyextensions.typed_sympy import TypedSymbol, DynamicType from .sympyextensions import SymbolCreator from .datahandling import create_data_handling @@ -43,6 +43,7 @@ __all__ = [ "fields", "DEFAULTS", "TypedSymbol", + "DynamicType", "create_type", "create_numeric_type", "make_slice", diff --git a/src/pystencils/backend/arrays.py b/src/pystencils/backend/arrays.py index 45f22cb87..9aefeaf62 100644 --- a/src/pystencils/backend/arrays.py +++ b/src/pystencils/backend/arrays.py @@ -37,8 +37,8 @@ class PsLinearizedArray: self, name: str, element_type: PsType, - shape: Sequence[int | EllipsisType], - strides: Sequence[int | EllipsisType], + shape: Sequence[int | str | EllipsisType], + strides: Sequence[int | str | EllipsisType], index_dtype: PsIntegerType = DEFAULTS.index_dtype, ): self._name = name @@ -48,25 +48,33 @@ class PsLinearizedArray: if len(shape) != len(strides): raise ValueError("Shape and stride tuples must have the same length") + def make_shape(coord, name_or_val): + match name_or_val: + case EllipsisType(): + return PsArrayShapeSymbol(DEFAULTS.field_shape_name(name, coord), self, coord) + case str(): + return PsArrayShapeSymbol(name_or_val, self, coord) + case _: + return PsConstant(name_or_val, index_dtype) + self._shape: tuple[PsArrayShapeSymbol | PsConstant, ...] = tuple( - ( - PsArrayShapeSymbol(self, i, index_dtype) - if s == Ellipsis - else PsConstant(s, index_dtype) - ) - for i, s in enumerate(shape) + make_shape(i, s) for i, s in enumerate(shape) ) + def make_stride(coord, name_or_val): + match name_or_val: + case EllipsisType(): + return PsArrayStrideSymbol(DEFAULTS.field_stride_name(name, coord), self, coord) + case str(): + return PsArrayStrideSymbol(name_or_val, self, coord) + case _: + return PsConstant(name_or_val, index_dtype) + self._strides: tuple[PsArrayStrideSymbol | PsConstant, ...] = tuple( - ( - PsArrayStrideSymbol(self, i, index_dtype) - if s == Ellipsis - else PsConstant(s, index_dtype) - ) - for i, s in enumerate(strides) + make_stride(i, s) for i, s in enumerate(strides) ) - self._base_ptr = PsArrayBasePointer(f"{self._name}_data", self) + self._base_ptr = PsArrayBasePointer(DEFAULTS.field_pointer_name(name), self) @property def name(self): @@ -83,27 +91,17 @@ class PsLinearizedArray: """The array's shape, expressed using `PsConstant` and `PsArrayShapeSymbol`""" return self._shape - @property - def shape_spec(self) -> tuple[EllipsisType | int, ...]: - """The array's shape, expressed using `int` and `...`""" - return tuple( - (s.value if isinstance(s, PsConstant) else ...) for s in self._shape - ) - @property def strides(self) -> tuple[PsArrayStrideSymbol | PsConstant, ...]: """The array's strides, expressed using `PsConstant` and `PsArrayStrideSymbol`""" return self._strides @property - def strides_spec(self) -> tuple[EllipsisType | int, ...]: - """The array's strides, expressed using `int` and `...`""" - return tuple( - (s.value if isinstance(s, PsConstant) else ...) for s in self._strides - ) + def index_type(self) -> PsIntegerType: + return self._index_dtype @property - def element_type(self): + def element_type(self) -> PsType: return self._element_type def __repr__(self) -> str: @@ -159,9 +157,13 @@ class PsArrayShapeSymbol(PsArrayAssocSymbol): __match_args__ = PsArrayAssocSymbol.__match_args__ + ("coordinate",) - def __init__(self, array: PsLinearizedArray, coordinate: int, dtype: PsIntegerType): - name = f"_size_{array.name}_{coordinate}" - super().__init__(name, dtype, array) + def __init__( + self, + name: str, + array: PsLinearizedArray, + coordinate: int, + ): + super().__init__(name, array.index_type, array) self._coordinate = coordinate @property @@ -178,9 +180,13 @@ class PsArrayStrideSymbol(PsArrayAssocSymbol): __match_args__ = PsArrayAssocSymbol.__match_args__ + ("coordinate",) - def __init__(self, array: PsLinearizedArray, coordinate: int, dtype: PsIntegerType): - name = f"_stride_{array.name}_{coordinate}" - super().__init__(name, dtype, array) + def __init__( + self, + name: str, + array: PsLinearizedArray, + coordinate: int, + ): + super().__init__(name, array.index_type, array) self._coordinate = coordinate @property diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index 464c08dd9..73e3c70cc 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -2,17 +2,23 @@ from __future__ import annotations from typing import Iterable, Iterator, Any from itertools import chain, count -from types import EllipsisType from collections import namedtuple, defaultdict import re from ...defaults import DEFAULTS from ...field import Field, FieldType -from ...sympyextensions.typed_sympy import TypedSymbol +from ...sympyextensions.typed_sympy import TypedSymbol, DynamicType from ..symbols import PsSymbol from ..arrays import PsLinearizedArray -from ...types import PsType, PsIntegerType, PsNumericType, PsScalarType, PsStructType, deconstify +from ...types import ( + PsType, + PsIntegerType, + PsNumericType, + PsScalarType, + PsStructType, + deconstify, +) from ..constraints import KernelParamsConstraint from ..exceptions import PsInternalCompilerError, KernelConstraintsError @@ -97,7 +103,7 @@ class KernelCreationContext: @property def constraints(self) -> tuple[KernelParamsConstraint, ...]: return tuple(self._constraints) - + @property def metadata(self) -> dict[str, Any]: return self._metadata @@ -215,8 +221,27 @@ class KernelCreationContext: else: return - arr_shape: list[EllipsisType | int] | None = None - arr_strides: list[EllipsisType | int] | None = None + arr_shape: list[str | int] | None = None + arr_strides: list[str | int] | None = None + + def normalize_type(s: TypedSymbol) -> PsIntegerType: + match s.dtype: + case DynamicType.INDEX_TYPE: + return self.index_dtype + case DynamicType.NUMERIC_TYPE: + if isinstance(self.default_dtype, PsIntegerType): + return self.default_dtype + else: + raise KernelConstraintsError( + f"Cannot use non-integer default numeric type {self.default_dtype} " + f"in field indexing symbol {s}." + ) + case PsIntegerType(): + return deconstify(s.dtype) + case _: + raise KernelConstraintsError( + f"Invalid data type for field indexing symbol {s}: {s.dtype}" + ) # Check field constraints and add to collection match field.field_type: @@ -243,7 +268,15 @@ class KernelCreationContext: "Buffer fields cannot have variable index shape." ) - arr_shape = [..., num_entries] + buffer_len = field.spatial_shape[0] + + if isinstance(buffer_len, TypedSymbol): + idx_type = normalize_type(buffer_len) + arr_shape = [buffer_len.name, num_entries] + else: + idx_type = DEFAULTS.index_dtype + arr_shape = [buffer_len, num_entries] + arr_strides = [num_entries, 1] self._fields_collection.buffer_fields.add(field) @@ -265,18 +298,25 @@ class KernelCreationContext: # For non-buffer fields, determine shape and strides if arr_shape is None: + idx_types = set( + normalize_type(s) + for s in chain(field.shape, field.strides) + if isinstance(s, TypedSymbol) + ) + + if len(idx_types) > 1: + raise KernelConstraintsError( + f"Multiple incompatible types found in index symbols of field {field}: " + f"{idx_types}" + ) + idx_type = idx_types.pop() if len(idx_types) > 0 else self.index_dtype + arr_shape = [ - ( - Ellipsis if isinstance(s, TypedSymbol) else s - ) # TODO: Field should also use ellipsis - for s in field.shape + (s.name if isinstance(s, TypedSymbol) else s) for s in field.shape ] arr_strides = [ - ( - Ellipsis if isinstance(s, TypedSymbol) else s - ) # TODO: Field should also use ellipsis - for s in field.strides + (s.name if isinstance(s, TypedSymbol) else s) for s in field.strides ] # The frontend doesn't quite agree with itself on how to model @@ -288,12 +328,13 @@ class KernelCreationContext: # Add array assert arr_strides is not None + assert idx_type is not None assert isinstance(field.dtype, (PsScalarType, PsStructType)) element_type = field.dtype arr = PsLinearizedArray( - field.name, element_type, arr_shape, arr_strides, self.index_dtype + field.name, element_type, arr_shape, arr_strides, idx_type ) self._fields_and_arrays[field.name] = FieldArrayPair(field, arr) diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 25ce28115..b5c04f1bd 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -272,6 +272,13 @@ class FreezeExpressions: def map_TypedSymbol(self, expr: TypedSymbol): dtype = expr.dtype + + match dtype: + case DynamicType.NUMERIC_TYPE: + dtype = self._ctx.default_dtype + case DynamicType.INDEX_TYPE: + dtype = self._ctx.index_dtype + symb = self._ctx.get_symbol(expr.name, dtype) return PsSymbolExpr(symb) diff --git a/src/pystencils/backend/kernelfunction.py b/src/pystencils/backend/kernelfunction.py index 32510731c..a3213350e 100644 --- a/src/pystencils/backend/kernelfunction.py +++ b/src/pystencils/backend/kernelfunction.py @@ -19,11 +19,6 @@ from ..types import PsType from ..enums import Target from ..field import Field from ..sympyextensions import TypedSymbol -from ..sympyextensions.typed_sympy import ( - FieldShapeSymbol, - FieldStrideSymbol, - FieldPointerSymbol, -) if TYPE_CHECKING: from .jit import JitBase @@ -151,10 +146,6 @@ class FieldShapeParam(FieldParameter): def coordinate(self): return self._coordinate - @property - def symbol(self) -> FieldShapeSymbol: - return FieldShapeSymbol(self.field.name, self.coordinate, self.dtype) - def _hashable_contents(self): return super()._hashable_contents() + (self._coordinate,) @@ -170,10 +161,6 @@ class FieldStrideParam(FieldParameter): def coordinate(self): return self._coordinate - @property - def symbol(self) -> FieldStrideSymbol: - return FieldStrideSymbol(self.field.name, self.coordinate, self.dtype) - def _hashable_contents(self): return super()._hashable_contents() + (self._coordinate,) @@ -182,10 +169,6 @@ class FieldPointerParam(FieldParameter): def __init__(self, name: str, dtype: PsType, field: Field): super().__init__(name, dtype, field) - @property - def symbol(self) -> FieldPointerSymbol: - return FieldPointerSymbol(self.field.name, self.field.dtype, const=True) - class KernelFunction: """A pystencils kernel function. diff --git a/src/pystencils/datahandling/parallel_datahandling.py b/src/pystencils/datahandling/parallel_datahandling.py index adc6439a2..88eea0315 100644 --- a/src/pystencils/datahandling/parallel_datahandling.py +++ b/src/pystencils/datahandling/parallel_datahandling.py @@ -8,8 +8,8 @@ import waLBerla as wlb from pystencils.datahandling.blockiteration import block_iteration, sliced_block_iteration from pystencils.datahandling.datahandling_interface import DataHandling from pystencils.field import Field, FieldType -from pystencils.sympyextensions.typed_sympy import FieldPointerSymbol from pystencils.utils import DotDict +from pystencils.backend.kernelfunction import FieldPointerParam from pystencils import Target @@ -258,9 +258,9 @@ class ParallelDataHandling(DataHandling): else: name_map = self._field_name_to_cpu_data_name to_array = wlb.field.toArray - data_used_in_kernel = [(name_map[p.symbol.field_name], self.fields[p.symbol.field_name]) + data_used_in_kernel = [(name_map[p.field_name], self.fields[p.field_name]) for p in kernel_function.parameters if - isinstance(p.symbol, FieldPointerSymbol) and p.symbol.field_name not in kwargs] + isinstance(p, FieldPointerParam) and p.field_name not in kwargs] result = [] for block in self.blocks: diff --git a/src/pystencils/defaults.py b/src/pystencils/defaults.py index f8e96a3a3..c7ac33347 100644 --- a/src/pystencils/defaults.py +++ b/src/pystencils/defaults.py @@ -1,5 +1,5 @@ from typing import TypeVar, Generic, Callable -from .types import PsType, PsIeeeFloatType, PsSignedIntegerType, PsStructType +from .types import PsType, PsIeeeFloatType, PsIntegerType, PsSignedIntegerType, PsStructType from pystencils.sympyextensions.typed_sympy import TypedSymbol @@ -11,7 +11,7 @@ class GenericDefaults(Generic[SymbolT]): self.numeric_dtype = PsIeeeFloatType(64) """Default data type for numerical computations""" - self.index_dtype = PsSignedIntegerType(64) + self.index_dtype: PsIntegerType = PsSignedIntegerType(64) """Default data type for indices.""" self.spatial_counter_names = ("ctr_0", "ctr_1", "ctr_2") @@ -40,6 +40,15 @@ class GenericDefaults(Generic[SymbolT]): self.sparse_counter = symcreate(self.sparse_counter_name, self.index_dtype) """Default sparse iteration counter.""" + def field_shape_name(self, field_name: str, coord: int): + return f"_size_{field_name}_{coord}" + + def field_stride_name(self, field_name: str, coord: int): + return f"_stride_{field_name}_{coord}" + + def field_pointer_name(self, field_name: str): + return f"_data_{field_name}" + DEFAULTS = GenericDefaults[TypedSymbol](TypedSymbol) """Default names and symbols used throughout code generation""" diff --git a/src/pystencils/field.py b/src/pystencils/field.py index 2a7b6d315..51f01deb1 100644 --- a/src/pystencils/field.py +++ b/src/pystencils/field.py @@ -11,11 +11,12 @@ import numpy as np import sympy as sp from sympy.core.cache import cacheit +from .defaults import DEFAULTS from pystencils.alignedarray import aligned_empty from pystencils.spatial_coordinates import x_staggered_vector, x_vector from pystencils.stencil import direction_string_to_offset, inverse_direction, offset_to_direction_string from pystencils.types import PsType, PsStructType, create_type -from pystencils.sympyextensions.typed_sympy import FieldShapeSymbol, FieldStrideSymbol, TypedSymbol +from pystencils.sympyextensions.typed_sympy import TypedSymbol, DynamicType from pystencils.sympyextensions import is_integer_sequence from pystencils.types import UserTypeSpec @@ -151,11 +152,22 @@ class Field: total_dimensions = spatial_dimensions + index_dimensions if index_shape is None or len(index_shape) == 0: - shape = tuple([FieldShapeSymbol(field_name, i) for i in range(total_dimensions)]) + shape = tuple([ + TypedSymbol(DEFAULTS.field_shape_name(field_name, i), DynamicType.INDEX_TYPE) + for i in range(total_dimensions) + ]) else: - shape = tuple([FieldShapeSymbol(field_name, i) for i in range(spatial_dimensions)] + list(index_shape)) - - strides = tuple([FieldStrideSymbol(field_name, i) for i in range(total_dimensions)]) + shape = tuple( + [ + TypedSymbol(DEFAULTS.field_shape_name(field_name, i), DynamicType.INDEX_TYPE) + for i in range(spatial_dimensions) + ] + list(index_shape) + ) + + strides = tuple([ + TypedSymbol(DEFAULTS.field_stride_name(field_name, i), DynamicType.INDEX_TYPE) + for i in range(total_dimensions) + ]) dtype = create_type(dtype) np_data_type = dtype.numpy_dtype diff --git a/src/pystencils/sympyextensions/math.py b/src/pystencils/sympyextensions/math.py index 1a006efe6..6f99be7ef 100644 --- a/src/pystencils/sympyextensions/math.py +++ b/src/pystencils/sympyextensions/math.py @@ -11,7 +11,7 @@ from sympy.functions import Abs from sympy.core.numbers import Zero from ..assignment import Assignment -from .typed_sympy import CastFunc, FieldPointerSymbol +from .typed_sympy import CastFunc from ..types import PsPointerType, PsVectorType T = TypeVar('T') @@ -565,8 +565,6 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]], def check_type(e): if only_type is None: return True - if isinstance(e, FieldPointerSymbol) and only_type == "real": - return only_type == "int" try: # base_type = get_type_of_expression(e) diff --git a/src/pystencils/sympyextensions/typed_sympy.py b/src/pystencils/sympyextensions/typed_sympy.py index c81a189ee..5e2eaab6c 100644 --- a/src/pystencils/sympyextensions/typed_sympy.py +++ b/src/pystencils/sympyextensions/typed_sympy.py @@ -3,28 +3,13 @@ from __future__ import annotations import sympy as sp from enum import Enum, auto -from ..types import PsType, PsNumericType, PsPointerType, PsBoolType, PsIntegerType, create_type - - -def assumptions_from_dtype(dtype: PsType): - """Derives SymPy assumptions from :class:`PsAbstractType` - - Args: - dtype (PsAbstractType): a pystencils data type - Returns: - A dict of SymPy assumptions - """ - assumptions = dict() - - if isinstance(dtype, PsNumericType): - if dtype.is_int(): - assumptions.update({"integer": True}) - if dtype.is_uint(): - assumptions.update({"negative": False}) - if dtype.is_int() or dtype.is_float(): - assumptions.update({"real": True}) - - return assumptions +from ..types import ( + PsType, + PsNumericType, + PsBoolType, + create_type, + UserTypeSpec +) def is_loop_counter_symbol(symbol): @@ -46,7 +31,7 @@ class TypeAtom(sp.Atom): def __new__(cls, *args, **kwargs): return sp.Basic.__new__(cls) - + def __init__(self, dtype: PsType | DynamicType) -> None: self._dtype = dtype @@ -55,21 +40,52 @@ class TypeAtom(sp.Atom): def get(self) -> PsType | DynamicType: return self._dtype - + def _hashable_content(self): - return (self._dtype, ) + return (self._dtype,) + + +def assumptions_from_dtype(dtype: PsType | DynamicType): + """Derives SymPy assumptions from :class:`PsAbstractType` + + Args: + dtype (PsAbstractType): a pystencils data type + Returns: + A dict of SymPy assumptions + """ + assumptions = dict() + + match dtype: + case DynamicType.INDEX_TYPE: + assumptions.update({"integer": True, "real": True}) + case DynamicType.NUMERIC_TYPE: + assumptions.update({"real": True}) + case PsNumericType(): + if dtype.is_int(): + assumptions.update({"integer": True}) + if dtype.is_uint(): + assumptions.update({"negative": False}) + if dtype.is_int() or dtype.is_float(): + assumptions.update({"real": True}) + + return assumptions class TypedSymbol(sp.Symbol): + + _dtype: PsType | DynamicType + def __new__(cls, *args, **kwds): obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds) return obj def __new_stage2__( - cls, name, dtype, **kwargs + cls, name: str, dtype: UserTypeSpec | DynamicType, **kwargs ): # TODO does not match signature of sp.Symbol??? # TODO: also Symbol should be allowed ---> see sympy Variable - dtype = create_type(dtype) + if not isinstance(dtype, DynamicType): + dtype = create_type(dtype) + assumptions = assumptions_from_dtype(dtype) assumptions.update(kwargs) @@ -82,7 +98,7 @@ class TypedSymbol(sp.Symbol): __xnew_cached_ = staticmethod(sp.core.cacheit(__new_stage2__)) @property - def dtype(self) -> PsType: + def dtype(self) -> PsType | DynamicType: # mypy: ignore return self._dtype @@ -106,104 +122,7 @@ class TypedSymbol(sp.Symbol): @property def headers(self) -> set[str]: - return self.dtype.required_headers - - -class FieldStrideSymbol(TypedSymbol): - """Sympy symbol representing the stride value of a field in a specific coordinate.""" - - def __new__(cls, *args, **kwds): - obj = FieldStrideSymbol.__xnew_cached_(cls, *args, **kwds) - return obj - - def __new_stage2__(cls, field_name: str, coordinate: int, dtype: PsIntegerType | None = None): - from ..defaults import DEFAULTS - - if dtype is None: - dtype = DEFAULTS.index_dtype - - name = f"_stride_{field_name}_{coordinate}" - obj = super(FieldStrideSymbol, cls).__xnew__( - cls, name, dtype, positive=True - ) - obj.field_name = field_name - obj.coordinate = coordinate - return obj - - def __getnewargs__(self): - return self.field_name, self.coordinate - - def __getnewargs_ex__(self): - return (self.field_name, self.coordinate), {} - - __xnew__ = staticmethod(__new_stage2__) - __xnew_cached_ = staticmethod(sp.core.cacheit(__new_stage2__)) - - def _hashable_content(self): - return super()._hashable_content(), self.coordinate, self.field_name - - -class FieldShapeSymbol(TypedSymbol): - """Sympy symbol representing the shape value of a sequence of fields. In a kernel iterating over multiple fields - there is only one set of `FieldShapeSymbol`s since all the fields have to be of equal size. - """ - - def __new__(cls, *args, **kwds): - obj = FieldShapeSymbol.__xnew_cached_(cls, *args, **kwds) - return obj - - def __new_stage2__(cls, field_name: str, coordinate: int, dtype: PsIntegerType | None = None): - from ..defaults import DEFAULTS - - if dtype is None: - dtype = DEFAULTS.index_dtype - - name = f"_size_{field_name}_{coordinate}" - obj = super(FieldShapeSymbol, cls).__xnew__( - cls, name, dtype, positive=True - ) - obj.field_name = field_name - obj.coordinate = coordinate - return obj - - def __getnewargs__(self): - return self.field_name, self.coordinate - - def __getnewargs_ex__(self): - return (self.field_name, self.coordinate), {} - - __xnew__ = staticmethod(__new_stage2__) - __xnew_cached_ = staticmethod(sp.core.cacheit(__new_stage2__)) - - def _hashable_content(self): - return super()._hashable_content(), self.coordinate, self.field_name - - -class FieldPointerSymbol(TypedSymbol): - """Sympy symbol representing the pointer to the beginning of the field data.""" - - def __new__(cls, *args, **kwds): - obj = FieldPointerSymbol.__xnew_cached_(cls, *args, **kwds) - return obj - - def __new_stage2__(cls, field_name, field_dtype: PsType, const: bool): - name = f"_data_{field_name}" - dtype = PsPointerType(field_dtype, restrict=True, const=const) - obj = super(FieldPointerSymbol, cls).__xnew__(cls, name, dtype) - obj.field_name = field_name - return obj - - def __getnewargs__(self): - return self.field_name, self.dtype, self.dtype.const - - def __getnewargs_ex__(self): - return (self.field_name, self.dtype, self.dtype.const), {} - - def _hashable_content(self): - return super()._hashable_content(), self.field_name - - __xnew__ = staticmethod(__new_stage2__) - __xnew_cached_ = staticmethod(sp.core.cacheit(__new_stage2__)) + return self.dtype.required_headers if isinstance(self.dtype, PsType) else set() class CastFunc(sp.Function): @@ -218,7 +137,7 @@ class CastFunc(sp.Function): @staticmethod def as_numeric(expr): return CastFunc(expr, DynamicType.NUMERIC_TYPE) - + @staticmethod def as_index(expr): return CastFunc(expr, DynamicType.INDEX_TYPE) @@ -240,7 +159,7 @@ class CastFunc(sp.Function): dtype = TypeAtom(dtype) else: dtype = TypeAtom(create_type(dtype)) - + # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads # to problems when for example comparing cast_func's for equality diff --git a/src/pystencils/typing.py b/src/pystencils/typing.py new file mode 100644 index 000000000..22f8f673e --- /dev/null +++ b/src/pystencils/typing.py @@ -0,0 +1,11 @@ +from .sympyextensions import TypedSymbol as _TypedSymbol +from .types import create_type as _create_type + +from warnings import warn +warn( + "Importing `TypedSymbol` and `create_type` from `pystencils.typing` is deprecated. " + "Import from `pystencils` instead." +) + +TypedSymbol = _TypedSymbol +create_type = _create_type diff --git a/tests/nbackend/kernelcreation/test_context.py b/tests/nbackend/kernelcreation/test_context.py new file mode 100644 index 000000000..9701013b0 --- /dev/null +++ b/tests/nbackend/kernelcreation/test_context.py @@ -0,0 +1,101 @@ +from itertools import chain +import pytest + +from pystencils import Field, TypedSymbol, FieldType, DynamicType + +from pystencils.backend.kernelcreation import KernelCreationContext +from pystencils.backend.constants import PsConstant +from pystencils.backend.exceptions import KernelConstraintsError +from pystencils.types.quick import SInt, Fp +from pystencils.types import deconstify + + +def test_field_arrays(): + ctx = KernelCreationContext(index_dtype=SInt(16)) + + f = Field.create_generic("f", 3, Fp(32)) + f_arr = ctx.get_array(f) + + assert f_arr.element_type == f.dtype == Fp(32) + assert len(f_arr.shape) == len(f.shape) + 1 == 4 + assert isinstance(f_arr.shape[3], PsConstant) and f_arr.shape[3].value == 1 + assert f_arr.shape[3].dtype == SInt(16, const=True) + assert f_arr.index_type == ctx.index_dtype == SInt(16) + assert f_arr.shape[0].dtype == ctx.index_dtype == SInt(16) + + g = Field.create_generic("g", 3, index_shape=(2, 4), dtype=Fp(16)) + g_arr = ctx.get_array(g) + + assert g_arr.element_type == g.dtype == Fp(16) + assert len(g_arr.shape) == len(g.spatial_shape) + len(g.index_shape) == 5 + assert isinstance(g_arr.shape[3], PsConstant) and g_arr.shape[3].value == 2 + assert g_arr.shape[3].dtype == SInt(16, const=True) + assert isinstance(g_arr.shape[4], PsConstant) and g_arr.shape[4].value == 4 + assert g_arr.shape[4].dtype == SInt(16, const=True) + assert g_arr.index_type == ctx.index_dtype == SInt(16) + + h = Field( + "h", + FieldType.GENERIC, + Fp(32), + (0, 1), + ( + TypedSymbol("nx", SInt(32)), + TypedSymbol("ny", SInt(32)), + 1 + ), + ( + TypedSymbol("sx", SInt(32)), + TypedSymbol("sy", SInt(32)), + 1 + ) + ) + + h_arr = ctx.get_array(h) + + assert h_arr.index_type == SInt(32) + + for s in chain(h_arr.shape, h_arr.strides): + assert deconstify(s.get_dtype()) == SInt(32) + + assert [s.name for s in chain(h_arr.shape[:2], h_arr.strides[:2])] == ["nx", "ny", "sx", "sy"] + + +def test_invalid_fields(): + ctx = KernelCreationContext(index_dtype=SInt(16)) + + h = Field( + "h", + FieldType.GENERIC, + Fp(32), + (0,), + (TypedSymbol("nx", SInt(32)),), + (TypedSymbol("sx", SInt(64)),) + ) + + with pytest.raises(KernelConstraintsError): + _ = ctx.get_array(h) + + h = Field( + "h", + FieldType.GENERIC, + Fp(32), + (0,), + (TypedSymbol("nx", Fp(32)),), + (TypedSymbol("sx", Fp(32)),) + ) + + with pytest.raises(KernelConstraintsError): + _ = ctx.get_array(h) + + h = Field( + "h", + FieldType.GENERIC, + Fp(32), + (0,), + (TypedSymbol("nx", DynamicType.NUMERIC_TYPE),), + (TypedSymbol("sx", DynamicType.NUMERIC_TYPE),) + ) + + with pytest.raises(KernelConstraintsError): + _ = ctx.get_array(h) diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py index f16a468e7..270c8f44a 100644 --- a/tests/nbackend/kernelcreation/test_freeze.py +++ b/tests/nbackend/kernelcreation/test_freeze.py @@ -1,7 +1,7 @@ import sympy as sp import pytest -from pystencils import Assignment, fields, create_type, create_numeric_type +from pystencils import Assignment, fields, create_type, create_numeric_type, TypedSymbol, DynamicType from pystencils.sympyextensions import CastFunc from pystencils.backend.ast.structural import ( @@ -266,6 +266,25 @@ def test_multiarg_min_max(): assert expr.structurally_equal(op(op(w2, x2), op(y2, z2))) +def test_dynamic_types(): + ctx = KernelCreationContext( + default_dtype=create_numeric_type("float16"), index_dtype=create_type("int16") + ) + freeze = FreezeExpressions(ctx) + + x, y = [TypedSymbol(n, DynamicType.NUMERIC_TYPE) for n in "xy"] + p, q = [TypedSymbol(n, DynamicType.INDEX_TYPE) for n in "pq"] + + expr = freeze(x + y) + + assert ctx.get_symbol("x").dtype == ctx.default_dtype + assert ctx.get_symbol("y").dtype == ctx.default_dtype + + expr = freeze(p - q) + assert ctx.get_symbol("p").dtype == ctx.index_dtype + assert ctx.get_symbol("q").dtype == ctx.index_dtype + + def test_cast_func(): ctx = KernelCreationContext( default_dtype=create_numeric_type("float16"), index_dtype=create_type("int16") -- GitLab