From fabbb472a7a99bab9b27684cddf975bc0f3d8dd9 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Fri, 5 Apr 2024 17:18:36 +0200
Subject: [PATCH] remove canonical_args. Refactor PsTypeMeta. Update
 documentation.

---
 docs/source/api/types.rst        | 13 +----
 src/pystencils/types/__init__.py |  7 ---
 src/pystencils/types/meta.py     | 97 ++++++++++++++++----------------
 src/pystencils/types/types.py    | 80 --------------------------
 4 files changed, 52 insertions(+), 145 deletions(-)

diff --git a/docs/source/api/types.rst b/docs/source/api/types.rst
index 624d956bc..5a740c058 100644
--- a/docs/source/api/types.rst
+++ b/docs/source/api/types.rst
@@ -4,6 +4,7 @@ Type System
 
 .. automodule:: pystencils.types
 
+
 Basic Functions
 -------------------------------------
 
@@ -13,7 +14,6 @@ Basic Functions
 .. autofunction:: pystencils.types.deconstify
 
 
-
 Data Type Class Hierarchy
 -------------------------
 
@@ -35,14 +35,7 @@ Data Type Abbreviations
     :members:
 
 
-Metaclass, Base Class and Uniquing Mechanisms
----------------------------------------------
+Implementation Details
+----------------------
 
 .. automodule:: pystencils.types.meta
-
-.. autoclass:: pystencils.types.meta.PsTypeMeta
-    :members:
-
-.. autofunction:: pystencils.types.PsType.__args__
-
-.. autofunction:: pystencils.types.PsType.__canonical_args__
diff --git a/src/pystencils/types/__init__.py b/src/pystencils/types/__init__.py
index 5f4335839..e9b67096b 100644
--- a/src/pystencils/types/__init__.py
+++ b/src/pystencils/types/__init__.py
@@ -3,13 +3,6 @@ The `pystencils.types` module contains the set of classes used by pystencils
 to model data types. Data types are used extensively within the code generator,
 but can safely be ignored by most users unless you wish to force certain types on
 symbols, generate mixed-precision kernels, et cetera.
-
-The various classes that constitute the pystencils type system are implemented in
-`pystencils.types.types`; most have abbreviated names defined in `pystencils.types.quick`.
-
-For more information about the type system's internal workings, and developer's guidance on
-how to extend it, refer to `pystencils.types.meta`.
-
 """
 
 from .meta import PsType, constify, deconstify
diff --git a/src/pystencils/types/meta.py b/src/pystencils/types/meta.py
index 02c2f5a3b..1d605edf8 100644
--- a/src/pystencils/types/meta.py
+++ b/src/pystencils/types/meta.py
@@ -1,11 +1,12 @@
 """
-Although mostly invisible to the user, types are ubiquitous throughout pystencils.
-They are created and converted in many places, especially in the code generation backend.
-To handle and compare types more efficiently, the pystencils type system implements
-a uniquing mechanism to ensure that at any point there exists only one instance of each type.
+
+Caching of Instances
+^^^^^^^^^^^^^^^^^^^^
+
+To handle and compare types more efficiently, the pystencils type system customizes class
+instantiation to cache and reuse existing instances of types.
 This means, for example, if a 32-bit const unsigned integer type gets created in two places
-at two different times in the program, the two types don't just behave identically, but
-in fact refer to the same object:
+in the program, the resulting objects are exactly the same:
 
 >>> from pystencils.types import PsUnsignedIntegerType
 >>> t1 = PsUnsignedIntegerType(32, const=True)
@@ -13,28 +14,24 @@ in fact refer to the same object:
 >>> t1 is t2
 True
 
-Both calls to `PsUnsignedIntegerType` return the same object. This is ensured by the
-`PsTypeMeta` metaclass.
-This metaclass holds an internal registry of all type objects ever created,
-and alters the class instantiation mechanism such that whenever a type is instantiated
-a second time with the same arguments, the pre-existing instance is found and returned instead.
+This mechanism is implemented by the metaclass `PsTypeMeta`. It is not perfect, however;
+some parts of Python that bypass the regular object creation sequence, such as `pickle` and
+`copy.copy`, may create additional instances of types.
 
-For this to work, all instantiable subclasses of `PsType` must implement the following protocol:
+.. autoclass:: pystencils.types.meta.PsTypeMeta
+    :members:
 
-- The ``const`` parameter must be the last keyword parameter of ``__init__``.
-- The ``__canonical_args__`` classmethod must have the same signature as ``__init__``, except it does
-  not take the ``const`` parameter. It must return a tuple containing all the positional and keyword
-  arguments in their canonical order. This method is used by `PsTypeMeta` to identify instances,
-  and to catch the various different possibilities Python offers for passing function arguments.
-- The ``__args__`` method, when called on an instance of the type, must return a tuple containing the constructor
-  arguments required to create that exact instance. This is used for comparing type objects
-  as well as const-conversion.
+Extending the Type System
+^^^^^^^^^^^^^^^^^^^^^^^^^
 
-As a rule, ``MyType.__canonical_args__(< arguments >)`` and ``MyType(< arguments >).__args__()`` must always return
-the same tuple.
+When extending the type system's class hierarchy, new classes need to implement at least the internal
+method `__args__`. This method, when called on a type object, must return a hashable sequence of arguments
+-- not including the const-qualifier --
+that can be used to recreate that exact type. It is used internally to compute hashes and compare equality
+of types, as well as for const-conversion.
+    
+.. autofunction:: pystencils.types.PsType.__args__
 
-Developers intending to extend the type class hierarchy are advised to study the implementations
-of this protocol in the existing classes.
 """
 
 from __future__ import annotations
@@ -47,24 +44,34 @@ import numpy as np
 class PsTypeMeta(ABCMeta):
     """Metaclass for the `PsType` hierarchy.
 
-    `PsTypeMeta` holds an internal cache of all instances of `PsType` and overrides object creation
-    such that whenever a type gets instantiated more than once, instead of creating a new object,
-    the existing object is returned.
+    `PsTypeMeta` holds an internal cache of all created instances of `PsType` and overrides object creation
+    such that whenever a type gets instantiated more than once with the same argument list,
+    instead of creating a new object, the existing object is returned.
     """
 
     _instances: dict[Any, PsType] = dict()
 
-    def __call__(
-        cls: PsTypeMeta, *args: Any, const: bool = False, **kwargs: Any
-    ) -> Any:
+    def __call__(cls: PsTypeMeta, *args: Any, **kwargs: Any) -> Any:
         assert issubclass(cls, PsType)
-        canonical_args = cls.__canonical_args__(*args, **kwargs)
-        key = (cls, canonical_args, const)
+        kwarg_tuples = tuple(sorted(kwargs.items(), key=lambda t: t[0]))
+
+        try:
+            key = (cls, args, kwarg_tuples)
 
-        if key in cls._instances:
-            obj = cls._instances[key]
+            if key in cls._instances:
+                return cls._instances[key]
+        except TypeError:
+            key = None
+
+        obj = super().__call__(*args, **kwargs)
+        canonical_key = (cls, obj.__args__(), (("const", obj.const),))
+
+        if canonical_key in cls._instances:
+            obj = cls._instances[canonical_key]
         else:
-            obj = super().__call__(*args, const=const, **kwargs)
+            cls._instances[canonical_key] = obj
+
+        if key is not None:
             cls._instances[key] = obj
 
         return obj
@@ -78,37 +85,31 @@ class PsType(metaclass=PsTypeMeta):
     """
 
     #   -------------------------------------------------------------------------------------------
-    #   Internals: Object creation, pickling and unpickling
+    #   Arguments, Equality and Hashing
     #   -------------------------------------------------------------------------------------------
 
     @abstractmethod
     def __args__(self) -> tuple[Any, ...]:
         """Return the arguments used to create this instance, in canonical order, excluding the const-qualifier.
 
-        The tuple returned by this method is used to identify, check equality, and const-convert types.
-        For each instantiable subclass ``MyType`` of ``PsType``, the following must hold::
+        The tuple returned by this method must be hashable and for each instantiable subclass
+        ``MyType`` of ``PsType``, the following must hold::
 
             t = MyType(< arguments >)
-            assert MyType(*t.__args__()) == t
+            assert MyType(*t.__args__(), const=t.const) == t
 
         """
 
-    @classmethod
-    @abstractmethod
-    def __canonical_args__(cls, *args, **kwargs) -> tuple[Any, ...]:
-        """Return a tuple containing the positional and keyword arguments of ``__init__``
-        in their canonical order."""
-
     def __eq__(self, other: object) -> bool:
         if self is other:
             return True
-        
+
         if type(self) is not type(other):
             return False
-        
+
         other = cast(PsType, other)
         return self.const == other.const and self.__args__() == other.__args__()
-    
+
     def __hash__(self) -> int:
         return hash((type(self), self.const, self.__args__()))
 
diff --git a/src/pystencils/types/types.py b/src/pystencils/types/types.py
index 617feeb4e..8e51f9397 100644
--- a/src/pystencils/types/types.py
+++ b/src/pystencils/types/types.py
@@ -22,15 +22,6 @@ class PsCustomType(PsType):
         super().__init__(const)
         self._name = name
 
-    @classmethod
-    def __canonical_args__(cls, name: str):
-        """
-        >>> t = PsCustomType(*PsCustomType.__canonical_args__(name="x"))
-        >>> t is PsCustomType("x")
-        True
-        """
-        return (name,)
-
     def __args__(self) -> tuple[Any, ...]:
         """
         >>> t = PsCustomType("std::vector< int >")
@@ -85,15 +76,6 @@ class PsPointerType(PsDereferencableType):
         super().__init__(base_type, const)
         self._restrict = restrict
 
-    @classmethod
-    def __canonical_args__(cls, base_type: PsType, restrict: bool = True):
-        """
-        >>> t = PsPointerType(*PsPointerType.__canonical_args__(restrict=False, base_type=PsBoolType()))
-        >>> t is PsPointerType(PsBoolType(), False)
-        True
-        """
-        return (base_type, restrict)
-
     def __args__(self) -> tuple[Any, ...]:
         """
         >>> t = PsPointerType(PsBoolType())
@@ -124,15 +106,6 @@ class PsArrayType(PsDereferencableType):
         self._length = length
         super().__init__(base_type, const)
 
-    @classmethod
-    def __canonical_args__(cls, base_type: PsType, length: int | None = None):
-        """
-        >>> t = PsArrayType(*PsArrayType.__canonical_args__(length=32, base_type=PsBoolType()))
-        >>> t is PsArrayType(PsBoolType(), 32)
-        True
-        """
-        return (base_type, length)
-
     def __args__(self) -> tuple[Any, ...]:
         """
         >>> t = PsArrayType(PsBoolType(), 13)
@@ -189,19 +162,6 @@ class PsStructType(PsType):
                 raise ValueError(f"Duplicate struct member name: {member.name}")
             names.add(member.name)
 
-    @classmethod
-    def __canonical_args__(
-        cls,
-        members: Sequence[PsStructType.Member | tuple[str, PsType]],
-        name: str | None = None,
-    ):
-        """
-        >>> t = PsStructType(*PsStructType.__canonical_args__(name="x", members=[("elem", PsBoolType())]))
-        >>> t is PsStructType([PsStructType.Member("elem", PsBoolType())], "x")
-        True
-        """
-        return (cls._canonical_members(members), name)
-
     def __args__(self) -> tuple[Any, ...]:
         """
         >>> t = PsStructType([("idx", PsSignedIntegerType(32)), ("val", PsBoolType())], "sname")
@@ -354,15 +314,6 @@ class PsVectorType(PsNumericType):
         self._vector_entries = vector_entries
         self._scalar_type = constify(scalar_type) if const else deconstify(scalar_type)
 
-    @classmethod
-    def __canonical_args__(cls, scalar_type: PsScalarType, vector_entries: int):
-        """
-        >>> t = PsVectorType(*PsVectorType.__canonical_args__(vector_entries=8, scalar_type=PsBoolType()))
-        >>> t is PsVectorType(PsBoolType(), 8)
-        True
-        """
-        return (scalar_type, vector_entries)
-
     def __args__(self) -> tuple[Any, ...]:
         """
         >>> t = PsVectorType(PsBoolType(), 8)
@@ -443,10 +394,6 @@ class PsBoolType(PsScalarType):
     def __init__(self, const: bool = False):
         super().__init__(const)
 
-    @classmethod
-    def __canonical_args__(cls, *args, **kwargs):
-        return ()
-
     def __args__(self) -> tuple[Any, ...]:
         """
         >>> t = PsBoolType()
@@ -576,15 +523,6 @@ class PsSignedIntegerType(PsIntegerType):
     def __init__(self, width: int, const: bool = False):
         super().__init__(width, True, const)
 
-    @classmethod
-    def __canonical_args__(cls, width: int):
-        """
-        >>> t = PsSignedIntegerType(*PsSignedIntegerType.__canonical_args__(width=8))
-        >>> t is PsSignedIntegerType(8)
-        True
-        """
-        return (width,)
-
     def __args__(self) -> tuple[Any, ...]:
         """
         >>> t = PsSignedIntegerType(32)
@@ -610,15 +548,6 @@ class PsUnsignedIntegerType(PsIntegerType):
     def __init__(self, width: int, const: bool = False):
         super().__init__(width, False, const)
 
-    @classmethod
-    def __canonical_args__(cls, width: int):
-        """
-        >>> t = PsUnsignedIntegerType(*PsUnsignedIntegerType.__canonical_args__(width=8))
-        >>> t is PsUnsignedIntegerType(8)
-        True
-        """
-        return (width,)
-
     def __args__(self) -> tuple[Any, ...]:
         """
         >>> t = PsUnsignedIntegerType(32)
@@ -651,15 +580,6 @@ class PsIeeeFloatType(PsScalarType):
         super().__init__(const)
         self._width = width
 
-    @classmethod
-    def __canonical_args__(cls, width: int):
-        """
-        >>> t = PsIeeeFloatType(*PsIeeeFloatType.__canonical_args__(width=16))
-        >>> t is PsIeeeFloatType(16)
-        True
-        """
-        return (width,)
-
     def __args__(self) -> tuple[Any, ...]:
         """
         >>> t = PsIeeeFloatType(32)
-- 
GitLab