From 953a1f9387f40e38058ee60ebba08b74cece919c Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Thu, 28 Mar 2024 16:03:35 +0100
Subject: [PATCH] Various fixes to constants

---
 src/pystencils/backend/constants.py           | 50 +++++++++---
 .../backend/kernelcreation/typification.py    |  2 +-
 src/pystencils/types/basic_types.py           | 48 +++++------
 .../kernelcreation/test_typification.py       | 19 +++--
 tests/nbackend/test_constant_folding.py       | 26 ------
 tests/nbackend/test_constants.py              | 79 ++++++++++++++++++
 tests/nbackend/types/test_constants.py        | 80 -------------------
 7 files changed, 154 insertions(+), 150 deletions(-)
 delete mode 100644 tests/nbackend/test_constant_folding.py
 create mode 100644 tests/nbackend/test_constants.py
 delete mode 100644 tests/nbackend/types/test_constants.py

diff --git a/src/pystencils/backend/constants.py b/src/pystencils/backend/constants.py
index 6e76f6dbb..6dc07842f 100644
--- a/src/pystencils/backend/constants.py
+++ b/src/pystencils/backend/constants.py
@@ -1,3 +1,4 @@
+from __future__ import annotations
 from typing import Any
 
 from ..types import PsNumericType, constify
@@ -5,6 +6,21 @@ from .exceptions import PsInternalCompilerError
 
 
 class PsConstant:
+    """Type-safe representation of typed numerical constants.
+    
+    This class models constants in the backend representation of kernels.
+    A constant may be *untyped*, in which case its ``value`` may be any Python object.
+    
+    If the constant is *typed* (i.e. its ``dtype`` is not ``None``), its data type is used
+    to check the validity of its ``value`` and to convert it into the type's internal representation.
+
+    Instances of `PsConstant` are immutable.
+
+    Args:
+        value: The constant's value
+        dtype: The constant's data type, or ``None`` if untyped.
+    """
+
     __match_args__ = ("value", "dtype")
 
     def __init__(self, value: Any, dtype: PsNumericType | None = None):
@@ -12,7 +28,30 @@ class PsConstant:
         self._value = value
 
         if dtype is not None:
-            self.apply_dtype(dtype)
+            self._dtype = constify(dtype)
+            self._value = self._dtype.create_constant(self._value)
+        else:
+            self._dtype = None
+            self._value = value
+
+    def interpret_as(self, dtype: PsNumericType) -> PsConstant:
+        """Interprets this *untyped* constant with the given data type.
+        
+        If this constant is already typed, raises an error.
+        """
+        if self._dtype is not None:
+            raise PsInternalCompilerError(
+                f"Cannot interpret already typed constant {self} with type {dtype}"
+            )
+        
+        return PsConstant(self._value, dtype)
+    
+    def reinterpret_as(self, dtype: PsNumericType) -> PsConstant:
+        """Reinterprets this constant with the given data type.
+        
+        Other than `interpret_as`, this method also works on typed constants.
+        """
+        return PsConstant(self._value, dtype)
 
     @property
     def value(self) -> Any:
@@ -27,15 +66,6 @@ class PsConstant:
             raise PsInternalCompilerError("Data type of constant was not set.")
         return self._dtype
 
-    def apply_dtype(self, dtype: PsNumericType):
-        if self._dtype is not None:
-            raise PsInternalCompilerError(
-                "Attempt to apply data type to already typed constant."
-            )
-
-        self._dtype = constify(dtype)
-        self._value = self._dtype.create_constant(self._value)
-
     def __str__(self) -> str:
         type_str = "<untyped>" if self._dtype is None else str(self._dtype)
         return f"{str(self._value)}: {type_str}"
diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py
index 9ef649b31..d2c93e221 100644
--- a/src/pystencils/backend/kernelcreation/typification.py
+++ b/src/pystencils/backend/kernelcreation/typification.py
@@ -127,7 +127,7 @@ class TypeContext:
                             f"Can't typify constant with non-numeric type {self._target_type}"
                         )
                     if c.dtype is None:
-                        c.apply_dtype(self._target_type)
+                        expr.constant = c.interpret_as(self._target_type)
                     elif deconstify(c.dtype) != self._target_type:
                         raise TypificationError(
                             f"Type mismatch at constant {c}: Constant type did not match the context's target type\n"
diff --git a/src/pystencils/types/basic_types.py b/src/pystencils/types/basic_types.py
index b83b6d7d6..3678ea126 100644
--- a/src/pystencils/types/basic_types.py
+++ b/src/pystencils/types/basic_types.py
@@ -486,9 +486,12 @@ class PsBoolType(PsScalarType):
         return np.dtype(PsBoolType.NUMPY_TYPE)
 
     def create_literal(self, value: Any) -> str:
-        if value in (1, True, np.True_):
+        if not isinstance(value, self.NUMPY_TYPE):
+            raise PsTypeError(f"Given value {value} is not of required type {self.NUMPY_TYPE}")
+
+        if value == np.True_:
             return "true"
-        elif value in (0, False, np.False_):
+        elif value == np.False_:
             return "false"
         else:
             raise PsTypeError(f"Cannot create boolean literal from {value}")
@@ -560,6 +563,17 @@ class PsIntegerType(PsScalarType, ABC):
         unsigned_suffix = "" if self.signed else "u"
         #   TODO: cast literal to correct type?
         return str(value) + unsigned_suffix
+    
+    def create_constant(self, value: Any) -> Any:
+        np_type = self.NUMPY_TYPES[self._width]
+
+        if isinstance(value, (int, np.integer)):
+            iinfo = np.iinfo(np_type)  # type: ignore
+            if value < iinfo.min or value > iinfo.max:
+                raise PsTypeError(f"Could not interpret {value} as {self}: Value is out of bounds.")
+            return np_type(value)
+
+        raise PsTypeError(f"Could not interpret {value} as {repr(self)}")
 
     def __eq__(self, other: object) -> bool:
         if not isinstance(other, PsIntegerType):
@@ -598,17 +612,6 @@ class PsSignedIntegerType(PsIntegerType):
     def __init__(self, width: int, const: bool = False):
         super().__init__(width, True, const)
 
-    def create_constant(self, value: Any) -> Any:
-        np_type = self.NUMPY_TYPES[self._width]
-
-        if isinstance(value, int):
-            return np_type(value)
-
-        if isinstance(value, np_type):
-            return value
-
-        raise PsTypeError(f"Could not interpret {value} as {repr(self)}")
-
 
 @final
 class PsUnsignedIntegerType(PsIntegerType):
@@ -626,17 +629,6 @@ class PsUnsignedIntegerType(PsIntegerType):
     def __init__(self, width: int, const: bool = False):
         super().__init__(width, False, const)
 
-    def create_constant(self, value: Any) -> Any:
-        np_type = self.NUMPY_TYPES[self._width]
-
-        if isinstance(value, int) and value >= 0:
-            return np_type(value)
-
-        if isinstance(value, np_type):
-            return value
-
-        raise PsTypeError(f"Could not interpret {value} as {repr(self)}")
-
 
 @final
 class PsIeeeFloatType(PsScalarType):
@@ -698,12 +690,12 @@ class PsIeeeFloatType(PsScalarType):
     def create_constant(self, value: Any) -> Any:
         np_type = self.NUMPY_TYPES[self._width]
 
-        if isinstance(value, int) or isinstance(value, float):
+        if isinstance(value, (int, float, np.floating)):
+            finfo = np.finfo(np_type)  # type: ignore
+            if value < finfo.min or value > finfo.max:
+                raise PsTypeError(f"Could not interpret {value} as {self}: Value is out of bounds.")
             return np_type(value)
 
-        if isinstance(value, np_type):
-            return value
-
         raise PsTypeError(f"Could not interpret {value} as {repr(self)}")
 
     def __eq__(self, other: object) -> bool:
diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py
index cb7e5561f..ef746c614 100644
--- a/tests/nbackend/kernelcreation/test_typification.py
+++ b/tests/nbackend/kernelcreation/test_typification.py
@@ -2,10 +2,13 @@ import pytest
 import sympy as sp
 import numpy as np
 
+from typing import cast
+
 from pystencils import Assignment, TypedSymbol, Field, FieldType
 
 from pystencils.backend.ast.structural import PsDeclaration
 from pystencils.backend.ast.expressions import PsConstantExpr, PsSymbolExpr, PsBinOp
+from pystencils.backend.constants import PsConstant
 from pystencils.types import constify
 from pystencils.types.quick import Fp, create_numeric_type
 from pystencils.backend.kernelcreation.context import KernelCreationContext
@@ -35,6 +38,7 @@ def test_typify_simple():
     assert isinstance(fasm, PsDeclaration)
 
     def check(expr):
+        assert expr.dtype == ctx.default_dtype
         match expr:
             case PsConstantExpr(cs):
                 assert cs.value == 2
@@ -83,6 +87,7 @@ def test_contextual_typing():
     expr = typify(expr)
 
     def check(expr):
+        assert expr.dtype == ctx.default_dtype
         match expr:
             case PsConstantExpr(cs):
                 assert cs.value in (2, 3, -4)
@@ -184,12 +189,16 @@ def test_typify_integer_binops_in_floating_context():
         expr = typify(expr)
 
 
-def test_regression_typify_constants():
+def test_typify_constant_clones():
     ctx = KernelCreationContext(default_dtype=Fp(32))
-    freeze = FreezeExpressions(ctx)
     typify = Typifier(ctx)
 
-    x, y = sp.symbols("x, y")
-    expr = (-x - y) ** 2
+    c = PsConstantExpr(PsConstant(3.0))
+    x = PsSymbolExpr(ctx.get_symbol("x"))
+    expr = c + x
+    expr_clone = expr.clone()
 
-    typify(freeze(expr))  # just test that no error is raised
+    expr = typify(expr)
+    
+    assert expr_clone.operand1.dtype is None
+    assert cast(PsConstantExpr, expr_clone.operand1).constant.dtype is None
diff --git a/tests/nbackend/test_constant_folding.py b/tests/nbackend/test_constant_folding.py
deleted file mode 100644
index ee214ff53..000000000
--- a/tests/nbackend/test_constant_folding.py
+++ /dev/null
@@ -1,26 +0,0 @@
-#   TODO: Reimplement for constant folder
-# import pytest
-
-# from pystencils.types.quick import *
-# from pystencils.backend.constants import PsConstant
-
-
-# @pytest.mark.parametrize("width", (8, 16, 32, 64))
-# def test_constant_folding_int(width):
-#     folder = ConstantFoldingMapper()
-
-#     expr = pb.Sum(
-#         (
-#             PsTypedConstant(13, UInt(width)),
-#             PsTypedConstant(5, UInt(width)),
-#             PsTypedConstant(3, UInt(width)),
-#         )
-#     )
-
-#     assert folder(expr) == PsTypedConstant(21, UInt(width))
-
-#     expr = pb.Product(
-#         (PsTypedConstant(-1, SInt(width)), PsTypedConstant(41, SInt(width)))
-#     ) - PsTypedConstant(12, SInt(width))
-
-#     assert folder(expr) == PsTypedConstant(-53, SInt(width))
diff --git a/tests/nbackend/test_constants.py b/tests/nbackend/test_constants.py
new file mode 100644
index 000000000..93c772e60
--- /dev/null
+++ b/tests/nbackend/test_constants.py
@@ -0,0 +1,79 @@
+import numpy as np
+import pytest
+
+from pystencils.types import PsTypeError
+from pystencils.backend.constants import PsConstant
+from pystencils.types.quick import Fp, Bool, UInt, SInt
+from pystencils.backend.exceptions import PsInternalCompilerError
+
+
+def test_constant_equality():
+    c1 = PsConstant(1.0, Fp(32))
+    c2 = PsConstant(1.0, Fp(32))
+
+    assert c1 == c2
+    assert hash(c1) == hash(c2)
+
+    c3 = PsConstant(1.0, Fp(64))
+    assert c1 != c3
+    assert hash(c1) != hash(c3)
+
+    c4 = c1.reinterpret_as(Fp(64))
+    assert c4 != c1
+    assert c4 == c3
+
+
+def test_interpret():
+    c1 = PsConstant(3.4, Fp(32))
+    c2 = PsConstant(3.4)
+
+    assert c2.interpret_as(Fp(32)) == c1
+
+    with pytest.raises(PsInternalCompilerError):
+        _ = c1.interpret_as(Fp(64))
+
+
+def test_boolean_constants():
+    true = PsConstant(True, Bool())
+    for val in (1, 1.0, True, np.True_):
+        assert PsConstant(val, Bool()) == true
+
+    false = PsConstant(False, Bool())
+    for val in (0, 0.0, False, np.False_):
+        assert PsConstant(val, Bool()) == false
+
+    with pytest.raises(PsTypeError):
+        PsConstant(1.1, Bool())
+
+
+def test_integer_bounds():
+    #  should not throw:
+    for val in (255, np.uint8(255), np.int16(255), np.int64(255)):
+        _ = PsConstant(val, UInt(8))
+
+    for val in (-128, np.int16(-128), np.int64(-128)):
+        _ = PsConstant(val, SInt(8))
+    
+    #  should throw:
+    for val in (256, np.int16(256), np.int64(256)):
+        with pytest.raises(PsTypeError):
+            _ = PsConstant(val, UInt(8))
+
+    for val in (-42, np.int32(-42)):
+        with pytest.raises(PsTypeError):
+            _ = PsConstant(val, UInt(8))
+
+    for val in (-129, np.int16(-129), np.int64(-129)):
+        with pytest.raises(PsTypeError):
+            _ = PsConstant(val, SInt(8))
+
+
+def test_floating_bounds():
+    for val in (5.1e4, -5.9e4):
+        _ = PsConstant(val, Fp(16))
+        _ = PsConstant(val, Fp(32))
+        _ = PsConstant(val, Fp(64))
+
+    for val in (8.1e5, -7.6e5):
+        with pytest.raises(PsTypeError):
+            _ = PsConstant(val, Fp(16))
diff --git a/tests/nbackend/types/test_constants.py b/tests/nbackend/types/test_constants.py
deleted file mode 100644
index 4d948e4e3..000000000
--- a/tests/nbackend/types/test_constants.py
+++ /dev/null
@@ -1,80 +0,0 @@
-# import pytest
-
-# TODO: Re-implement for constant folder
-# from pystencils.types.quick import *
-# from pystencils.types import PsTypeError
-# from pystencils.backend.typed_expressions import PsTypedConstant
-
-
-# @pytest.mark.parametrize("width", (8, 16, 32, 64))
-# def test_integer_constants(width):
-#     dtype = SInt(width)
-#     a = PsTypedConstant(42, dtype)
-#     b = PsTypedConstant(2, dtype)
-
-#     assert a + b == PsTypedConstant(44, dtype)
-#     assert a - b == PsTypedConstant(40, dtype)
-#     assert a * b == PsTypedConstant(84, dtype)
-
-#     assert a - b != PsTypedConstant(-12, dtype)
-
-#     #   Typed constants only compare to themselves
-#     assert a + b != 44
-
-
-# @pytest.mark.parametrize("width", (32, 64))
-# def test_float_constants(width):
-#     a = PsTypedConstant(32.0, Fp(width))
-#     b = PsTypedConstant(0.5, Fp(width))
-#     c = PsTypedConstant(2.0, Fp(width))
-
-#     assert a + b == PsTypedConstant(32.5, Fp(width))
-#     assert a * b == PsTypedConstant(16.0, Fp(width))
-#     assert a - b == PsTypedConstant(31.5, Fp(width))
-#     assert a / c == PsTypedConstant(16.0, Fp(width))
-
-
-# def test_illegal_ops():
-#     #   Cannot interpret negative numbers as unsigned types
-#     with pytest.raises(PsTypeError):
-#         _ = PsTypedConstant(-3, UInt(32))
-
-#     #   Mixed ops are illegal
-#     with pytest.raises(PsTypeError):
-#         _ = PsTypedConstant(32.0, Fp(32)) + PsTypedConstant(2, UInt(32))
-
-#     with pytest.raises(PsTypeError):
-#         _ = PsTypedConstant(32.0, Fp(32)) - PsTypedConstant(2, UInt(32))
-
-#     with pytest.raises(PsTypeError):
-#         _ = PsTypedConstant(32.0, Fp(32)) * PsTypedConstant(2, UInt(32))
-
-#     with pytest.raises(PsTypeError):
-#         _ = PsTypedConstant(32.0, Fp(32)) / PsTypedConstant(2, UInt(32))
-
-
-# @pytest.mark.parametrize("width", (8, 16, 32, 64))
-# def test_unsigned_integer_division(width):
-#     a = PsTypedConstant(8, UInt(width))
-#     b = PsTypedConstant(3, UInt(width))
-
-#     assert a / b == PsTypedConstant(2, UInt(width))
-#     assert a % b == PsTypedConstant(2, UInt(width))
-
-
-# @pytest.mark.parametrize("width", (8, 16, 32, 64))
-# def test_signed_integer_division(width):
-#     five = PsTypedConstant(5, SInt(width))
-#     two = PsTypedConstant(2, SInt(width))
-
-#     assert five / two == PsTypedConstant(2, SInt(width))
-#     assert five % two == PsTypedConstant(1, SInt(width))
-
-#     assert (- five) / two == PsTypedConstant(-2, SInt(width))
-#     assert (- five) % two == PsTypedConstant(-1, SInt(width))
-
-#     assert five / (- two) == PsTypedConstant(-2, SInt(width))
-#     assert five % (- two) == PsTypedConstant(1, SInt(width))
-
-#     assert (- five) / (- two) == PsTypedConstant(2, SInt(width))
-#     assert (- five) % (- two) == PsTypedConstant(-1, SInt(width))
-- 
GitLab