From 095436ed00d326c63298ede52ffbc0474b0c7cc8 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Wed, 31 Jan 2024 14:09:40 +0100 Subject: [PATCH] additional tests & some fixes --- .../nbackend/kernelcreation/freeze.py | 8 +-- .../nbackend/kernelcreation/typification.py | 5 +- src/pystencils/nbackend/types/__init__.py | 3 +- src/pystencils/nbackend/types/basic_types.py | 9 +-- src/pystencils/nbackend/types/parsing.py | 2 + src/pystencils/nbackend/types/quick.py | 17 +++-- tests/nbackend/test_cpujit.py | 4 +- tests/nbackend/test_types.py | 72 +++++++++++++++++++ tests/nbackend/test_typification.py | 33 ++++++++- 9 files changed, 132 insertions(+), 21 deletions(-) create mode 100644 tests/nbackend/test_types.py diff --git a/src/pystencils/nbackend/kernelcreation/freeze.py b/src/pystencils/nbackend/kernelcreation/freeze.py index 11ea24219..22eb1dbf4 100644 --- a/src/pystencils/nbackend/kernelcreation/freeze.py +++ b/src/pystencils/nbackend/kernelcreation/freeze.py @@ -32,15 +32,15 @@ class FreezeExpressions(SympyToPymbolicMapper): @overload def __call__(self, asms: AssignmentCollection) -> PsBlock: - ... + pass @overload def __call__(self, expr: sp.Expr) -> PsExpression: - ... + pass @overload - def __call__(self, expr: Assignment) -> PsAssignment: - ... + def __call__(self, asm: Assignment) -> PsAssignment: + pass def __call__(self, obj): if isinstance(obj, AssignmentCollection): diff --git a/src/pystencils/nbackend/kernelcreation/typification.py b/src/pystencils/nbackend/kernelcreation/typification.py index c8421fbfe..f1d299e28 100644 --- a/src/pystencils/nbackend/kernelcreation/typification.py +++ b/src/pystencils/nbackend/kernelcreation/typification.py @@ -171,8 +171,9 @@ class Typifier(Mapper): def typify_expression( self, expr: Any, target_type: PsNumericType | None = None - ) -> ExprOrConstant: - return self.rec(expr, TypeContext(target_type)) + ) -> tuple[ExprOrConstant, PsNumericType]: + tc = TypeContext(target_type) + return self.rec(expr, tc) # Leaf nodes: Variables, Typed Variables, Constants and TypedConstants diff --git a/src/pystencils/nbackend/types/__init__.py b/src/pystencils/nbackend/types/__init__.py index 13deab6b4..d7eb490c5 100644 --- a/src/pystencils/nbackend/types/__init__.py +++ b/src/pystencils/nbackend/types/__init__.py @@ -13,7 +13,7 @@ from .basic_types import ( deconstify, ) -from .quick import make_type +from .quick import make_type, make_numeric_type from .exception import PsTypeError @@ -31,5 +31,6 @@ __all__ = [ "constify", "deconstify", "make_type", + "make_numeric_type", "PsTypeError", ] diff --git a/src/pystencils/nbackend/types/basic_types.py b/src/pystencils/nbackend/types/basic_types.py index 540f28334..bb27e3493 100644 --- a/src/pystencils/nbackend/types/basic_types.py +++ b/src/pystencils/nbackend/types/basic_types.py @@ -208,10 +208,9 @@ class PsStructType(PsAbstractType): def _c_string(self) -> str: if self._name is None: - # raise PsInternalCompilerError( - # "Cannot retrieve C string for anonymous struct type" - # ) - return "<anonymous>" + raise PsInternalCompilerError( + "Cannot retrieve C string for anonymous struct type" + ) return self._name def __eq__(self, other: object) -> bool: @@ -502,6 +501,8 @@ class PsIeeeFloatType(PsScalarType): def _c_string(self) -> str: match self._width: + case 16: + return f"{self._const_string()}half" case 32: return f"{self._const_string()}float" case 64: diff --git a/src/pystencils/nbackend/types/parsing.py b/src/pystencils/nbackend/types/parsing.py index 952438f11..be9600c71 100644 --- a/src/pystencils/nbackend/types/parsing.py +++ b/src/pystencils/nbackend/types/parsing.py @@ -34,6 +34,8 @@ def interpret_python_type(t: type) -> PsAbstractType: if t is np.int64: return PsSignedIntegerType(64) + if t is np.float16: + return PsIeeeFloatType(16) if t is np.float32: return PsIeeeFloatType(32) if t is np.float64: diff --git a/src/pystencils/nbackend/types/quick.py b/src/pystencils/nbackend/types/quick.py index cf65897d7..e5d271cf9 100644 --- a/src/pystencils/nbackend/types/quick.py +++ b/src/pystencils/nbackend/types/quick.py @@ -11,6 +11,7 @@ import numpy as np from .basic_types import ( PsAbstractType, PsCustomType, + PsNumericType, PsScalarType, PsPointerType, PsIntegerType, @@ -39,11 +40,7 @@ def make_type(type_spec: UserTypeSpec) -> PsAbstractType: - Instances of `PsAbstractType` will be returned as they are """ - from .parsing import ( - parse_type_string, - interpret_python_type, - interpret_numpy_dtype - ) + from .parsing import parse_type_string, interpret_python_type, interpret_numpy_dtype if isinstance(type_spec, PsAbstractType): return type_spec @@ -56,6 +53,16 @@ def make_type(type_spec: UserTypeSpec) -> PsAbstractType: raise ValueError(f"{type_spec} is not a valid type specification.") +def make_numeric_type(type_spec: UserTypeSpec) -> PsNumericType: + """Like `make_type`, but only for numeric types.""" + dtype = make_type(type_spec) + if not isinstance(dtype, PsNumericType): + raise ValueError( + f"Given type {type_spec} does not translate to a numeric type." + ) + return dtype + + Custom = PsCustomType """`Custom(name)` matches `PsCustomType(name)`""" diff --git a/tests/nbackend/test_cpujit.py b/tests/nbackend/test_cpujit.py index 6c2a453c7..b93f7a1e2 100644 --- a/tests/nbackend/test_cpujit.py +++ b/tests/nbackend/test_cpujit.py @@ -15,8 +15,8 @@ from pystencils.cpu.cpujit import compile_and_load def test_pairwise_addition(): idx_type = SInt(64) - u = PsLinearizedArray("u", Fp(64, const=True), (..., ...), (..., ...), index_dtype=idx_type) - v = PsLinearizedArray("v", Fp(64), (..., ...), (..., ...), index_dtype=idx_type) + u = PsLinearizedArray("u", Fp(64, const=True), (...,), (...,), index_dtype=idx_type) + v = PsLinearizedArray("v", Fp(64), (...,), (...,), index_dtype=idx_type) u_data = PsArrayBasePointer("u_data", u) v_data = PsArrayBasePointer("v_data", v) diff --git a/tests/nbackend/test_types.py b/tests/nbackend/test_types.py new file mode 100644 index 000000000..ba5746222 --- /dev/null +++ b/tests/nbackend/test_types.py @@ -0,0 +1,72 @@ +import pytest +import numpy as np + +from pystencils.nbackend.exceptions import PsInternalCompilerError +from pystencils.nbackend.types import * +from pystencils.nbackend.types.quick import * + + +@pytest.mark.parametrize( + "numpy_type", + [ + np.uint8, + np.uint16, + np.uint32, + np.uint64, + np.int8, + np.int16, + np.int32, + np.int64, + np.float16, + np.float32, + np.float64, + ], +) +def test_numpy_translation(numpy_type): + dtype_obj = np.dtype(numpy_type) + ps_type = make_type(numpy_type) + + assert isinstance(ps_type, PsNumericType) + assert ps_type.numpy_dtype == dtype_obj + assert ps_type.itemsize == dtype_obj.itemsize + + assert isinstance(ps_type.create_constant(13), numpy_type) + + if ps_type.is_int(): + with pytest.raises(PsTypeError): + ps_type.create_constant(13.0) + with pytest.raises(PsTypeError): + ps_type.create_constant(1.75) + + if ps_type.is_sint(): + assert numpy_type(17) == ps_type.create_constant(17) + assert numpy_type(-4) == ps_type.create_constant(-4) + + if ps_type.is_uint(): + with pytest.raises(PsTypeError): + ps_type.create_constant(-4) + + if ps_type.is_float(): + assert numpy_type(17.3) == ps_type.create_constant(17.3) + assert numpy_type(-4.2) == ps_type.create_constant(-4.2) + + +def test_constify(): + t = PsCustomType("std::shared_ptr< Custom >") + assert deconstify(t) == t + assert deconstify(constify(t)) == t + s = PsCustomType("Field", const=True) + assert constify(s) == s + + +def test_struct_types(): + t = PsStructType( + [ + PsStructType.Member("data", Ptr(Fp(32))), + ("size", UInt(32)), + ] + ) + + assert t.anonymous + with pytest.raises(PsInternalCompilerError): + str(t) diff --git a/tests/nbackend/test_typification.py b/tests/nbackend/test_typification.py index 6caadb084..ae477fe19 100644 --- a/tests/nbackend/test_typification.py +++ b/tests/nbackend/test_typification.py @@ -1,16 +1,17 @@ import pytest import sympy as sp +import numpy as np import pymbolic.primitives as pb -from pystencils import Assignment +from pystencils import Assignment, TypedSymbol from pystencils.nbackend.ast import PsDeclaration -from pystencils.nbackend.types import constify +from pystencils.nbackend.types import constify, make_numeric_type from pystencils.nbackend.typed_expressions import PsTypedConstant, PsTypedVariable from pystencils.nbackend.kernelcreation.options import KernelCreationOptions from pystencils.nbackend.kernelcreation.context import KernelCreationContext from pystencils.nbackend.kernelcreation.freeze import FreezeExpressions -from pystencils.nbackend.kernelcreation.typification import Typifier +from pystencils.nbackend.kernelcreation.typification import Typifier, TypificationError def test_typify_simple(): @@ -68,3 +69,29 @@ def test_contextual_typing(): pytest.fail(f"Unexpected expression: {expr}") check(expr.expression) + + +def test_erronous_typing(): + options = KernelCreationOptions(default_dtype=make_numeric_type(np.float64)) + ctx = KernelCreationContext(options) + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + x, y, z = sp.symbols("x, y, z") + q = TypedSymbol("q", np.float32) + w = TypedSymbol("w", np.float16) + + expr = freeze(2 * x + 3 * y + q - 4) + + with pytest.raises(TypificationError): + typify(expr) + + asm = Assignment(q, 3 - w) + fasm = freeze(asm) + with pytest.raises(TypificationError): + typify(fasm) + + asm = Assignment(q, 3 - x) + fasm = freeze(asm) + with pytest.raises(TypificationError): + typify(fasm) -- GitLab