From 095436ed00d326c63298ede52ffbc0474b0c7cc8 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Wed, 31 Jan 2024 14:09:40 +0100
Subject: [PATCH] additional tests & some fixes

---
 .../nbackend/kernelcreation/freeze.py         |  8 +--
 .../nbackend/kernelcreation/typification.py   |  5 +-
 src/pystencils/nbackend/types/__init__.py     |  3 +-
 src/pystencils/nbackend/types/basic_types.py  |  9 +--
 src/pystencils/nbackend/types/parsing.py      |  2 +
 src/pystencils/nbackend/types/quick.py        | 17 +++--
 tests/nbackend/test_cpujit.py                 |  4 +-
 tests/nbackend/test_types.py                  | 72 +++++++++++++++++++
 tests/nbackend/test_typification.py           | 33 ++++++++-
 9 files changed, 132 insertions(+), 21 deletions(-)
 create mode 100644 tests/nbackend/test_types.py

diff --git a/src/pystencils/nbackend/kernelcreation/freeze.py b/src/pystencils/nbackend/kernelcreation/freeze.py
index 11ea24219..22eb1dbf4 100644
--- a/src/pystencils/nbackend/kernelcreation/freeze.py
+++ b/src/pystencils/nbackend/kernelcreation/freeze.py
@@ -32,15 +32,15 @@ class FreezeExpressions(SympyToPymbolicMapper):
 
     @overload
     def __call__(self, asms: AssignmentCollection) -> PsBlock:
-        ...
+        pass
 
     @overload
     def __call__(self, expr: sp.Expr) -> PsExpression:
-        ...
+        pass
 
     @overload
-    def __call__(self, expr: Assignment) -> PsAssignment:
-        ...
+    def __call__(self, asm: Assignment) -> PsAssignment:
+        pass
 
     def __call__(self, obj):
         if isinstance(obj, AssignmentCollection):
diff --git a/src/pystencils/nbackend/kernelcreation/typification.py b/src/pystencils/nbackend/kernelcreation/typification.py
index c8421fbfe..f1d299e28 100644
--- a/src/pystencils/nbackend/kernelcreation/typification.py
+++ b/src/pystencils/nbackend/kernelcreation/typification.py
@@ -171,8 +171,9 @@ class Typifier(Mapper):
 
     def typify_expression(
         self, expr: Any, target_type: PsNumericType | None = None
-    ) -> ExprOrConstant:
-        return self.rec(expr, TypeContext(target_type))
+    ) -> tuple[ExprOrConstant, PsNumericType]:
+        tc = TypeContext(target_type)
+        return self.rec(expr, tc)
 
     #   Leaf nodes: Variables, Typed Variables, Constants and TypedConstants
 
diff --git a/src/pystencils/nbackend/types/__init__.py b/src/pystencils/nbackend/types/__init__.py
index 13deab6b4..d7eb490c5 100644
--- a/src/pystencils/nbackend/types/__init__.py
+++ b/src/pystencils/nbackend/types/__init__.py
@@ -13,7 +13,7 @@ from .basic_types import (
     deconstify,
 )
 
-from .quick import make_type
+from .quick import make_type, make_numeric_type
 
 from .exception import PsTypeError
 
@@ -31,5 +31,6 @@ __all__ = [
     "constify",
     "deconstify",
     "make_type",
+    "make_numeric_type",
     "PsTypeError",
 ]
diff --git a/src/pystencils/nbackend/types/basic_types.py b/src/pystencils/nbackend/types/basic_types.py
index 540f28334..bb27e3493 100644
--- a/src/pystencils/nbackend/types/basic_types.py
+++ b/src/pystencils/nbackend/types/basic_types.py
@@ -208,10 +208,9 @@ class PsStructType(PsAbstractType):
 
     def _c_string(self) -> str:
         if self._name is None:
-            # raise PsInternalCompilerError(
-            #     "Cannot retrieve C string for anonymous struct type"
-            # )
-            return "<anonymous>"
+            raise PsInternalCompilerError(
+                "Cannot retrieve C string for anonymous struct type"
+            )
         return self._name
 
     def __eq__(self, other: object) -> bool:
@@ -502,6 +501,8 @@ class PsIeeeFloatType(PsScalarType):
 
     def _c_string(self) -> str:
         match self._width:
+            case 16:
+                return f"{self._const_string()}half"
             case 32:
                 return f"{self._const_string()}float"
             case 64:
diff --git a/src/pystencils/nbackend/types/parsing.py b/src/pystencils/nbackend/types/parsing.py
index 952438f11..be9600c71 100644
--- a/src/pystencils/nbackend/types/parsing.py
+++ b/src/pystencils/nbackend/types/parsing.py
@@ -34,6 +34,8 @@ def interpret_python_type(t: type) -> PsAbstractType:
     if t is np.int64:
         return PsSignedIntegerType(64)
 
+    if t is np.float16:
+        return PsIeeeFloatType(16)
     if t is np.float32:
         return PsIeeeFloatType(32)
     if t is np.float64:
diff --git a/src/pystencils/nbackend/types/quick.py b/src/pystencils/nbackend/types/quick.py
index cf65897d7..e5d271cf9 100644
--- a/src/pystencils/nbackend/types/quick.py
+++ b/src/pystencils/nbackend/types/quick.py
@@ -11,6 +11,7 @@ import numpy as np
 from .basic_types import (
     PsAbstractType,
     PsCustomType,
+    PsNumericType,
     PsScalarType,
     PsPointerType,
     PsIntegerType,
@@ -39,11 +40,7 @@ def make_type(type_spec: UserTypeSpec) -> PsAbstractType:
         - Instances of `PsAbstractType` will be returned as they are
     """
 
-    from .parsing import (
-        parse_type_string,
-        interpret_python_type,
-        interpret_numpy_dtype
-    )
+    from .parsing import parse_type_string, interpret_python_type, interpret_numpy_dtype
 
     if isinstance(type_spec, PsAbstractType):
         return type_spec
@@ -56,6 +53,16 @@ def make_type(type_spec: UserTypeSpec) -> PsAbstractType:
     raise ValueError(f"{type_spec} is not a valid type specification.")
 
 
+def make_numeric_type(type_spec: UserTypeSpec) -> PsNumericType:
+    """Like `make_type`, but only for numeric types."""
+    dtype = make_type(type_spec)
+    if not isinstance(dtype, PsNumericType):
+        raise ValueError(
+            f"Given type {type_spec} does not translate to a numeric type."
+        )
+    return dtype
+
+
 Custom = PsCustomType
 """`Custom(name)` matches `PsCustomType(name)`"""
 
diff --git a/tests/nbackend/test_cpujit.py b/tests/nbackend/test_cpujit.py
index 6c2a453c7..b93f7a1e2 100644
--- a/tests/nbackend/test_cpujit.py
+++ b/tests/nbackend/test_cpujit.py
@@ -15,8 +15,8 @@ from pystencils.cpu.cpujit import compile_and_load
 def test_pairwise_addition():
     idx_type = SInt(64)
 
-    u = PsLinearizedArray("u", Fp(64, const=True), (..., ...), (..., ...), index_dtype=idx_type)
-    v = PsLinearizedArray("v", Fp(64), (..., ...), (..., ...), index_dtype=idx_type)
+    u = PsLinearizedArray("u", Fp(64, const=True), (...,), (...,), index_dtype=idx_type)
+    v = PsLinearizedArray("v", Fp(64), (...,), (...,), index_dtype=idx_type)
 
     u_data = PsArrayBasePointer("u_data", u)
     v_data = PsArrayBasePointer("v_data", v)
diff --git a/tests/nbackend/test_types.py b/tests/nbackend/test_types.py
new file mode 100644
index 000000000..ba5746222
--- /dev/null
+++ b/tests/nbackend/test_types.py
@@ -0,0 +1,72 @@
+import pytest
+import numpy as np
+
+from pystencils.nbackend.exceptions import PsInternalCompilerError
+from pystencils.nbackend.types import *
+from pystencils.nbackend.types.quick import *
+
+
+@pytest.mark.parametrize(
+    "numpy_type",
+    [
+        np.uint8,
+        np.uint16,
+        np.uint32,
+        np.uint64,
+        np.int8,
+        np.int16,
+        np.int32,
+        np.int64,
+        np.float16,
+        np.float32,
+        np.float64,
+    ],
+)
+def test_numpy_translation(numpy_type):
+    dtype_obj = np.dtype(numpy_type)
+    ps_type = make_type(numpy_type)
+
+    assert isinstance(ps_type, PsNumericType)
+    assert ps_type.numpy_dtype == dtype_obj
+    assert ps_type.itemsize == dtype_obj.itemsize
+
+    assert isinstance(ps_type.create_constant(13), numpy_type)
+
+    if ps_type.is_int():
+        with pytest.raises(PsTypeError):
+            ps_type.create_constant(13.0)
+        with pytest.raises(PsTypeError):
+            ps_type.create_constant(1.75)
+
+    if ps_type.is_sint():
+        assert numpy_type(17) == ps_type.create_constant(17)
+        assert numpy_type(-4) == ps_type.create_constant(-4)
+
+    if ps_type.is_uint():
+        with pytest.raises(PsTypeError):
+            ps_type.create_constant(-4)
+
+    if ps_type.is_float():
+        assert numpy_type(17.3) == ps_type.create_constant(17.3)
+        assert numpy_type(-4.2) == ps_type.create_constant(-4.2)
+
+
+def test_constify():
+    t = PsCustomType("std::shared_ptr< Custom >")
+    assert deconstify(t) == t
+    assert deconstify(constify(t)) == t
+    s = PsCustomType("Field", const=True)
+    assert constify(s) == s
+
+
+def test_struct_types():
+    t = PsStructType(
+        [
+            PsStructType.Member("data", Ptr(Fp(32))),
+            ("size", UInt(32)),
+        ]
+    )
+
+    assert t.anonymous
+    with pytest.raises(PsInternalCompilerError):
+        str(t)
diff --git a/tests/nbackend/test_typification.py b/tests/nbackend/test_typification.py
index 6caadb084..ae477fe19 100644
--- a/tests/nbackend/test_typification.py
+++ b/tests/nbackend/test_typification.py
@@ -1,16 +1,17 @@
 import pytest
 import sympy as sp
+import numpy as np
 import pymbolic.primitives as pb
 
-from pystencils import Assignment
+from pystencils import Assignment, TypedSymbol
 
 from pystencils.nbackend.ast import PsDeclaration
-from pystencils.nbackend.types import constify
+from pystencils.nbackend.types import constify, make_numeric_type
 from pystencils.nbackend.typed_expressions import PsTypedConstant, PsTypedVariable
 from pystencils.nbackend.kernelcreation.options import KernelCreationOptions
 from pystencils.nbackend.kernelcreation.context import KernelCreationContext
 from pystencils.nbackend.kernelcreation.freeze import FreezeExpressions
-from pystencils.nbackend.kernelcreation.typification import Typifier
+from pystencils.nbackend.kernelcreation.typification import Typifier, TypificationError
 
 
 def test_typify_simple():
@@ -68,3 +69,29 @@ def test_contextual_typing():
                 pytest.fail(f"Unexpected expression: {expr}")
 
     check(expr.expression)
+
+
+def test_erronous_typing():
+    options = KernelCreationOptions(default_dtype=make_numeric_type(np.float64))
+    ctx = KernelCreationContext(options)
+    freeze = FreezeExpressions(ctx)
+    typify = Typifier(ctx)
+
+    x, y, z = sp.symbols("x, y, z")
+    q = TypedSymbol("q", np.float32)
+    w = TypedSymbol("w", np.float16)
+
+    expr = freeze(2 * x + 3 * y + q - 4)
+
+    with pytest.raises(TypificationError):
+        typify(expr)
+
+    asm = Assignment(q, 3 - w)
+    fasm = freeze(asm)
+    with pytest.raises(TypificationError):
+        typify(fasm)
+
+    asm = Assignment(q, 3 - x)
+    fasm = freeze(asm)
+    with pytest.raises(TypificationError):
+        typify(fasm)
-- 
GitLab