From 62e6ddf8b6bd769c2c6088a10b80e55bf4d174d7 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Thu, 11 Jan 2024 16:56:47 +0100
Subject: [PATCH] PsTypedConstant: Creation and Arithmetic, Part One

---
 pystencils/nbackend/ast/nodes.py              |  41 +++--
 pystencils/nbackend/typed_expressions.py      | 165 ++++++++++++++----
 pystencils/nbackend/types/__init__.py         |  11 +-
 pystencils/nbackend/types/basic_types.py      | 136 ++++++++++++++-
 pystencils/nbackend/types/exception.py        |   4 +
 pystencils/nbackend/types/parsing.py          |   3 -
 pystencils/nbackend/types/quick.py            |   4 +-
 .../nbackend/test_constant_folding.py         |  45 +++++
 .../nbackend/types/test_constants.py          |  61 +++++++
 9 files changed, 405 insertions(+), 65 deletions(-)
 create mode 100644 pystencils/nbackend/types/exception.py
 create mode 100644 pystencils_tests/nbackend/test_constant_folding.py
 create mode 100644 pystencils_tests/nbackend/types/test_constants.py

diff --git a/pystencils/nbackend/ast/nodes.py b/pystencils/nbackend/ast/nodes.py
index adb3e2a01..40200ff38 100644
--- a/pystencils/nbackend/ast/nodes.py
+++ b/pystencils/nbackend/ast/nodes.py
@@ -19,7 +19,7 @@ def failing_cast(target: type, obj: T):
 
 class PsAstNode(ABC):
     """Base class for all nodes in the pystencils AST.
-    
+
     This base class provides a common interface to inspect and update the AST's branching structure.
     The four methods `num_children`, `children`, `get_child` and `set_child` must be implemented by
     each subclass.
@@ -150,7 +150,7 @@ class PsAssignment(PsAstNode):
         return (self._lhs, self._rhs)[idx]
 
     def set_child(self, idx: int, c: PsAstNode):
-        idx = [0, 1][idx]   # trick to normalize index
+        idx = [0, 1][idx]  # trick to normalize index
         if idx == 0:
             self._lhs = failing_cast(PsLvalueExpr, c)
         elif idx == 1:
@@ -180,7 +180,7 @@ class PsDeclaration(PsAssignment):
         self._lhs = lvalue
 
     def set_child(self, idx: int, c: PsAstNode):
-        idx = [0, 1][idx]   # trick to normalize index
+        idx = [0, 1][idx]  # trick to normalize index
         if idx == 0:
             self._lhs = failing_cast(PsSymbolExpr, c)
         elif idx == 1:
@@ -190,12 +190,14 @@ class PsDeclaration(PsAssignment):
 
 
 class PsLoop(PsAstNode):
-    def __init__(self,
-                 ctr: PsSymbolExpr,
-                 start: PsExpression,
-                 stop: PsExpression,
-                 step: PsExpression,
-                 body: PsBlock):
+    def __init__(
+        self,
+        ctr: PsSymbolExpr,
+        start: PsExpression,
+        stop: PsExpression,
+        step: PsExpression,
+        body: PsBlock,
+    ):
         self._ctr = ctr
         self._start = start
         self._stop = stop
@@ -205,7 +207,7 @@ class PsLoop(PsAstNode):
     @property
     def counter(self) -> PsSymbolExpr:
         return self._ctr
-    
+
     @counter.setter
     def counter(self, expr: PsSymbolExpr):
         self._ctr = expr
@@ -251,13 +253,18 @@ class PsLoop(PsAstNode):
     def get_child(self, idx: int):
         return (self._ctr, self._start, self._stop, self._step, self._body)[idx]
 
-
     def set_child(self, idx: int, c: PsAstNode):
         idx = list(range(5))[idx]
         match idx:
-            case 0: self._ctr = failing_cast(PsSymbolExpr, c)
-            case 1: self._start = failing_cast(PsExpression, c)
-            case 2: self._stop = failing_cast(PsExpression, c)
-            case 3: self._step = failing_cast(PsExpression, c)
-            case 4: self._body = failing_cast(PsBlock, c)
-            case _: assert False, "unreachable code"
+            case 0:
+                self._ctr = failing_cast(PsSymbolExpr, c)
+            case 1:
+                self._start = failing_cast(PsExpression, c)
+            case 2:
+                self._stop = failing_cast(PsExpression, c)
+            case 3:
+                self._step = failing_cast(PsExpression, c)
+            case 4:
+                self._body = failing_cast(PsBlock, c)
+            case _:
+                assert False, "unreachable code"
diff --git a/pystencils/nbackend/typed_expressions.py b/pystencils/nbackend/typed_expressions.py
index 5d0c0fc89..c3e4488e8 100644
--- a/pystencils/nbackend/typed_expressions.py
+++ b/pystencils/nbackend/typed_expressions.py
@@ -1,11 +1,19 @@
 from __future__ import annotations
 
 from functools import reduce
-from typing import TypeAlias, Union, Any, Tuple
+from typing import TypeAlias, Union, Any, Tuple, Callable
+import operator
 
 import pymbolic.primitives as pb
 
-from .types import PsAbstractType, PsScalarType, PsPointerType, constify
+from .types import (
+    PsAbstractType,
+    PsScalarType,
+    PsNumericType,
+    PsPointerType,
+    constify,
+    PsTypeError,
+)
 
 
 class PsTypedVariable(pb.Variable):
@@ -96,7 +104,7 @@ class PsArrayAccess(pb.Subscript):
     @property
     def array(self) -> PsArray:
         return self._base_ptr.array
-    
+
     @property
     def dtype(self) -> PsAbstractType:
         """Data type of this expression, i.e. the element type of the underlying array"""
@@ -107,50 +115,145 @@ PsLvalue: TypeAlias = Union[PsTypedVariable, PsArrayAccess]
 
 
 class PsTypedConstant:
+    """Represents typed constants occuring in the pystencils AST.
+
+    A Note On Divisions
+    -------------------
+
+    The semantics of divisions in C and Python differ greatly.
+    Python has two division operators: `/` (`truediv`) and `//` (`floordiv`).
+    `truediv` is pure floating-point division, and so applied to floating-point numbers maps exactly to
+    floating-point division in C, but not when applied to integers.
+    `floordiv` has no C equivalent:
+    While `floordiv` performs euclidean division and always rounds its result
+    downward (`3 // 2 == 1`, and `-3 // 2 = -2`),
+    the C `/` operator on integers always rounds toward zero (in C, `-3 / 2 = -1`.)
+
+    The same applies to the `%` operator:
+    In Python, `%` computes the euclidean modulus (e.g. `-3 % 2 = 1`),
+    while in C, `%` computes the remainder (e.g. `-3 % 2 = -1`).
+    These two differ whenever negative numbers are involved.
+
+    Pymbolic provides `Quotient` to model Python's `/`, `FloorDiv` to model `//`, and `Remainder` to model `%`.
+    The last one is a misnomer: it should instead be called `Modulus`.
+
+    Since the pystencils backend has to accurately capture the behaviour of C,
+    the behaviour of `/` is overridden in `PsTypedConstant`.
+    In a floating-point context, it behaves as usual, while in an integer context,
+    it implements C-style integer division.
+    Similarily, `%` is only legal in integer contexts, where it implements the C-style remainder.
+    Usage of `//` and the pymbolic `FloorDiv` is illegal.
+    """
+
     @staticmethod
-    def _cast(value, target_dtype: PsAbstractType):
-        if isinstance(value, PsTypedConstant):
-            if value._dtype != target_dtype:
-                raise ValueError(
-                    f"Incompatible types: {value._dtype} and {target_dtype}"
-                )
-            return value
-
-        # TODO check legality
-        return PsTypedConstant(value, target_dtype)
-
-    def __init__(self, value, dtype: PsAbstractType):
-        """Represents typed constants occuring in the pystencils AST"""
-        if isinstance(dtype, PsScalarType):
-            dtype = constify(dtype)
-            self._value = value # todo: cast to given type
-        else:
+    def try_create(value: Any, dtype: PsNumericType):
+        try:
+            return PsTypedConstant(value, dtype)
+        except PsTypeError:
+            return None
+
+    def __init__(self, value: Any, dtype: PsNumericType):
+        """Create a new `PsTypedConstant`.
+
+        The constructor of `PsTypedConstant` will first convert the given `dtype` to its const version.
+        The given `value` will then be interpreted as that data type. The constructor will fail with an
+        exception if that is not possible.
+        """
+        if not isinstance(dtype, PsNumericType):
             raise ValueError(f"Cannot create constant of type {dtype}")
 
-        self._dtype = dtype
+        self._dtype = constify(dtype)
+        self._value = self._dtype.create_constant(value)
 
     def __str__(self) -> str:
         return str(self._value)
 
+    def __repr__(self) -> str:
+        return f"PsTypedConstant( {self._value}, {repr(self._dtype)} )"
+
+    @staticmethod
+    def _fix(v: Any, dtype: PsNumericType) -> PsTypedConstant:
+        if not isinstance(v, PsTypedConstant):
+            return PsTypedConstant(v, dtype)
+        else:
+            return v
+
+    @staticmethod
+    def _bin_op(
+        lhs: PsTypedConstant, rhs: PsTypedConstant, op: Callable[[Any, Any], Any]
+    ) -> PsTypedConstant:
+        """Backend for binary operators. Never call directly!"""
+
+        if lhs._dtype != rhs._dtype:
+            raise PsTypeError(
+                f"Incompatible operand types in constant folding: {lhs._dtype} and {rhs._dtype}"
+            )
+
+        try:
+            return PsTypedConstant(op(lhs._value, rhs._value), lhs._dtype)
+        except PsTypeError:
+            raise PsTypeError(
+                f"Invalid operation in constant folding: {op.__name__}( {repr(lhs)}, {repr(rhs)}  )"
+            )
+
     def __add__(self, other: Any):
-        return NotImplemented # todo
-        # other = PsTypedConstant._cast(other, self._dtype)
+        return PsTypedConstant._bin_op(
+            self, PsTypedConstant._fix(other, self._dtype), operator.add
+        )
 
-        # return PsTypedConstant(self._value + other._value, self._dtype)
+    def __radd__(self, other: Any):
+        return PsTypedConstant._bin_op(
+            PsTypedConstant._fix(other, self._dtype), self, operator.add
+        )
 
     def __mul__(self, other: Any):
-        return NotImplemented # todo
-        # other = PsTypedConstant._cast(other, self._dtype)
+        return PsTypedConstant._bin_op(
+            self, PsTypedConstant._fix(other, self._dtype), operator.mul
+        )
 
-        # return PsTypedConstant(self._value * other._value, self._dtype)
+    def __rmul__(self, other: Any):
+        return PsTypedConstant._bin_op(
+            PsTypedConstant._fix(other, self._dtype), self, operator.mul
+        )
 
     def __sub__(self, other: Any):
-        return NotImplemented # todo
-        # other = PsTypedConstant._cast(other, self._dtype)
+        return PsTypedConstant._bin_op(
+            self, PsTypedConstant._fix(other, self._dtype), operator.sub
+        )
+
+    def __rsub__(self, other: Any):
+        return PsTypedConstant._bin_op(
+            PsTypedConstant._fix(other, self._dtype), self, operator.sub
+        )
+
+    def __truediv__(self, other: Any):
+        other2 = PsTypedConstant._fix(other, self._dtype)
+        if self._dtype.is_float():
+            return PsTypedConstant._bin_op(self, other2, operator.truediv)
+        else:
+            return NotImplemented  # todo: C integer division
+
+    def __rtruediv__(self, other: Any):
+        other2 = PsTypedConstant._fix(other, self._dtype)
+        if self._dtype.is_float():
+            return PsTypedConstant._bin_op(other2, self, operator.truediv)
+        else:
+            return NotImplemented  # todo: C integer division
+
+    def __mod__(self, other: Any):
+        return NotImplemented  # todo: C integer division
+
+    def __neg__(self):
+        return PsTypedConstant(-self._value, self._dtype)
+
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, PsTypedConstant):
+            return False
 
-        # return PsTypedConstant(self._value - other._value, self._dtype)
+        return self._dtype == other._dtype and self._value == other._value
 
-    # TODO: Remaining operators
+    def __hash__(self) -> int:
+        return hash((self._value, self._dtype))
 
 
 pb.VALID_CONSTANT_CLASSES += (PsTypedConstant,)
diff --git a/pystencils/nbackend/types/__init__.py b/pystencils/nbackend/types/__init__.py
index 5319afa33..1f15c4516 100644
--- a/pystencils/nbackend/types/__init__.py
+++ b/pystencils/nbackend/types/__init__.py
@@ -1,6 +1,7 @@
 from .basic_types import (
     PsAbstractType,
     PsCustomType,
+    PsNumericType,
     PsScalarType,
     PsPointerType,
     PsIntegerType,
@@ -8,18 +9,22 @@ from .basic_types import (
     PsSignedIntegerType,
     PsIeeeFloatType,
     constify,
-    deconstify
+    deconstify,
 )
 
+from .exception import PsTypeError
+
 __all__ = [
     "PsAbstractType",
     "PsCustomType",
-    "PsScalarType",
     "PsPointerType",
+    "PsNumericType",
+    "PsScalarType",
     "PsIntegerType",
     "PsUnsignedIntegerType",
     "PsSignedIntegerType",
     "PsIeeeFloatType",
     "constify",
     "deconstify",
-]
\ No newline at end of file
+    "PsTypeError",
+]
diff --git a/pystencils/nbackend/types/basic_types.py b/pystencils/nbackend/types/basic_types.py
index dcdb300ce..0a56f2f57 100644
--- a/pystencils/nbackend/types/basic_types.py
+++ b/pystencils/nbackend/types/basic_types.py
@@ -1,8 +1,12 @@
 from __future__ import annotations
 from abc import ABC, abstractmethod
-from typing import final, TypeVar
+from typing import final, TypeVar, Any
 from copy import copy
 
+import numpy as np
+
+from .exception import PsTypeError
+
 
 class PsAbstractType(ABC):
     """Base class for all pystencils types.
@@ -51,8 +55,9 @@ class PsAbstractType(ABC):
     def __str__(self) -> str:
         return self._c_string()
 
+    @abstractmethod
     def __hash__(self) -> int:
-        return hash(self._c_string())
+        ...
 
 
 class PsCustomType(PsAbstractType):
@@ -73,6 +78,9 @@ class PsCustomType(PsAbstractType):
             return False
         return self._base_equal(other) and self._name == other._name
 
+    def __hash__(self) -> int:
+        return hash(("PsCustomType", self._name, self._const))
+
     def _c_string(self) -> str:
         return f"{self._const_string()} {self._name}"
 
@@ -106,6 +114,9 @@ class PsPointerType(PsAbstractType):
             return False
         return self._base_equal(other) and self._base_type == other._base_type
 
+    def __hash__(self) -> int:
+        return hash(("PsPointerType", self._base_type, self._restrict, self._const))
+
     def _c_string(self) -> str:
         base_str = self._base_type._c_string()
         return f"{base_str} * {self._const_string()}"
@@ -114,8 +125,51 @@ class PsPointerType(PsAbstractType):
         return f"PsPointerType( {repr(self.base_type)}, const={self.const} )"
 
 
-class PsScalarType(PsAbstractType, ABC):
-    """Class to model scalar types"""
+class PsNumericType(PsAbstractType, ABC):
+    """Class to model numeric types, which are all types that may occur at the top-level inside
+    arithmetic-logical expressions.
+
+    Constants
+    ---------
+
+    Every numeric type has to act as a factory for compile-time constants of that type.
+    The `PsTypedConstant` class relies on `create_constant` to instantiate constants
+    of a given numeric type. The object returned by `create_constant` must implement the
+    necessary arithmetic operations, and its arithmetic behaviour must match the given type.
+
+    `create_constant` should fail whenever its input cannot safely be interpreted as the given
+    type. As for which interpretations are considered 'safe', it should be as restrictive as possible.
+    However, `create_constant` must *never* fail for the literals `0`, `1` and `-1`.
+    """
+
+    @abstractmethod
+    def create_constant(self, value: Any) -> Any:
+        """
+        Create the internal representation of a constant with this type.
+
+        Raises:
+            PsTypeError: If the given value cannot be interpreted in this type.
+        """
+
+    @abstractmethod
+    def is_int(self) -> bool:
+        ...
+
+    @abstractmethod
+    def is_sint(self) -> bool:
+        ...
+
+    @abstractmethod
+    def is_uint(self) -> bool:
+        ...
+
+    @abstractmethod
+    def is_float(self) -> bool:
+        ...
+
+
+class PsScalarType(PsNumericType, ABC):
+    """Class to model scalar numeric types."""
 
     def is_int(self) -> bool:
         return isinstance(self, PsIntegerType)
@@ -130,7 +184,7 @@ class PsScalarType(PsAbstractType, ABC):
         return isinstance(self, PsIeeeFloatType)
 
 
-class PsIntegerType(PsAbstractType, ABC):
+class PsIntegerType(PsScalarType, ABC):
     """Class to model signed and unsigned integer types.
 
     `PsIntegerType` cannot be instantiated on its own, but only through `PsSignedIntegerType`
@@ -170,6 +224,9 @@ class PsIntegerType(PsAbstractType, ABC):
             and self._signed == other._signed
         )
 
+    def __hash__(self) -> int:
+        return hash(("PsIntegerType", self._width, self._signed, self._const))
+
     def _c_string(self) -> str:
         prefix = "" if self._signed else "u"
         return f"{self._const_string()} {prefix}int{self._width}_t"
@@ -184,9 +241,27 @@ class PsSignedIntegerType(PsIntegerType):
 
     __match_args__ = ("width",)
 
+    NUMPY_TYPES = {
+        8: np.int8,
+        16: np.int16,
+        32: np.int32,
+        64: np.int64,
+    }
+
     def __init__(self, width: int, const: bool = False):
         super().__init__(width, True, const)
 
+    def create_constant(self, value: Any) -> Any:
+        np_type = self.NUMPY_TYPES[self._width]
+
+        if isinstance(value, int):
+            return np_type(value)
+
+        if isinstance(value, np_type):
+            return value
+
+        raise PsTypeError(f"Could not interpret {value} as {repr(self)}")
+
 
 @final
 class PsUnsignedIntegerType(PsIntegerType):
@@ -194,18 +269,42 @@ class PsUnsignedIntegerType(PsIntegerType):
 
     __match_args__ = ("width",)
 
+    NUMPY_TYPES = {
+        8: np.uint8,
+        16: np.uint16,
+        32: np.uint32,
+        64: np.uint64,
+    }
+
     def __init__(self, width: int, const: bool = False):
-        super().__init__(width, True, const)
+        super().__init__(width, False, const)
+
+    def create_constant(self, value: Any) -> Any:
+        np_type = self.NUMPY_TYPES[self._width]
+
+        if isinstance(value, int) and value >= 0:
+            return np_type(value)
+
+        if isinstance(value, np_type):
+            return value
+
+        raise PsTypeError(f"Could not interpret {value} as {repr(self)}")
 
 
 @final
-class PsIeeeFloatType(PsAbstractType):
+class PsIeeeFloatType(PsScalarType):
     """Class to model IEEE-754 floating point data types"""
 
     __match_args__ = ("width",)
 
     SUPPORTED_WIDTHS = (16, 32, 64)
 
+    NUMPY_TYPES = {
+        16: np.float16,
+        32: np.float32,
+        64: np.float64,
+    }
+
     def __init__(self, width: int, const: bool = False):
         if width not in self.SUPPORTED_WIDTHS:
             raise ValueError(
@@ -219,11 +318,28 @@ class PsIeeeFloatType(PsAbstractType):
     def width(self) -> int:
         return self._width
 
+    def create_constant(self, value: Any) -> Any:
+        np_type = self.NUMPY_TYPES[self._width]
+
+        if isinstance(value, int) and value in (0, 1, -1):
+            return np_type(value)
+
+        if isinstance(value, float):
+            return np_type(value)
+
+        if isinstance(value, np_type):
+            return value
+
+        raise PsTypeError(f"Could not interpret {value} as {repr(self)}")
+
     def __eq__(self, other: object) -> bool:
         if not isinstance(other, PsIeeeFloatType):
             return False
         return self._base_equal(other) and self._width == other._width
 
+    def __hash__(self) -> int:
+        return hash(("PsIeeeFloatType", self._width, self._const))
+
     def _c_string(self) -> str:
         match self._width:
             case 32:
@@ -239,13 +355,15 @@ class PsIeeeFloatType(PsAbstractType):
 
 T = TypeVar("T", bound=PsAbstractType)
 
-def constify(t: T):
+
+def constify(t: T) -> T:
     """Adds the const qualifier to a given type."""
     t_copy = copy(t)
     t_copy._const = True
     return t_copy
 
-def deconstify(t: T):
+
+def deconstify(t: T) -> T:
     """Removes the const qualifier from a given type."""
     t_copy = copy(t)
     t_copy._const = False
diff --git a/pystencils/nbackend/types/exception.py b/pystencils/nbackend/types/exception.py
new file mode 100644
index 000000000..7c0cb97af
--- /dev/null
+++ b/pystencils/nbackend/types/exception.py
@@ -0,0 +1,4 @@
+
+
+class PsTypeError(Exception):
+    """Indicates a type error in the pystencils AST."""
diff --git a/pystencils/nbackend/types/parsing.py b/pystencils/nbackend/types/parsing.py
index a6f2ac42f..8a5e687aa 100644
--- a/pystencils/nbackend/types/parsing.py
+++ b/pystencils/nbackend/types/parsing.py
@@ -2,10 +2,7 @@ import numpy as np
 
 from .basic_types import (
     PsAbstractType,
-    PsCustomType,
-    PsScalarType,
     PsPointerType,
-    PsIntegerType,
     PsUnsignedIntegerType,
     PsSignedIntegerType,
     PsIeeeFloatType,
diff --git a/pystencils/nbackend/types/quick.py b/pystencils/nbackend/types/quick.py
index 8c6b4d006..b1da0c5e2 100644
--- a/pystencils/nbackend/types/quick.py
+++ b/pystencils/nbackend/types/quick.py
@@ -31,8 +31,8 @@ def make_type(type_spec: UserTypeSpec) -> PsAbstractType:
             - `int` becomes a signed 64-bit integer
             - `float` becomes a double-precision IEEE-754 float
             - No others are supported at the moment
-        - Supported Numpy scalar data types (see https://numpy.org/doc/stable/reference/arrays.scalars.html) are converted to pystencils
-          scalar data types
+        - Supported Numpy scalar data types (see https://numpy.org/doc/stable/reference/arrays.scalars.html)
+          are converted to pystencils scalar data types
         - Instances of `PsAbstractType` will be returned as they are
     """
 
diff --git a/pystencils_tests/nbackend/test_constant_folding.py b/pystencils_tests/nbackend/test_constant_folding.py
new file mode 100644
index 000000000..12149c3ce
--- /dev/null
+++ b/pystencils_tests/nbackend/test_constant_folding.py
@@ -0,0 +1,45 @@
+import pytest
+
+import pymbolic.primitives as pb
+from pymbolic.mapper.constant_folder import ConstantFoldingMapper
+
+from pystencils.nbackend.types.quick import *
+from pystencils.nbackend.typed_expressions import PsTypedConstant
+
+
+@pytest.mark.parametrize("width", (8, 16, 32, 64))
+def test_constant_folding_int(width):
+    folder = ConstantFoldingMapper()
+
+    expr = pb.Sum(
+        (
+            PsTypedConstant(13, UInt(width)),
+            PsTypedConstant(5, UInt(width)),
+            PsTypedConstant(3, UInt(width)),
+        )
+    )
+
+    assert folder(expr) == PsTypedConstant(21, UInt(width))
+
+    expr = pb.Product(
+        (PsTypedConstant(-1, SInt(width)), PsTypedConstant(41, SInt(width)))
+    ) - PsTypedConstant(12, SInt(width))
+
+    assert folder(expr) == PsTypedConstant(-53, SInt(width))
+
+
+@pytest.mark.parametrize("width", (32, 64))
+def test_constant_folding_float(width):
+    folder = ConstantFoldingMapper()
+
+    expr = pb.Quotient(
+        PsTypedConstant(14.0, Fp(width)),
+        pb.Sum(
+            (
+                PsTypedConstant(2.5, Fp(width)),
+                PsTypedConstant(4.5, Fp(width)),
+            )
+        ),
+    )
+
+    assert folder(expr) == PsTypedConstant(7.0, Fp(width))
diff --git a/pystencils_tests/nbackend/types/test_constants.py b/pystencils_tests/nbackend/types/test_constants.py
new file mode 100644
index 000000000..f166a1570
--- /dev/null
+++ b/pystencils_tests/nbackend/types/test_constants.py
@@ -0,0 +1,61 @@
+import pytest
+
+from pystencils.nbackend.types.quick import *
+from pystencils.nbackend.types import PsTypeError
+from pystencils.nbackend.typed_expressions import PsTypedConstant
+
+
+@pytest.mark.parametrize("width", (8, 16, 32, 64))
+def test_integer_constants(width):
+    dtype = SInt(width)
+    a = PsTypedConstant(42, dtype)
+    b = PsTypedConstant(2, dtype)
+
+    assert a + b == PsTypedConstant(44, dtype)
+    assert a - b == PsTypedConstant(40, dtype)
+    assert a * b == PsTypedConstant(84, dtype)
+
+    assert a - b != PsTypedConstant(-12, dtype)
+
+    #   Typed constants only compare to themselves
+    assert a + b != 44
+
+
+@pytest.mark.parametrize("width", (32, 64))
+def test_float_constants(width):
+    a = PsTypedConstant(32.0, Fp(width))
+    b = PsTypedConstant(0.5, Fp(width))
+    c = PsTypedConstant(2.0, Fp(width))
+
+    assert a + b == PsTypedConstant(32.5, Fp(width))
+    assert a * b == PsTypedConstant(16.0, Fp(width))
+    assert a - b == PsTypedConstant(31.5, Fp(width))
+    assert a / c == PsTypedConstant(16.0, Fp(width))
+
+        
+def test_illegal_ops():
+    #   Cannot interpret negative numbers as unsigned types
+    with pytest.raises(PsTypeError):
+        _ = PsTypedConstant(-3, UInt(32))
+
+    #   Mixed ops are illegal
+    with pytest.raises(PsTypeError):
+        _ = PsTypedConstant(32.0, Fp(32)) + PsTypedConstant(2, UInt(32))
+
+    with pytest.raises(PsTypeError):
+        _ = PsTypedConstant(32.0, Fp(32)) - PsTypedConstant(2, UInt(32))
+
+    with pytest.raises(PsTypeError):
+        _ = PsTypedConstant(32.0, Fp(32)) * PsTypedConstant(2, UInt(32))
+
+    with pytest.raises(PsTypeError):
+        _ = PsTypedConstant(32.0, Fp(32)) / PsTypedConstant(2, UInt(32))
+
+
+@pytest.mark.parametrize("width", (8, 16, 32, 64))
+def test_integer_division(width):
+    a = PsTypedConstant(-5, SInt(width))
+    b = PsTypedConstant(2, SInt(width))
+
+    assert a / b == PsTypedConstant(-2, SInt(width))
+    assert a % b == PsTypedConstant(-1, SInt(width))
-- 
GitLab