From 8f65a3a778db4a2b5bace272f0d84900e093df4b Mon Sep 17 00:00:00 2001
From: Daniel Bauer <daniel.j.bauer@fau.de>
Date: Thu, 28 Mar 2024 14:36:46 +0100
Subject: [PATCH] Fix typing of constants

---
 .../backend/kernelcreation/typification.py    | 13 ++++++++++---
 src/pystencils/backend/symbols.py             |  4 +++-
 src/pystencils/types/basic_types.py           |  2 +-
 .../kernelcreation/test_typification.py       | 19 ++++++++++++++-----
 4 files changed, 28 insertions(+), 10 deletions(-)

diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py
index dcfb0f548..9ef649b31 100644
--- a/src/pystencils/backend/kernelcreation/typification.py
+++ b/src/pystencils/backend/kernelcreation/typification.py
@@ -96,7 +96,7 @@ class TypeContext:
         Otherwise, the expression is deferred, and a type will be applied to it as soon as `apply_type` is
         called on this context.
 
-        It the expression already has a data type set, it must be equal to the inferred type.
+        If the expression already has a data type set, it must be equal to the inferred type.
         """
 
         if self._target_type is None:
@@ -126,8 +126,15 @@ class TypeContext:
                         raise TypificationError(
                             f"Can't typify constant with non-numeric type {self._target_type}"
                         )
-                    c.apply_dtype(self._target_type)
-                
+                    if c.dtype is None:
+                        c.apply_dtype(self._target_type)
+                    elif deconstify(c.dtype) != self._target_type:
+                        raise TypificationError(
+                            f"Type mismatch at constant {c}: Constant type did not match the context's target type\n"
+                            f"  Constant type: {c.dtype}\n"
+                            f"    Target type: {self._target_type}"
+                        )
+
                 case PsSymbolExpr(symb):
                     symb.apply_dtype(self._target_type)
 
diff --git a/src/pystencils/backend/symbols.py b/src/pystencils/backend/symbols.py
index e1db5a930..b007e3fcf 100644
--- a/src/pystencils/backend/symbols.py
+++ b/src/pystencils/backend/symbols.py
@@ -42,7 +42,9 @@ class PsSymbol:
 
     def get_dtype(self) -> PsType:
         if self._dtype is None:
-            raise PsInternalCompilerError("Symbol had no type assigned yet")
+            raise PsInternalCompilerError(
+                f"Symbol {self.name} had no type assigned yet"
+            )
         return self._dtype
 
     def __str__(self) -> str:
diff --git a/src/pystencils/types/basic_types.py b/src/pystencils/types/basic_types.py
index 7565ea9a5..b83b6d7d6 100644
--- a/src/pystencils/types/basic_types.py
+++ b/src/pystencils/types/basic_types.py
@@ -655,7 +655,7 @@ class PsIeeeFloatType(PsScalarType):
     def __init__(self, width: int, const: bool = False):
         if width not in self.SUPPORTED_WIDTHS:
             raise ValueError(
-                f"Invalid integer width; must be one of {self.SUPPORTED_WIDTHS}."
+                f"Invalid integer width {width}; must be one of {self.SUPPORTED_WIDTHS}."
             )
 
         super().__init__(const)
diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py
index 9ff18623e..cb7e5561f 100644
--- a/tests/nbackend/kernelcreation/test_typification.py
+++ b/tests/nbackend/kernelcreation/test_typification.py
@@ -131,13 +131,11 @@ def test_typify_integer_binops():
 
     ctx.get_symbol("x", ctx.index_dtype)
     ctx.get_symbol("y", ctx.index_dtype)
-    ctx.get_symbol("z", ctx.index_dtype)
 
-    x, y, z = sp.symbols("x, y, z")
+    x, y = sp.symbols("x, y")
     expr = bit_shift_left(
-        bit_shift_right(bitwise_and(x, 2), bitwise_or(y, z)), bitwise_xor(2, 2)
-    )  #                            ^
-    # TODO: x can not be a constant here, because then the typifier can not check that the arguments are integer.
+        bit_shift_right(bitwise_and(2, 2), bitwise_or(x, y)), bitwise_xor(2, 2)
+    )
     expr = freeze(expr)
     expr = typify(expr)
 
@@ -184,3 +182,14 @@ def test_typify_integer_binops_in_floating_context():
 
     with pytest.raises(TypificationError):
         expr = typify(expr)
+
+
+def test_regression_typify_constants():
+    ctx = KernelCreationContext(default_dtype=Fp(32))
+    freeze = FreezeExpressions(ctx)
+    typify = Typifier(ctx)
+
+    x, y = sp.symbols("x, y")
+    expr = (-x - y) ** 2
+
+    typify(freeze(expr))  # just test that no error is raised
-- 
GitLab