diff --git a/src/pystencils/__init__.py b/src/pystencils/__init__.py index 3d3b7846a84bb9c477b38e90839d7f67fe12933c..c39cd3b826c733da9dceb1944f0cce5038c348e0 100644 --- a/src/pystencils/__init__.py +++ b/src/pystencils/__init__.py @@ -6,7 +6,7 @@ from . import fd from . import stencil as stencil from .display_utils import get_code_obj, get_code_str, show_code, to_dot from .field import Field, FieldType, fields -from .types import create_type +from .types import create_type, create_numeric_type from .cache import clear_cache from .config import ( CreateKernelConfig, @@ -41,6 +41,7 @@ __all__ = [ "DEFAULTS", "TypedSymbol", "create_type", + "create_numeric_type", "make_slice", "CreateKernelConfig", "CpuOptimConfig", diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 3865db38fe603a6cf5fe4d31deef1743d4276bd6..59fa04b3bdb6af98ac584ef7d166167c64423862 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -7,13 +7,12 @@ import sympy.core.relational import sympy.logic.boolalg from sympy.codegen.ast import AssignmentBase, AugmentedAssignment +from ...sympyextensions.astnodes import Assignment, AssignmentCollection from ...sympyextensions import ( - Assignment, - AssignmentCollection, integer_functions, ConditionalFieldAccess, ) -from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc +from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc, DynamicType from ...sympyextensions.pointers import AddressOf from ...field import Field, FieldType @@ -58,7 +57,7 @@ from ..ast.expressions import ( ) from ..constants import PsConstant -from ...types import PsStructType +from ...types import PsStructType, PsType from ..exceptions import PsInputError from ..functions import PsMathFunction, MathFunctions @@ -465,7 +464,16 @@ class FreezeExpressions: return cast(PsCall, args[0]) def map_CastFunc(self, cast_expr: CastFunc) -> PsCast: - return PsCast(cast_expr.dtype, self.visit_expr(cast_expr.expr)) + dtype: PsType + match cast_expr.dtype: + case DynamicType.NUMERIC_TYPE: + dtype = self._ctx.default_dtype + case DynamicType.INDEX_TYPE: + dtype = self._ctx.index_dtype + case other if isinstance(other, PsType): + dtype = other + + return PsCast(dtype, self.visit_expr(cast_expr.expr)) def map_Relational(self, rel: sympy.core.relational.Relational) -> PsRel: arg1, arg2 = [self.visit_expr(arg) for arg in rel.args] diff --git a/src/pystencils/sympyextensions/typed_sympy.py b/src/pystencils/sympyextensions/typed_sympy.py index e022db511ed9e637d0a4c2eea31d62a2214dd9ca..c81a189eee29b88d84906ff6c8112f388cd67476 100644 --- a/src/pystencils/sympyextensions/typed_sympy.py +++ b/src/pystencils/sympyextensions/typed_sympy.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import sympy as sp +from enum import Enum, auto -from ..types import PsType, PsNumericType, PsPointerType, PsBoolType, create_type +from ..types import PsType, PsNumericType, PsPointerType, PsBoolType, PsIntegerType, create_type def assumptions_from_dtype(dtype: PsType): @@ -33,20 +36,28 @@ def is_loop_counter_symbol(symbol): return None -class PsTypeAtom(sp.Atom): - """Wrapper around a PsType to disguise it as a SymPy atom.""" +class DynamicType(Enum): + NUMERIC_TYPE = auto() + INDEX_TYPE = auto() + + +class TypeAtom(sp.Atom): + """Wrapper around a type to disguise it as a SymPy atom.""" def __new__(cls, *args, **kwargs): return sp.Basic.__new__(cls) - def __init__(self, dtype: PsType) -> None: + def __init__(self, dtype: PsType | DynamicType) -> None: self._dtype = dtype def _sympystr(self, *args, **kwargs): return str(self._dtype) - def get(self) -> PsType: + def get(self) -> PsType | DynamicType: return self._dtype + + def _hashable_content(self): + return (self._dtype, ) class TypedSymbol(sp.Symbol): @@ -63,7 +74,7 @@ class TypedSymbol(sp.Symbol): assumptions.update(kwargs) obj = super(TypedSymbol, cls).__xnew__(cls, name, **assumptions) - obj._dtype = create_type(dtype) + obj._dtype = dtype return obj @@ -105,12 +116,15 @@ class FieldStrideSymbol(TypedSymbol): obj = FieldStrideSymbol.__xnew_cached_(cls, *args, **kwds) return obj - def __new_stage2__(cls, field_name: str, coordinate: int): + def __new_stage2__(cls, field_name: str, coordinate: int, dtype: PsIntegerType | None = None): from ..defaults import DEFAULTS + + if dtype is None: + dtype = DEFAULTS.index_dtype name = f"_stride_{field_name}_{coordinate}" obj = super(FieldStrideSymbol, cls).__xnew__( - cls, name, DEFAULTS.index_dtype, positive=True + cls, name, dtype, positive=True ) obj.field_name = field_name obj.coordinate = coordinate @@ -138,12 +152,15 @@ class FieldShapeSymbol(TypedSymbol): obj = FieldShapeSymbol.__xnew_cached_(cls, *args, **kwds) return obj - def __new_stage2__(cls, field_name: str, coordinate: int): + def __new_stage2__(cls, field_name: str, coordinate: int, dtype: PsIntegerType | None = None): from ..defaults import DEFAULTS + + if dtype is None: + dtype = DEFAULTS.index_dtype name = f"_size_{field_name}_{coordinate}" obj = super(FieldShapeSymbol, cls).__xnew__( - cls, name, DEFAULTS.index_dtype, positive=True + cls, name, dtype, positive=True ) obj.field_name = field_name obj.coordinate = coordinate @@ -190,10 +207,21 @@ class FieldPointerSymbol(TypedSymbol): class CastFunc(sp.Function): + """Use this function to introduce a static type cast into the output code. + + Usage: ``CastFunc(expr, target_type)`` becomes, in C code, ``(target_type) expr``. + The `target_type` may be a valid pystencils type specification parsable by `create_type`, + or a special value of the `DynamicType` enum. + These dynamic types can be used to select the target type according to the code generation context. """ - CastFunc is used in order to introduce static casts. They are especially useful as a way to signal what type - a certain node should have, if it is impossible to add a type to a node, e.g. a sp.Number. - """ + + @staticmethod + def as_numeric(expr): + return CastFunc(expr, DynamicType.NUMERIC_TYPE) + + @staticmethod + def as_index(expr): + return CastFunc(expr, DynamicType.INDEX_TYPE) is_Atom = True @@ -207,8 +235,12 @@ class CastFunc(sp.Function): if expr.__class__ == CastFunc: expr = expr.args[0] - if not isinstance(dtype, PsTypeAtom): - dtype = PsTypeAtom(create_type(dtype)) + if not isinstance(dtype, (TypeAtom)): + if isinstance(dtype, DynamicType): + dtype = TypeAtom(dtype) + else: + dtype = TypeAtom(create_type(dtype)) + # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads # to problems when for example comparing cast_func's for equality @@ -236,8 +268,8 @@ class CastFunc(sp.Function): return self.args[0].is_commutative @property - def dtype(self) -> PsType: - assert isinstance(self.args[1], PsTypeAtom) + def dtype(self) -> PsType | DynamicType: + assert isinstance(self.args[1], TypeAtom) return self.args[1].get() @property @@ -246,7 +278,9 @@ class CastFunc(sp.Function): @property def is_integer(self): - if isinstance(self.dtype, PsNumericType): + if self.dtype == DynamicType.INDEX_TYPE: + return True + elif isinstance(self.dtype, PsNumericType): return self.dtype.is_int() or super().is_integer else: return super().is_integer diff --git a/src/pystencils/types/parsing.py b/src/pystencils/types/parsing.py index 75fb35d223a9dbc1cfd1090a661e0c61a5335cf8..d6522e5bbf6c37c5f105114b40ad73b82fe3e1fd 100644 --- a/src/pystencils/types/parsing.py +++ b/src/pystencils/types/parsing.py @@ -158,6 +158,8 @@ def parse_type_name(typename: str, const: bool): case "uint8" | "uint8_t": return PsUnsignedIntegerType(8, const=const) + case "half" | "float16": + return PsIeeeFloatType(16, const=const) case "float" | "float32": return PsIeeeFloatType(32, const=const) case "double" | "float64": diff --git a/src/pystencils/types/types.py b/src/pystencils/types/types.py index 2f0f2ff46ec1498eda724f23964bf28aa2ad2a9b..61e3d73fd5b66e06b48ea9de788b3fa1b51a7a61 100644 --- a/src/pystencils/types/types.py +++ b/src/pystencils/types/types.py @@ -72,7 +72,7 @@ class PsPointerType(PsDereferencableType): __match_args__ = ("base_type",) - def __init__(self, base_type: PsType, restrict: bool = True, const: bool = False): + def __init__(self, base_type: PsType, restrict: bool = False, const: bool = False): super().__init__(base_type, const) self._restrict = restrict @@ -94,7 +94,7 @@ class PsPointerType(PsDereferencableType): return f"{base_str} *{restrict_str} {self._const_string()}" def __repr__(self) -> str: - return f"PsPointerType( {repr(self.base_type)}, const={self.const} )" + return f"PsPointerType( {repr(self.base_type)}, const={self.const}, restrict={self.restrict} )" class PsArrayType(PsDereferencableType): @@ -200,7 +200,7 @@ class PsStructType(PsType): @property def numpy_dtype(self) -> np.dtype: members = [(m.name, m.dtype.numpy_dtype) for m in self._members] - return np.dtype(members) + return np.dtype(members, align=True) @property def itemsize(self) -> int: diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py index b22df7d0bd132cc530e289b630f9c48851e4996b..f16a468e715bd64a4c176794a73134f6e6807e50 100644 --- a/tests/nbackend/kernelcreation/test_freeze.py +++ b/tests/nbackend/kernelcreation/test_freeze.py @@ -1,7 +1,8 @@ import sympy as sp import pytest -from pystencils import Assignment, fields +from pystencils import Assignment, fields, create_type, create_numeric_type +from pystencils.sympyextensions import CastFunc from pystencils.backend.ast.structural import ( PsAssignment, @@ -26,7 +27,8 @@ from pystencils.backend.ast.expressions import ( PsLe, PsGt, PsGe, - PsCall + PsCall, + PsCast, ) from pystencils.backend.constants import PsConstant from pystencils.backend.functions import PsMathFunction, MathFunctions @@ -182,14 +184,17 @@ def test_freeze_booleans(): assert expr.structurally_equal(PsOr(PsOr(PsOr(w2, x2), y2), z2)) -@pytest.mark.parametrize("rel_pair", [ - (sp.Eq, PsEq), - (sp.Ne, PsNe), - (sp.Lt, PsLt), - (sp.Gt, PsGt), - (sp.Le, PsLe), - (sp.Ge, PsGe) -]) +@pytest.mark.parametrize( + "rel_pair", + [ + (sp.Eq, PsEq), + (sp.Ne, PsNe), + (sp.Lt, PsLt), + (sp.Gt, PsGt), + (sp.Le, PsLe), + (sp.Ge, PsGe), + ], +) def test_freeze_relations(rel_pair): ctx = KernelCreationContext() freeze = FreezeExpressions(ctx) @@ -211,7 +216,7 @@ def test_freeze_piecewise(): freeze = FreezeExpressions(ctx) p, q, x, y, z = sp.symbols("p, q, x, y, z") - + p2 = PsExpression.make(ctx.get_symbol("p")) q2 = PsExpression.make(ctx.get_symbol("q")) x2 = PsExpression.make(ctx.get_symbol("x")) @@ -222,10 +227,10 @@ def test_freeze_piecewise(): expr = freeze(piecewise) assert isinstance(expr, PsTernary) - + should = PsTernary(p2, x2, PsTernary(q2, y2, z2)) assert expr.structurally_equal(should) - + piecewise = sp.Piecewise((x, p), (y, q), (z, sp.Or(p, q))) with pytest.raises(FreezeError): freeze(piecewise) @@ -259,3 +264,25 @@ def test_multiarg_min_max(): expr = freeze(sp.Max(w, x, y, z)) assert expr.structurally_equal(op(op(w2, x2), op(y2, z2))) + + +def test_cast_func(): + ctx = KernelCreationContext( + default_dtype=create_numeric_type("float16"), index_dtype=create_type("int16") + ) + freeze = FreezeExpressions(ctx) + + x, y, z = sp.symbols("x, y, z") + + x2 = PsExpression.make(ctx.get_symbol("x")) + y2 = PsExpression.make(ctx.get_symbol("y")) + z2 = PsExpression.make(ctx.get_symbol("z")) + + expr = freeze(CastFunc(x, create_type("int"))) + assert expr.structurally_equal(PsCast(create_type("int"), x2)) + + expr = freeze(CastFunc.as_numeric(y)) + assert expr.structurally_equal(PsCast(ctx.default_dtype, y2)) + + expr = freeze(CastFunc.as_index(z)) + assert expr.structurally_equal(PsCast(ctx.index_dtype, z2)) diff --git a/tests/nbackend/types/test_types.py b/tests/nbackend/types/test_types.py index 39f89e6fe6ef7ff77fdf5534eaebb3510f9caf4b..1cc2ae0e4a213df51ccf80578a6ba028771e4f0c 100644 --- a/tests/nbackend/types/test_types.py +++ b/tests/nbackend/types/test_types.py @@ -139,6 +139,17 @@ def test_struct_types(): with pytest.raises(PsTypeError): t.c_string() + t = PsStructType([ + ("a", SInt(8)), + ("b", SInt(16)), + ("c", SInt(64)) + ]) + + # Check that natural alignment is taken into account + numpy_type = np.dtype([("a", "i1"), ("b", "i2"), ("c", "i8")], align=True) + assert t.numpy_dtype == numpy_type + assert t.itemsize == numpy_type.itemsize == 16 + def test_pickle(): types = [ diff --git a/tests/symbolics/test_typed_sympy.py b/tests/symbolics/test_typed_sympy.py new file mode 100644 index 0000000000000000000000000000000000000000..41015f96bfa6950a57f9ccfa3194c128c2bc0f69 --- /dev/null +++ b/tests/symbolics/test_typed_sympy.py @@ -0,0 +1,57 @@ +import numpy as np + +from pystencils.sympyextensions.typed_sympy import ( + TypedSymbol, + CastFunc, + TypeAtom, + DynamicType, +) +from pystencils.types import create_type +from pystencils.types.quick import UInt, Ptr + + +def test_type_atoms(): + atom1 = TypeAtom(create_type("int32")) + atom2 = TypeAtom(create_type("int32")) + + assert atom1 == atom2 + + atom3 = TypeAtom(create_type("const int32")) + assert atom1 != atom3 + + atom4 = TypeAtom(DynamicType.INDEX_TYPE) + atom5 = TypeAtom(DynamicType.NUMERIC_TYPE) + + assert atom3 != atom4 + assert atom4 != atom5 + + +def test_typed_symbol(): + x = TypedSymbol("x", "uint32") + x2 = TypedSymbol("x", "uint64 *") + z = TypedSymbol("z", "float32") + + assert x == TypedSymbol("x", np.uint32) + assert x != x2 + + assert x.dtype == UInt(32) + assert x2.dtype == Ptr(UInt(64)) + + assert x.is_integer + assert x.is_nonnegative + + assert not x2.is_integer + + assert z.is_real + assert not z.is_nonnegative + + +def test_cast_func(): + assert ( + CastFunc(TypedSymbol("s", np.uint), np.int64).canonical + == TypedSymbol("s", np.uint).canonical + ) + + a = CastFunc(5, np.uint) + assert a.is_negative is False + assert a.is_nonnegative