diff --git a/src/pystencils/types/meta.py b/src/pystencils/types/meta.py index 915943cf8d730c9925c72e5f4896dd9591321d7a..198abf0ec2638accce560b201ec9923a899957bb 100644 --- a/src/pystencils/types/meta.py +++ b/src/pystencils/types/meta.py @@ -24,10 +24,14 @@ For this to work, all instantiable subclasses of `PsType` must implement the fol - The ``const`` parameter must be the last keyword parameter of ``__init__``. - The ``__canonical_args__`` classmethod must have the same signature as ``__init__``, except it does not take the ``const`` parameter. It must return a tuple containing all the positional and keyword - arguments in their canonical order. This method is used by `PsTypeMeta` to identify instances of the type, + arguments in their canonical order. This method is used by `PsTypeMeta` to identify instances, and to catch the various different possibilities Python offers for passing function arguments. - The ``__args__`` method, when called on an instance of the type, must return a tuple containing the constructor - arguments required to create that exact instance. + arguments required to create that exact instance. This is used for pickling and unpickling of type objects, + as well as const-conversion. + +As a rule, ``MyType.__canonical_args__(< arguments >)`` and ``MyType(< arguments >).__args__()`` must always return +the same tuple. Developers intending to extend the type class hierarchy are advised to study the implementations of this protocol in the existing classes. @@ -73,6 +77,10 @@ class PsType(metaclass=PsTypeMeta): const: Const-qualification of this type """ + # ------------------------------------------------------------------------------------------- + # Internals: Object creation, pickling and unpickling + # ------------------------------------------------------------------------------------------- + def __new__(cls, *args, _pickle=False, **kwargs): if _pickle: # force unpickler to use metaclass uniquing mechanism @@ -85,6 +93,34 @@ class PsType(metaclass=PsTypeMeta): kwargs = {"const": self._const, "_pickle": True} return args, kwargs + def __getstate__(self): + # To make sure pickle does not unnecessarily override the instance dictionary + return None + + @abstractmethod + def __args__(self) -> tuple[Any, ...]: + """Return the arguments used to create this instance, in canonical order, excluding the const-qualifier. + + The tuple returned by this method is used to serialize, deserialize, and const-convert types. + For each instantiable subclass ``MyType`` of ``PsType``, the following must hold:: + + t = MyType(< arguments >) + assert MyType(*t.__args__()) == t + + """ + + @classmethod + @abstractmethod + def __canonical_args__(cls, *args, **kwargs) -> tuple[Any, ...]: + """Return a tuple containing the positional and keyword arguments of ``__init__`` + in their canonical order.""" + + # __eq__ and __hash__ unimplemented because due to uniquing, types are equal iff their instances are equal + + # ------------------------------------------------------------------------------------------- + # Constructor and properties + # ------------------------------------------------------------------------------------------- + def __init__(self, const: bool = False): self._const = const @@ -117,29 +153,9 @@ class PsType(metaclass=PsTypeMeta): return None # ------------------------------------------------------------------------------------------- - # Internal operations + # String Conversion # ------------------------------------------------------------------------------------------- - @abstractmethod - def __args__(self) -> tuple[Any, ...]: - """Arguments to this type, excluding the const-qualifier. - - The tuple returned by this method is used to serialize, deserialize, and check equality of types. - For each instantiable subclass ``MyType`` of ``PsType``, the following must hold:: - - t = MyType(< arguments >) - assert MyType(*t.__args__()) == t - - """ - pass - - @classmethod - @abstractmethod - def __canonical_args__(cls, *args, **kwargs): - """Return a tuple containing the positional and keyword arguments of ``__init__`` - in their canonical order.""" - pass - def _const_string(self) -> str: return "const " if self._const else "" @@ -147,26 +163,9 @@ class PsType(metaclass=PsTypeMeta): def c_string(self) -> str: pass - # ------------------------------------------------------------------------------------------- - # Dunder Methods - # ------------------------------------------------------------------------------------------- - - def __eq__(self, other: object) -> bool: - if self is other: - return True - - if type(self) is not type(other): - return False - - other = cast(PsType, other) - return self._const == other._const and self.__args__() == other.__args__() - def __str__(self) -> str: return self.c_string() - def __hash__(self) -> int: - return hash((type(self), self.__args__())) - T = TypeVar("T", bound=PsType) diff --git a/src/pystencils/types/types.py b/src/pystencils/types/types.py index 5499d359158f2d33669edf977abf628c800919ed..617feeb4e2787e2b5561a75f076cdb368f597ade 100644 --- a/src/pystencils/types/types.py +++ b/src/pystencils/types/types.py @@ -24,6 +24,11 @@ class PsCustomType(PsType): @classmethod def __canonical_args__(cls, name: str): + """ + >>> t = PsCustomType(*PsCustomType.__canonical_args__(name="x")) + >>> t is PsCustomType("x") + True + """ return (name,) def __args__(self) -> tuple[Any, ...]: @@ -82,6 +87,11 @@ class PsPointerType(PsDereferencableType): @classmethod def __canonical_args__(cls, base_type: PsType, restrict: bool = True): + """ + >>> t = PsPointerType(*PsPointerType.__canonical_args__(restrict=False, base_type=PsBoolType())) + >>> t is PsPointerType(PsBoolType(), False) + True + """ return (base_type, restrict) def __args__(self) -> tuple[Any, ...]: @@ -116,6 +126,11 @@ class PsArrayType(PsDereferencableType): @classmethod def __canonical_args__(cls, base_type: PsType, length: int | None = None): + """ + >>> t = PsArrayType(*PsArrayType.__canonical_args__(length=32, base_type=PsBoolType())) + >>> t is PsArrayType(PsBoolType(), 32) + True + """ return (base_type, length) def __args__(self) -> tuple[Any, ...]: @@ -180,6 +195,11 @@ class PsStructType(PsType): members: Sequence[PsStructType.Member | tuple[str, PsType]], name: str | None = None, ): + """ + >>> t = PsStructType(*PsStructType.__canonical_args__(name="x", members=[("elem", PsBoolType())])) + >>> t is PsStructType([PsStructType.Member("elem", PsBoolType())], "x") + True + """ return (cls._canonical_members(members), name) def __args__(self) -> tuple[Any, ...]: @@ -336,6 +356,11 @@ class PsVectorType(PsNumericType): @classmethod def __canonical_args__(cls, scalar_type: PsScalarType, vector_entries: int): + """ + >>> t = PsVectorType(*PsVectorType.__canonical_args__(vector_entries=8, scalar_type=PsBoolType())) + >>> t is PsVectorType(PsBoolType(), 8) + True + """ return (scalar_type, vector_entries) def __args__(self) -> tuple[Any, ...]: @@ -553,6 +578,11 @@ class PsSignedIntegerType(PsIntegerType): @classmethod def __canonical_args__(cls, width: int): + """ + >>> t = PsSignedIntegerType(*PsSignedIntegerType.__canonical_args__(width=8)) + >>> t is PsSignedIntegerType(8) + True + """ return (width,) def __args__(self) -> tuple[Any, ...]: @@ -582,6 +612,11 @@ class PsUnsignedIntegerType(PsIntegerType): @classmethod def __canonical_args__(cls, width: int): + """ + >>> t = PsUnsignedIntegerType(*PsUnsignedIntegerType.__canonical_args__(width=8)) + >>> t is PsUnsignedIntegerType(8) + True + """ return (width,) def __args__(self) -> tuple[Any, ...]: @@ -610,7 +645,7 @@ class PsIeeeFloatType(PsScalarType): def __init__(self, width: int, const: bool = False): if width not in self.SUPPORTED_WIDTHS: raise ValueError( - f"Invalid integer width {width}; must be one of {self.SUPPORTED_WIDTHS}." + f"Invalid floating-point width {width}; must be one of {self.SUPPORTED_WIDTHS}." ) super().__init__(const) @@ -618,6 +653,11 @@ class PsIeeeFloatType(PsScalarType): @classmethod def __canonical_args__(cls, width: int): + """ + >>> t = PsIeeeFloatType(*PsIeeeFloatType.__canonical_args__(width=16)) + >>> t is PsIeeeFloatType(16) + True + """ return (width,) def __args__(self) -> tuple[Any, ...]: