From ac3b5e34705bbeb46cfa20b9035035f656469c7a Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Wed, 3 Apr 2024 13:11:36 +0200
Subject: [PATCH] Introduce Type Uniquing Mechanism

 - Introduce metaclass PsTypeMeta
 - Refactor __args__ protocol to exclude const
 - Move PsType and PsTypeMeta to types/meta.py
 - Rename basic_types.py to types.py
 - Adapt test cases to check for identity
---
 src/pystencils/backend/ast/expressions.py     |   4 +-
 src/pystencils/backend/constants.py           |  14 +-
 .../backend/kernelcreation/typification.py    |  10 +-
 .../backend/platforms/generic_gpu.py          |   8 +-
 src/pystencils/sympyextensions/pointers.py    |   2 +-
 src/pystencils/sympyextensions/typed_sympy.py |   2 +-
 src/pystencils/types/__init__.py              |   7 +-
 src/pystencils/types/meta.py                  | 140 +++++++++++++++
 src/pystencils/types/parsing.py               |  10 +-
 src/pystencils/types/quick.py                 |   2 +-
 .../types/{basic_types.py => types.py}        | 159 ++++--------------
 .../kernelcreation/test_typification.py       |   3 +
 tests/nbackend/types/test_types.py            |  37 ++--
 13 files changed, 227 insertions(+), 171 deletions(-)
 create mode 100644 src/pystencils/types/meta.py
 rename src/pystencils/types/{basic_types.py => types.py} (80%)

diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py
index 8a66457a9..7c743a399 100644
--- a/src/pystencils/backend/ast/expressions.py
+++ b/src/pystencils/backend/ast/expressions.py
@@ -21,7 +21,7 @@ from .astnode import PsAstNode, PsLeafMixIn
 
 class PsExpression(PsAstNode, ABC):
     """Base class for all expressions.
-    
+
     **Types:** Each expression should be annotated with its type.
     Upon construction, the `dtype` property of most expression nodes is unset;
     only constant expressions, symbol expressions, and array accesses immediately inherit their type from
@@ -271,7 +271,7 @@ class PsVectorArrayAccess(PsArrayAccess):
     @property
     def alignment(self) -> int:
         return self._alignment
-    
+
     def get_vector_type(self) -> PsVectorType:
         return cast(PsVectorType, self._dtype)
 
diff --git a/src/pystencils/backend/constants.py b/src/pystencils/backend/constants.py
index 125c1149b..b867d89d3 100644
--- a/src/pystencils/backend/constants.py
+++ b/src/pystencils/backend/constants.py
@@ -7,10 +7,10 @@ from .exceptions import PsInternalCompilerError
 
 class PsConstant:
     """Type-safe representation of typed numerical constants.
-    
+
     This class models constants in the backend representation of kernels.
     A constant may be *untyped*, in which case its ``value`` may be any Python object.
-    
+
     If the constant is *typed* (i.e. its ``dtype`` is not ``None``), its data type is used
     to check the validity of its ``value`` and to convert it into the type's internal representation.
 
@@ -36,19 +36,19 @@ class PsConstant:
 
     def interpret_as(self, dtype: PsNumericType) -> PsConstant:
         """Interprets this *untyped* constant with the given data type.
-        
+
         If this constant is already typed, raises an error.
         """
         if self._dtype is not None:
             raise PsInternalCompilerError(
                 f"Cannot interpret already typed constant {self} with type {dtype}"
             )
-        
+
         return PsConstant(self._value, dtype)
-    
+
     def reinterpret_as(self, dtype: PsNumericType) -> PsConstant:
         """Reinterprets this constant with the given data type.
-        
+
         Other than `interpret_as`, this method also works on typed constants.
         """
         return PsConstant(self._value, dtype)
@@ -60,7 +60,7 @@ class PsConstant:
     @property
     def dtype(self) -> PsNumericType | None:
         """This constant's data type, or ``None`` if it is untyped.
-        
+
         The data type of a constant always has ``const == True``.
         """
         return self._dtype
diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py
index 190cd9e23..bfecec5be 100644
--- a/src/pystencils/backend/kernelcreation/typification.py
+++ b/src/pystencils/backend/kernelcreation/typification.py
@@ -56,7 +56,7 @@ NodeT = TypeVar("NodeT", bound=PsAstNode)
 
 class TypeContext:
     """Typing context, with support for type inference and checking.
-    
+
     Instances of this class are used to propagate and check data types across expression subtrees
     of the AST. Each type context has:
 
@@ -185,7 +185,7 @@ class TypeContext:
 
     def _compatible(self, dtype: PsType):
         """Checks whether the given data type is compatible with the context's target type.
-        
+
         If the target type is ``const``, they must be equal up to const qualification;
         if the target type is not ``const``, `dtype` must match it exactly.
         """
@@ -248,7 +248,7 @@ class Typifier:
 
     Some expressions (`PsSymbolExpr`, `PsArrayAccess`) encapsulate symbols and inherit their data types, but
     not necessarily their const-qualification.
-    A symbol with non-``const`` type may occur in a `PsSymbolExpr` with ``const`` type, 
+    A symbol with non-``const`` type may occur in a `PsSymbolExpr` with ``const`` type,
     and an array base pointer with non-``const`` base type may be nested in a ``const`` `PsArrayAccess`,
     but not vice versa.
     """
@@ -321,7 +321,7 @@ class Typifier:
 
     def visit_expr(self, expr: PsExpression, tc: TypeContext) -> None:
         """Recursive processing of expression nodes.
-        
+
         This method opens, expands, and closes typing contexts according to the respective expression's
         typing rules. It may add or check restrictions only when opening or closing a type context.
 
@@ -394,7 +394,7 @@ class Typifier:
                         f"Unable to determine type of argument to AddressOf: {arg}"
                     )
 
-                ptr_type = PsPointerType(arg_tc.target_type, True)
+                ptr_type = PsPointerType(arg_tc.target_type, const=True)
                 tc.apply_dtype(ptr_type, expr)
 
             case PsLookup(aggr, member_name):
diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py
index 64c0cd3e9..8b1d8f783 100644
--- a/src/pystencils/backend/platforms/generic_gpu.py
+++ b/src/pystencils/backend/platforms/generic_gpu.py
@@ -1,5 +1,5 @@
 from pystencils.backend.functions import CFunction, PsMathFunction
-from pystencils.types.basic_types import PsType
+from pystencils.types.types import PsType
 from .platform import Platform
 
 from ..kernelcreation.iteration_space import (
@@ -56,8 +56,10 @@ class GenericGpu(Platform):
         ]
 
         return indices[:dim]
-    
-    def select_function(self, math_function: PsMathFunction, dtype: PsType) -> CFunction:
+
+    def select_function(
+        self, math_function: PsMathFunction, dtype: PsType
+    ) -> CFunction:
         raise NotImplementedError()
 
     #   Internals
diff --git a/src/pystencils/sympyextensions/pointers.py b/src/pystencils/sympyextensions/pointers.py
index 130338c99..a814f941e 100644
--- a/src/pystencils/sympyextensions/pointers.py
+++ b/src/pystencils/sympyextensions/pointers.py
@@ -26,6 +26,6 @@ class AddressOf(sp.Function):
     @property
     def dtype(self):
         if hasattr(self.args[0], 'dtype'):
-            return PsPointerType(self.args[0].dtype, const=True, restrict=True)
+            return PsPointerType(self.args[0].dtype, restrict=True, const=True)
         else:
             raise ValueError(f'pystencils supports only non void pointers. Current address_of type: {self.args[0]}')
diff --git a/src/pystencils/sympyextensions/typed_sympy.py b/src/pystencils/sympyextensions/typed_sympy.py
index 541f9aed7..71caf5a2f 100644
--- a/src/pystencils/sympyextensions/typed_sympy.py
+++ b/src/pystencils/sympyextensions/typed_sympy.py
@@ -172,7 +172,7 @@ class FieldPointerSymbol(TypedSymbol):
 
     def __new_stage2__(cls, field_name, field_dtype: PsType, const: bool):
         name = f"_data_{field_name}"
-        dtype = PsPointerType(field_dtype, const=const, restrict=True)
+        dtype = PsPointerType(field_dtype, restrict=True, const=const)
         obj = super(FieldPointerSymbol, cls).__xnew__(cls, name, dtype)
         obj.field_name = field_name
         return obj
diff --git a/src/pystencils/types/__init__.py b/src/pystencils/types/__init__.py
index 2c7bac59b..05e14fa70 100644
--- a/src/pystencils/types/__init__.py
+++ b/src/pystencils/types/__init__.py
@@ -8,8 +8,9 @@ For more user-friendly and less verbose access to the type modelling system, ref
 the `pystencils.types.quick` submodule. 
 """
 
-from .basic_types import (
-    PsType,
+from .meta import PsType, constify, deconstify
+
+from .types import (
     PsCustomType,
     PsStructType,
     PsNumericType,
@@ -23,8 +24,6 @@ from .basic_types import (
     PsUnsignedIntegerType,
     PsSignedIntegerType,
     PsIeeeFloatType,
-    constify,
-    deconstify,
 )
 
 from .quick import UserTypeSpec, create_type, create_numeric_type
diff --git a/src/pystencils/types/meta.py b/src/pystencils/types/meta.py
new file mode 100644
index 000000000..4a115ef97
--- /dev/null
+++ b/src/pystencils/types/meta.py
@@ -0,0 +1,140 @@
+from __future__ import annotations
+
+from abc import ABCMeta, abstractmethod
+from typing import TypeVar, Any, cast
+import numpy as np
+
+
+class PsTypeMeta(ABCMeta):
+
+    _instances: dict[Any, PsType] = dict()
+
+    def __call__(cls, *args: Any, const: bool = False, **kwargs: Any) -> Any:
+        obj = super(PsTypeMeta, cls).__call__(*args, const=const, **kwargs)
+        canonical_args = obj.__args__()
+        key = (cls, canonical_args, const)
+
+        if key in cls._instances:
+            obj = cls._instances[key]
+        else:
+            cls._instances[key] = obj
+
+        return obj
+
+
+class PsType(metaclass=PsTypeMeta):
+    """Base class for all pystencils types.
+
+    Args:
+        const: Const-qualification of this type
+
+    **Implementation details for subclasses:**
+    `PsType` and its metaclass ``PsTypeMeta`` together implement a uniquing mechanism to ensure that of each type,
+    only one instance ever exists in the public.
+    For this to work, subclasses have to adhere to several rules:
+
+     - All instances of `PsType` must be immutable.
+     - The `const` argument must be the last keyword argument to ``__init__`` and must be passed to the superclass
+       ``__init__``.
+     - The `__args__` method must return a tuple of positional arguments excluding the `const` property,
+       which, when passed to the class's constructor, create an identically-behaving instance.
+    """
+
+    def __init__(self, const: bool = False):
+        self._const = const
+
+        self._requalified: PsType | None = None
+
+    @property
+    def const(self) -> bool:
+        return self._const
+
+    #   -------------------------------------------------------------------------------------------
+    #   Optional Info
+    #   -------------------------------------------------------------------------------------------
+
+    @property
+    def required_headers(self) -> set[str]:
+        """The set of header files required when this type occurs in generated code."""
+        return set()
+
+    @property
+    def itemsize(self) -> int | None:
+        """If this type has a valid in-memory size, return that size."""
+        return None
+
+    @property
+    def numpy_dtype(self) -> np.dtype | None:
+        """A np.dtype object representing this data type.
+
+        Available both for backward compatibility and for interaction with the numpy-based runtime system.
+        """
+        return None
+
+    #   -------------------------------------------------------------------------------------------
+    #   Internal operations
+    #   -------------------------------------------------------------------------------------------
+
+    @abstractmethod
+    def __args__(self) -> tuple[Any, ...]:
+        """Arguments to this type, excluding the const-qualifier.
+
+        The tuple returned by this method is used to serialize, deserialize, and check equality of types.
+        For each instantiable subclass ``MyType`` of ``PsType``, the following must hold:
+
+        ```
+        t = MyType(< arguments >)
+        assert MyType(*t.__args__()) == t
+        ```
+        """
+        pass
+
+    def _const_string(self) -> str:
+        return "const " if self._const else ""
+
+    @abstractmethod
+    def c_string(self) -> str:
+        pass
+
+    #   -------------------------------------------------------------------------------------------
+    #   Dunder Methods
+    #   -------------------------------------------------------------------------------------------
+
+    def __eq__(self, other: object) -> bool:
+        if self is other:
+            return True
+
+        if type(self) is not type(other):
+            return False
+
+        other = cast(PsType, other)
+        return self._const == other._const and self.__args__() == other.__args__()
+
+    def __str__(self) -> str:
+        return self.c_string()
+
+    def __hash__(self) -> int:
+        return hash((type(self), self.__args__()))
+
+
+T = TypeVar("T", bound=PsType)
+
+
+def constify(t: T) -> T:
+    """Adds the const qualifier to a given type."""
+    if not t.const:
+        if t._requalified is None:
+            t._requalified = type(t)(*t.__args__(), const=True)  # type: ignore
+        return cast(T, t._requalified)
+    else:
+        return t
+
+
+def deconstify(t: T) -> T:
+    """Removes the const qualifier from a given type."""
+    if t.const:
+        if t._requalified is None:
+            t._requalified = type(t)(*t.__args__(), const=False)  # type: ignore
+        return cast(T, t._requalified)
+    else:
+        return t
diff --git a/src/pystencils/types/parsing.py b/src/pystencils/types/parsing.py
index e28b83ae7..40f989a09 100644
--- a/src/pystencils/types/parsing.py
+++ b/src/pystencils/types/parsing.py
@@ -1,6 +1,6 @@
 import numpy as np
 
-from .basic_types import (
+from .types import (
     PsType,
     PsPointerType,
     PsStructType,
@@ -76,13 +76,13 @@ def parse_type_string(s: str) -> PsType:
             base_type = parse_type_string(base)
             match suffix.split():
                 case []:
-                    return PsPointerType(base_type, const=False, restrict=False)
+                    return PsPointerType(base_type, restrict=False, const=False)
                 case ["const"]:
-                    return PsPointerType(base_type, const=True, restrict=False)
+                    return PsPointerType(base_type, restrict=False, const=True)
                 case ["restrict"]:
-                    return PsPointerType(base_type, const=False, restrict=True)
+                    return PsPointerType(base_type, restrict=True, const=False)
                 case ["const", "restrict"] | ["restrict", "const"]:
-                    return PsPointerType(base_type, const=True, restrict=True)
+                    return PsPointerType(base_type, restrict=True, const=True)
                 case _:
                     raise ValueError(f"Could not parse token '{s}' as C type.")
 
diff --git a/src/pystencils/types/quick.py b/src/pystencils/types/quick.py
index c1a3aadc5..1c44ba398 100644
--- a/src/pystencils/types/quick.py
+++ b/src/pystencils/types/quick.py
@@ -4,7 +4,7 @@ from __future__ import annotations
 
 import numpy as np
 
-from .basic_types import (
+from .types import (
     PsType,
     PsCustomType,
     PsNumericType,
diff --git a/src/pystencils/types/basic_types.py b/src/pystencils/types/types.py
similarity index 80%
rename from src/pystencils/types/basic_types.py
rename to src/pystencils/types/types.py
index ebfe9d610..b61c421d0 100644
--- a/src/pystencils/types/basic_types.py
+++ b/src/pystencils/types/types.py
@@ -1,94 +1,12 @@
 from __future__ import annotations
 from abc import ABC, abstractmethod
-from typing import final, TypeVar, Any, Sequence, cast
+from typing import final, Any, Sequence
 from dataclasses import dataclass
-from copy import copy
 
 import numpy as np
 
 from .exception import PsTypeError
-
-
-class PsType(ABC):
-    """Base class for all pystencils types.
-
-    Args:
-        const: Const-qualification of this type
-    """
-
-    def __init__(self, const: bool = False):
-        self._const = const
-
-    @property
-    def const(self) -> bool:
-        return self._const
-
-    #   -------------------------------------------------------------------------------------------
-    #   Optional Info
-    #   -------------------------------------------------------------------------------------------
-
-    @property
-    def required_headers(self) -> set[str]:
-        """The set of header files required when this type occurs in generated code."""
-        return set()
-
-    @property
-    def itemsize(self) -> int | None:
-        """If this type has a valid in-memory size, return that size."""
-        return None
-
-    @property
-    def numpy_dtype(self) -> np.dtype | None:
-        """A np.dtype object representing this data type.
-
-        Available both for backward compatibility and for interaction with the numpy-based runtime system.
-        """
-        return None
-
-    #   -------------------------------------------------------------------------------------------
-    #   Internal operations
-    #   -------------------------------------------------------------------------------------------
-
-    @abstractmethod
-    def __args__(self) -> tuple[Any, ...]:
-        """Arguments to this type.
-        
-        The tuple returned by this method is used to serialize, deserialize, and check equality of types.
-        For each instantiable subclass ``MyType`` of ``PsType``, the following must hold:
-
-        ```
-        t = MyType(< arguments >)
-        assert MyType(*t.__args__()) == t
-        ```
-        """
-        pass
-
-    def _const_string(self) -> str:
-        return "const " if self._const else ""
-
-    @abstractmethod
-    def c_string(self) -> str:
-        pass
-
-    #   -------------------------------------------------------------------------------------------
-    #   Dunder Methods
-    #   -------------------------------------------------------------------------------------------
-
-    def __eq__(self, other: object) -> bool:
-        if self is other:
-            return True
-        
-        if type(self) is not type(other):
-            return False
-        
-        other = cast(PsType, other)
-        return self.__args__() == other.__args__()
-
-    def __str__(self) -> str:
-        return self.c_string()
-
-    def __hash__(self) -> int:
-        return hash((type(self), self.__args__()))
+from .meta import PsType, constify, deconstify
 
 
 class PsCustomType(PsType):
@@ -154,17 +72,17 @@ class PsPointerType(PsDereferencableType):
 
     __match_args__ = ("base_type",)
 
-    def __init__(self, base_type: PsType, const: bool = False, restrict: bool = True):
+    def __init__(self, base_type: PsType, restrict: bool = True, const: bool = False):
         super().__init__(base_type, const)
         self._restrict = restrict
 
     def __args__(self) -> tuple[Any, ...]:
         """
-        >>> t = PsPointerType(PsBoolType(), const=True)
+        >>> t = PsPointerType(PsBoolType())
         >>> t == PsPointerType(*t.__args__())
         True
         """
-        return (self._base_type, self._const, self._restrict)
+        return (self._base_type, self._restrict)
 
     @property
     def restrict(self) -> bool:
@@ -190,11 +108,11 @@ class PsArrayType(PsDereferencableType):
 
     def __args__(self) -> tuple[Any, ...]:
         """
-        >>> t = PsArrayType(PsBoolType(), 13, const=True)
+        >>> t = PsArrayType(PsBoolType(), 13)
         >>> t == PsArrayType(*t.__args__())
         True
         """
-        return (self._base_type, self._length, self._const)
+        return (self._base_type, self._length)
 
     @property
     def length(self) -> int | None:
@@ -246,7 +164,7 @@ class PsStructType(PsType):
         >>> t == PsStructType(*t.__args__())
         True
         """
-        return (self._members, self._name, self._const)
+        return (self._members, self._name)
 
     @property
     def members(self) -> tuple[PsStructType.Member, ...]:
@@ -394,11 +312,11 @@ class PsVectorType(PsNumericType):
 
     def __args__(self) -> tuple[Any, ...]:
         """
-        >>> t = PsVectorType(PsBoolType(), 8, True)
+        >>> t = PsVectorType(PsBoolType(), 8)
         >>> t == PsVectorType(*t.__args__())
         True
         """
-        return (self._scalar_type, self._vector_entries, self._const)
+        return (self._scalar_type, self._vector_entries)
 
     @property
     def scalar_type(self) -> PsScalarType:
@@ -474,11 +392,11 @@ class PsBoolType(PsScalarType):
 
     def __args__(self) -> tuple[Any, ...]:
         """
-        >>> t = PsBoolType(True)
+        >>> t = PsBoolType()
         >>> t == PsBoolType(*t.__args__())
         True
         """
-        return (self._const,)
+        return ()
 
     @property
     def width(self) -> int:
@@ -494,7 +412,9 @@ class PsBoolType(PsScalarType):
 
     def create_literal(self, value: Any) -> str:
         if not isinstance(value, self.NUMPY_TYPE):
-            raise PsTypeError(f"Given value {value} is not of required type {self.NUMPY_TYPE}")
+            raise PsTypeError(
+                f"Given value {value} is not of required type {self.NUMPY_TYPE}"
+            )
 
         if value == np.True_:
             return "true"
@@ -513,7 +433,7 @@ class PsBoolType(PsScalarType):
 
     def c_string(self) -> str:
         return "bool"
-    
+
 
 class PsIntegerType(PsScalarType, ABC):
     """Signed and unsigned integer types.
@@ -561,18 +481,20 @@ class PsIntegerType(PsScalarType, ABC):
         unsigned_suffix = "" if self.signed else "u"
         #   TODO: cast literal to correct type?
         return str(value) + unsigned_suffix
-    
+
     def create_constant(self, value: Any) -> Any:
         np_type = self.NUMPY_TYPES[self._width]
 
         if isinstance(value, (int, np.integer)):
             iinfo = np.iinfo(np_type)  # type: ignore
             if value < iinfo.min or value > iinfo.max:
-                raise PsTypeError(f"Could not interpret {value} as {self}: Value is out of bounds.")
+                raise PsTypeError(
+                    f"Could not interpret {value} as {self}: Value is out of bounds."
+                )
             return np_type(value)
 
         raise PsTypeError(f"Could not interpret {value} as {repr(self)}")
-    
+
     def c_string(self) -> str:
         prefix = "" if self._signed else "u"
         return f"{self._const_string()}{prefix}int{self._width}_t"
@@ -599,11 +521,11 @@ class PsSignedIntegerType(PsIntegerType):
 
     def __args__(self) -> tuple[Any, ...]:
         """
-        >>> t = PsSignedIntegerType(32, True)
+        >>> t = PsSignedIntegerType(32)
         >>> t == PsSignedIntegerType(*t.__args__())
         True
         """
-        return (self._width, self._const)
+        return (self._width,)
 
 
 @final
@@ -624,11 +546,11 @@ class PsUnsignedIntegerType(PsIntegerType):
 
     def __args__(self) -> tuple[Any, ...]:
         """
-        >>> t = PsUnsignedIntegerType(32, True)
+        >>> t = PsUnsignedIntegerType(32)
         >>> t == PsUnsignedIntegerType(*t.__args__())
         True
         """
-        return (self._width, self._const)
+        return (self._width,)
 
 
 @final
@@ -656,11 +578,11 @@ class PsIeeeFloatType(PsScalarType):
 
     def __args__(self) -> tuple[Any, ...]:
         """
-        >>> t = PsIeeeFloatType(32, True)
+        >>> t = PsIeeeFloatType(32)
         >>> t == PsIeeeFloatType(*t.__args__())
         True
         """
-        return (self._width, self._const)
+        return (self._width,)
 
     @property
     def width(self) -> int:
@@ -702,7 +624,9 @@ class PsIeeeFloatType(PsScalarType):
         if isinstance(value, (int, float, np.floating)):
             finfo = np.finfo(np_type)  # type: ignore
             if value < finfo.min or value > finfo.max:
-                raise PsTypeError(f"Could not interpret {value} as {self}: Value is out of bounds.")
+                raise PsTypeError(
+                    f"Could not interpret {value} as {self}: Value is out of bounds."
+                )
             return np_type(value)
 
         raise PsTypeError(f"Could not interpret {value} as {repr(self)}")
@@ -720,26 +644,3 @@ class PsIeeeFloatType(PsScalarType):
 
     def __repr__(self) -> str:
         return f"PsIeeeFloatType( width={self.width}, const={self.const} )"
-
-
-T = TypeVar("T", bound=PsType)
-
-
-def constify(t: T) -> T:
-    """Adds the const qualifier to a given type."""
-    if not t.const:
-        t_copy = copy(t)
-        t_copy._const = True
-        return t_copy
-    else:
-        return t
-
-
-def deconstify(t: T) -> T:
-    """Removes the const qualifier from a given type."""
-    if t.const:
-        t_copy = copy(t)
-        t_copy._const = False
-        return t_copy
-    else:
-        return t
diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py
index d9cc5f9ce..abe22ccc1 100644
--- a/tests/nbackend/kernelcreation/test_typification.py
+++ b/tests/nbackend/kernelcreation/test_typification.py
@@ -289,3 +289,6 @@ def test_typify_constant_clones():
 
     assert expr_clone.operand1.dtype is None
     assert cast(PsConstantExpr, expr_clone.operand1).constant.dtype is None
+
+
+test_lhs_constness()
\ No newline at end of file
diff --git a/tests/nbackend/types/test_types.py b/tests/nbackend/types/test_types.py
index 24a46ab90..74467080a 100644
--- a/tests/nbackend/types/test_types.py
+++ b/tests/nbackend/types/test_types.py
@@ -19,13 +19,13 @@ def test_widths(Type):
 
 
 def test_parsing_positive():
-    assert create_type("const uint32_t * restrict") == Ptr(
+    assert create_type("const uint32_t * restrict") is Ptr(
         UInt(32, const=True), restrict=True
     )
-    assert create_type("float * * const") == Ptr(Ptr(Fp(32), restrict=False), const=True, restrict=False)
-    assert create_type("float * * restrict const") == Ptr(Ptr(Fp(32), restrict=False), const=True, restrict=True)
-    assert create_type("uint16 * const") == Ptr(UInt(16), const=True, restrict=False)
-    assert create_type("uint64 const * const") == Ptr(UInt(64, const=True), const=True, restrict=False)
+    assert create_type("float * * const") is Ptr(Ptr(Fp(32), restrict=False), const=True, restrict=False)
+    assert create_type("float * * restrict const") is Ptr(Ptr(Fp(32), restrict=False), const=True, restrict=True)
+    assert create_type("uint16 * const") is Ptr(UInt(16), const=True, restrict=False)
+    assert create_type("uint64 const * const") is Ptr(UInt(64, const=True), const=True, restrict=False)
 
 
 def test_parsing_negative():
@@ -45,14 +45,14 @@ def test_parsing_negative():
 def test_numpy():
     import numpy as np
 
-    assert create_type(np.single) == create_type(np.float32) == PsIeeeFloatType(32)
+    assert create_type(np.single) is create_type(np.float32) is PsIeeeFloatType(32)
     assert (
         create_type(float)
-        == create_type(np.double)
-        == create_type(np.float64)
-        == PsIeeeFloatType(64)
+        is create_type(np.double)
+        is create_type(np.float64)
+        is PsIeeeFloatType(64)
     )
-    assert create_type(int) == create_type(np.int64) == PsSignedIntegerType(64)
+    assert create_type(int) is create_type(np.int64) is PsSignedIntegerType(64)
 
 
 @pytest.mark.parametrize(
@@ -102,10 +102,21 @@ def test_numpy_translation(numpy_type):
 
 def test_constify():
     t = PsCustomType("std::shared_ptr< Custom >")
-    assert deconstify(t) == t
-    assert deconstify(constify(t)) == t
+    assert deconstify(t) is t
+    assert deconstify(constify(t)) is t
+    
     s = PsCustomType("Field", const=True)
-    assert constify(s) == s
+    assert constify(s) is s
+
+    i32 = create_type(np.int32)
+    i32_2 = PsSignedIntegerType(32)
+
+    assert i32 is i32_2
+    assert constify(i32) is constify(i32_2)
+
+    i32_const = PsSignedIntegerType(32, const=True)
+    assert i32_const is not i32
+    assert i32_const is constify(i32)
 
 
 def test_struct_types():
-- 
GitLab