From 8a4a65b929d72861f3f2183efa3b5e767acc1250 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Wed, 3 Apr 2024 18:52:49 +0200 Subject: [PATCH] Add pickle support and extend uniquing protocl for types. - Add pickle support to PsType - Add `__canonical_args__` protocol for more efficient uniquing - Write extensive documentation around PsTypeMeta - Refactor data types section in docs --- .gitlab-ci.yml | 12 +++ docs/source/api/types.rst | 34 ++++++-- src/pystencils/field.py | 2 +- src/pystencils/sympyextensions/typed_sympy.py | 3 +- src/pystencils/types/__init__.py | 10 ++- src/pystencils/types/meta.py | 87 +++++++++++++++---- src/pystencils/types/parsing.py | 48 ++++++++++ src/pystencils/types/quick.py | 56 +----------- src/pystencils/types/types.py | 60 +++++++++++-- .../kernelcreation/test_typification.py | 4 +- tests/nbackend/types/test_types.py | 41 ++++++++- 11 files changed, 257 insertions(+), 100 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 664f20570..dbe3ae0eb 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -330,6 +330,18 @@ nbackend-unit-tests: tags: - docker +doctest: + stage: "Unit Tests" + needs: [] + image: i10git.cs.fau.de:5005/pycodegen/pycodegen/full + before_script: + - pip install -e .[tests] + script: + - pytest src/pystencils/backend + - pytest src/pystencils/types + tags: + - docker + # -------------------- Documentation --------------------------------------------------------------------- diff --git a/docs/source/api/types.rst b/docs/source/api/types.rst index 1794fd856..ce08697e2 100644 --- a/docs/source/api/types.rst +++ b/docs/source/api/types.rst @@ -1,21 +1,26 @@ -********** -Data Types -********** - -Type System Module (pystencils.types) -------------------------------------- +*********** +Type System +*********** .. automodule:: pystencils.types +Basic Functions +------------------------------------- + .. autofunction:: pystencils.types.create_type .. autofunction:: pystencils.types.create_numeric_type +.. autofunction:: pystencils.types.constify +.. autofunction:: pystencils.types.deconstify Data Type Class Hierarchy ------------------------- -.. automodule:: pystencils.types.basic_types +.. autoclass:: pystencils.types.PsType + :members: + +.. automodule:: pystencils.types.types :members: @@ -23,4 +28,17 @@ Data Type Abbreviations ----------------------- .. automodule:: pystencils.types.quick - :members: \ No newline at end of file + :members: + + +Metaclass, Base Class and Uniquing Mechanisms +--------------------------------------------- + +.. automodule:: pystencils.types.meta + +.. autoclass:: pystencils.types.meta.PsTypeMeta + :members: + +.. autofunction:: pystencils.types.PsType.__args__ + +.. autofunction:: pystencils.types.PsType.__canonical_args__ diff --git a/src/pystencils/field.py b/src/pystencils/field.py index 2a82b8a43..ac0aa3ae0 100644 --- a/src/pystencils/field.py +++ b/src/pystencils/field.py @@ -17,7 +17,7 @@ from pystencils.stencil import direction_string_to_offset, inverse_direction, of from pystencils.types import PsType, PsStructType, create_type from pystencils.sympyextensions.typed_sympy import (FieldShapeSymbol, FieldStrideSymbol, TypedSymbol) from pystencils.sympyextensions.math import is_integer_sequence -from pystencils.types.quick import UserTypeSpec +from pystencils.types import UserTypeSpec __all__ = ['Field', 'fields', 'FieldType', 'Field'] diff --git a/src/pystencils/sympyextensions/typed_sympy.py b/src/pystencils/sympyextensions/typed_sympy.py index 71caf5a2f..e022db511 100644 --- a/src/pystencils/sympyextensions/typed_sympy.py +++ b/src/pystencils/sympyextensions/typed_sympy.py @@ -1,7 +1,6 @@ import sympy as sp -from ..types import PsType, PsNumericType, PsPointerType, PsBoolType -from ..types.quick import create_type +from ..types import PsType, PsNumericType, PsPointerType, PsBoolType, create_type def assumptions_from_dtype(dtype: PsType): diff --git a/src/pystencils/types/__init__.py b/src/pystencils/types/__init__.py index 05e14fa70..5f4335839 100644 --- a/src/pystencils/types/__init__.py +++ b/src/pystencils/types/__init__.py @@ -4,8 +4,12 @@ to model data types. Data types are used extensively within the code generator, but can safely be ignored by most users unless you wish to force certain types on symbols, generate mixed-precision kernels, et cetera. -For more user-friendly and less verbose access to the type modelling system, refer to -the `pystencils.types.quick` submodule. +The various classes that constitute the pystencils type system are implemented in +`pystencils.types.types`; most have abbreviated names defined in `pystencils.types.quick`. + +For more information about the type system's internal workings, and developer's guidance on +how to extend it, refer to `pystencils.types.meta`. + """ from .meta import PsType, constify, deconstify @@ -26,7 +30,7 @@ from .types import ( PsIeeeFloatType, ) -from .quick import UserTypeSpec, create_type, create_numeric_type +from .parsing import UserTypeSpec, create_type, create_numeric_type from .exception import PsTypeError diff --git a/src/pystencils/types/meta.py b/src/pystencils/types/meta.py index 4a115ef97..915943cf8 100644 --- a/src/pystencils/types/meta.py +++ b/src/pystencils/types/meta.py @@ -1,3 +1,38 @@ +""" +Although mostly invisible to the user, types are ubiquitous throughout pystencils. +They are created and converted in many places, especially in the code generation backend. +To handle and compare types more efficiently, the pystencils type system implements +a uniquing mechanism to ensure that at any point there exists only one instance of each type. +This means, for example, if a 32-bit const unsigned integer type gets created in two places +at two different times in the program, the two types don't just behave identically, but +in fact refer to the same object: + +>>> from pystencils.types import PsUnsignedIntegerType +>>> t1 = PsUnsignedIntegerType(32, const=True) +>>> t2 = PsUnsignedIntegerType(32, const=True) +>>> t1 is t2 +True + +Both calls to `PsUnsignedIntegerType` return the same object. This is ensured by the +`PsTypeMeta` metaclass. +This metaclass holds an internal registry of all type objects ever created, +and alters the class instantiation mechanism such that whenever a type is instantiated +a second time with the same arguments, the pre-existing instance is found and returned instead. + +For this to work, all instantiable subclasses of `PsType` must implement the following protocol: + +- The ``const`` parameter must be the last keyword parameter of ``__init__``. +- The ``__canonical_args__`` classmethod must have the same signature as ``__init__``, except it does + not take the ``const`` parameter. It must return a tuple containing all the positional and keyword + arguments in their canonical order. This method is used by `PsTypeMeta` to identify instances of the type, + and to catch the various different possibilities Python offers for passing function arguments. +- The ``__args__`` method, when called on an instance of the type, must return a tuple containing the constructor + arguments required to create that exact instance. + +Developers intending to extend the type class hierarchy are advised to study the implementations +of this protocol in the existing classes. +""" + from __future__ import annotations from abc import ABCMeta, abstractmethod @@ -6,17 +41,26 @@ import numpy as np class PsTypeMeta(ABCMeta): + """Metaclass for the `PsType` hierarchy. + + `PsTypeMeta` holds an internal cache of all instances of `PsType` and overrides object creation + such that whenever a type gets instantiated more than once, instead of creating a new object, + the existing object is returned. + """ _instances: dict[Any, PsType] = dict() - def __call__(cls, *args: Any, const: bool = False, **kwargs: Any) -> Any: - obj = super(PsTypeMeta, cls).__call__(*args, const=const, **kwargs) - canonical_args = obj.__args__() + def __call__( + cls: PsTypeMeta, *args: Any, const: bool = False, **kwargs: Any + ) -> Any: + assert issubclass(cls, PsType) + canonical_args = cls.__canonical_args__(*args, **kwargs) key = (cls, canonical_args, const) if key in cls._instances: obj = cls._instances[key] else: + obj = super().__call__(*args, const=const, **kwargs) cls._instances[key] = obj return obj @@ -27,18 +71,19 @@ class PsType(metaclass=PsTypeMeta): Args: const: Const-qualification of this type + """ - **Implementation details for subclasses:** - `PsType` and its metaclass ``PsTypeMeta`` together implement a uniquing mechanism to ensure that of each type, - only one instance ever exists in the public. - For this to work, subclasses have to adhere to several rules: + def __new__(cls, *args, _pickle=False, **kwargs): + if _pickle: + # force unpickler to use metaclass uniquing mechanism + return cls(*args, **kwargs) + else: + return super().__new__(cls) - - All instances of `PsType` must be immutable. - - The `const` argument must be the last keyword argument to ``__init__`` and must be passed to the superclass - ``__init__``. - - The `__args__` method must return a tuple of positional arguments excluding the `const` property, - which, when passed to the class's constructor, create an identically-behaving instance. - """ + def __getnewargs_ex__(self): + args = self.__args__() + kwargs = {"const": self._const, "_pickle": True} + return args, kwargs def __init__(self, const: bool = False): self._const = const @@ -80,15 +125,21 @@ class PsType(metaclass=PsTypeMeta): """Arguments to this type, excluding the const-qualifier. The tuple returned by this method is used to serialize, deserialize, and check equality of types. - For each instantiable subclass ``MyType`` of ``PsType``, the following must hold: + For each instantiable subclass ``MyType`` of ``PsType``, the following must hold:: + + t = MyType(< arguments >) + assert MyType(*t.__args__()) == t - ``` - t = MyType(< arguments >) - assert MyType(*t.__args__()) == t - ``` """ pass + @classmethod + @abstractmethod + def __canonical_args__(cls, *args, **kwargs): + """Return a tuple containing the positional and keyword arguments of ``__init__`` + in their canonical order.""" + pass + def _const_string(self) -> str: return "const " if self._const else "" diff --git a/src/pystencils/types/parsing.py b/src/pystencils/types/parsing.py index 40f989a09..75fb35d22 100644 --- a/src/pystencils/types/parsing.py +++ b/src/pystencils/types/parsing.py @@ -4,11 +4,59 @@ from .types import ( PsType, PsPointerType, PsStructType, + PsNumericType, PsUnsignedIntegerType, PsSignedIntegerType, PsIeeeFloatType, ) +UserTypeSpec = str | type | np.dtype | PsType + + +def create_type(type_spec: UserTypeSpec) -> PsType: + """Create a pystencils type object from a variety of specifications. + + This function converts several possible representations of data types to an instance of `PsType`. + The ``type_spec`` argument can be any of the following: + + - Strings (`str`): will be parsed as common C types, throwing an exception if that fails. + To construct a `PsCustomType` instead, use the constructor of `PsCustomType` + or its abbreviation `types.quick.Custom`. + - Python builtin data types (instances of `type`): Attempts to interpret Python numeric types like so: + - `int` becomes a signed 64-bit integer + - `float` becomes a double-precision IEEE-754 float + - No others are supported at the moment + - Supported Numpy scalar data types (see https://numpy.org/doc/stable/reference/arrays.scalars.html) + are converted to pystencils scalar data types + - Instances of `numpy.dtype`: Attempt to interpret scalar types like above, and structured types as structs. + - Instances of `PsType` will be returned as they are + + Args: + type_spec: The data type, in one of the above formats + """ + + from .parsing import parse_type_string, interpret_python_type, interpret_numpy_dtype + + if isinstance(type_spec, PsType): + return type_spec + if isinstance(type_spec, str): + return parse_type_string(type_spec) + if isinstance(type_spec, type): + return interpret_python_type(type_spec) + if isinstance(type_spec, np.dtype): + return interpret_numpy_dtype(type_spec) + raise ValueError(f"{type_spec} is not a valid type specification.") + + +def create_numeric_type(type_spec: UserTypeSpec) -> PsNumericType: + """Like `create_type`, but only for numeric types.""" + dtype = create_type(type_spec) + if not isinstance(dtype, PsNumericType): + raise ValueError( + f"Given type {type_spec} does not translate to a numeric type." + ) + return dtype + def interpret_python_type(t: type) -> PsType: if t is int: diff --git a/src/pystencils/types/quick.py b/src/pystencils/types/quick.py index 1c44ba398..146528c65 100644 --- a/src/pystencils/types/quick.py +++ b/src/pystencils/types/quick.py @@ -2,12 +2,8 @@ from __future__ import annotations -import numpy as np - from .types import ( - PsType, PsCustomType, - PsNumericType, PsScalarType, PsBoolType, PsPointerType, @@ -18,54 +14,6 @@ from .types import ( PsIeeeFloatType, ) -UserTypeSpec = str | type | np.dtype | PsType - - -def create_type(type_spec: UserTypeSpec) -> PsType: - """Create a pystencils type object from a variety of specifications. - - This function converts several possible representations of data types to an instance of `PsType`. - The ``type_spec`` argument can be any of the following: - - - Strings (`str`): will be parsed as common C types, throwing an exception if that fails. - To construct a `PsCustomType` instead, use the constructor of `PsCustomType` - or its abbreviation `types.quick.Custom`. - - Python builtin data types (instances of `type`): Attempts to interpret Python numeric types like so: - - `int` becomes a signed 64-bit integer - - `float` becomes a double-precision IEEE-754 float - - No others are supported at the moment - - Supported Numpy scalar data types (see https://numpy.org/doc/stable/reference/arrays.scalars.html) - are converted to pystencils scalar data types - - Instances of `numpy.dtype`: Attempt to interpret scalar types like above, and structured types as structs. - - Instances of `PsType` will be returned as they are - - Args: - type_spec: The data type, in one of the above formats - """ - - from .parsing import parse_type_string, interpret_python_type, interpret_numpy_dtype - - if isinstance(type_spec, PsType): - return type_spec - if isinstance(type_spec, str): - return parse_type_string(type_spec) - if isinstance(type_spec, type): - return interpret_python_type(type_spec) - if isinstance(type_spec, np.dtype): - return interpret_numpy_dtype(type_spec) - raise ValueError(f"{type_spec} is not a valid type specification.") - - -def create_numeric_type(type_spec: UserTypeSpec) -> PsNumericType: - """Like `create_type`, but only for numeric types.""" - dtype = create_type(type_spec) - if not isinstance(dtype, PsNumericType): - raise ValueError( - f"Given type {type_spec} does not translate to a numeric type." - ) - return dtype - - Custom = PsCustomType """Custom data types are modelled only by their name.""" @@ -91,7 +39,7 @@ Int = PsSignedIntegerType """``Int(width)`` matches ``PsSignedIntegerType(width)``""" SInt = PsSignedIntegerType -"""``SInt(width)` matches `PsSignedIntegerType(width)``""" +"""``SInt(width)`` matches ``PsSignedIntegerType(width)``""" Fp = PsIeeeFloatType -"""``Fp(width)` matches `PsIeeeFloatType(width)``""" +"""``Fp(width)`` matches ``PsIeeeFloatType(width)``""" diff --git a/src/pystencils/types/types.py b/src/pystencils/types/types.py index b61c421d0..5499d3591 100644 --- a/src/pystencils/types/types.py +++ b/src/pystencils/types/types.py @@ -22,9 +22,9 @@ class PsCustomType(PsType): super().__init__(const) self._name = name - @property - def name(self) -> str: - return self._name + @classmethod + def __canonical_args__(cls, name: str): + return (name,) def __args__(self) -> tuple[Any, ...]: """ @@ -34,6 +34,10 @@ class PsCustomType(PsType): """ return (self._name,) + @property + def name(self) -> str: + return self._name + def c_string(self) -> str: return f"{self._const_string()} {self._name}" @@ -45,7 +49,7 @@ class PsDereferencableType(PsType, ABC): """Base class for subscriptable types. `PsDereferencableType` represents any type that may be dereferenced and may - occur as the base of a subscript, that is, before the C `[]` operator. + occur as the base of a subscript, that is, before the C ``[]`` operator. Args: base_type: The base type, which is the type of the object obtained by dereferencing. @@ -76,6 +80,10 @@ class PsPointerType(PsDereferencableType): super().__init__(base_type, const) self._restrict = restrict + @classmethod + def __canonical_args__(cls, base_type: PsType, restrict: bool = True): + return (base_type, restrict) + def __args__(self) -> tuple[Any, ...]: """ >>> t = PsPointerType(PsBoolType()) @@ -106,6 +114,10 @@ class PsArrayType(PsDereferencableType): self._length = length super().__init__(base_type, const) + @classmethod + def __canonical_args__(cls, base_type: PsType, length: int | None = None): + return (base_type, length) + def __args__(self) -> tuple[Any, ...]: """ >>> t = PsArrayType(PsBoolType(), 13) @@ -138,6 +150,13 @@ class PsStructType(PsType): name: str dtype: PsType + @staticmethod + def _canonical_members(members: Sequence[PsStructType.Member | tuple[str, PsType]]): + return tuple( + (PsStructType.Member(m[0], m[1]) if isinstance(m, tuple) else m) + for m in members + ) + def __init__( self, members: Sequence[PsStructType.Member | tuple[str, PsType]], @@ -147,10 +166,7 @@ class PsStructType(PsType): super().__init__(const=const) self._name = name - self._members = tuple( - (PsStructType.Member(m[0], m[1]) if isinstance(m, tuple) else m) - for m in members - ) + self._members = self._canonical_members(members) names: set[str] = set() for member in self._members: @@ -158,6 +174,14 @@ class PsStructType(PsType): raise ValueError(f"Duplicate struct member name: {member.name}") names.add(member.name) + @classmethod + def __canonical_args__( + cls, + members: Sequence[PsStructType.Member | tuple[str, PsType]], + name: str | None = None, + ): + return (cls._canonical_members(members), name) + def __args__(self) -> tuple[Any, ...]: """ >>> t = PsStructType([("idx", PsSignedIntegerType(32)), ("val", PsBoolType())], "sname") @@ -310,6 +334,10 @@ class PsVectorType(PsNumericType): self._vector_entries = vector_entries self._scalar_type = constify(scalar_type) if const else deconstify(scalar_type) + @classmethod + def __canonical_args__(cls, scalar_type: PsScalarType, vector_entries: int): + return (scalar_type, vector_entries) + def __args__(self) -> tuple[Any, ...]: """ >>> t = PsVectorType(PsBoolType(), 8) @@ -390,6 +418,10 @@ class PsBoolType(PsScalarType): def __init__(self, const: bool = False): super().__init__(const) + @classmethod + def __canonical_args__(cls, *args, **kwargs): + return () + def __args__(self) -> tuple[Any, ...]: """ >>> t = PsBoolType() @@ -519,6 +551,10 @@ class PsSignedIntegerType(PsIntegerType): def __init__(self, width: int, const: bool = False): super().__init__(width, True, const) + @classmethod + def __canonical_args__(cls, width: int): + return (width,) + def __args__(self) -> tuple[Any, ...]: """ >>> t = PsSignedIntegerType(32) @@ -544,6 +580,10 @@ class PsUnsignedIntegerType(PsIntegerType): def __init__(self, width: int, const: bool = False): super().__init__(width, False, const) + @classmethod + def __canonical_args__(cls, width: int): + return (width,) + def __args__(self) -> tuple[Any, ...]: """ >>> t = PsUnsignedIntegerType(32) @@ -576,6 +616,10 @@ class PsIeeeFloatType(PsScalarType): super().__init__(const) self._width = width + @classmethod + def __canonical_args__(cls, width: int): + return (width,) + def __args__(self) -> tuple[Any, ...]: """ >>> t = PsIeeeFloatType(32) diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py index abe22ccc1..adca2245b 100644 --- a/tests/nbackend/kernelcreation/test_typification.py +++ b/tests/nbackend/kernelcreation/test_typification.py @@ -9,8 +9,8 @@ from pystencils import Assignment, TypedSymbol, Field, FieldType from pystencils.backend.ast.structural import PsDeclaration, PsAssignment, PsExpression from pystencils.backend.ast.expressions import PsConstantExpr, PsSymbolExpr, PsBinOp from pystencils.backend.constants import PsConstant -from pystencils.types import constify -from pystencils.types.quick import Fp, create_type, create_numeric_type +from pystencils.types import constify, create_type, create_numeric_type +from pystencils.types.quick import Fp from pystencils.backend.kernelcreation.context import KernelCreationContext from pystencils.backend.kernelcreation.freeze import FreezeExpressions from pystencils.backend.kernelcreation.typification import Typifier, TypificationError diff --git a/tests/nbackend/types/test_types.py b/tests/nbackend/types/test_types.py index 74467080a..65cbf9d08 100644 --- a/tests/nbackend/types/test_types.py +++ b/tests/nbackend/types/test_types.py @@ -1,5 +1,6 @@ import pytest import numpy as np +import pickle from pystencils.backend.exceptions import PsInternalCompilerError from pystencils.types import * @@ -22,10 +23,16 @@ def test_parsing_positive(): assert create_type("const uint32_t * restrict") is Ptr( UInt(32, const=True), restrict=True ) - assert create_type("float * * const") is Ptr(Ptr(Fp(32), restrict=False), const=True, restrict=False) - assert create_type("float * * restrict const") is Ptr(Ptr(Fp(32), restrict=False), const=True, restrict=True) + assert create_type("float * * const") is Ptr( + Ptr(Fp(32), restrict=False), const=True, restrict=False + ) + assert create_type("float * * restrict const") is Ptr( + Ptr(Fp(32), restrict=False), const=True, restrict=True + ) assert create_type("uint16 * const") is Ptr(UInt(16), const=True, restrict=False) - assert create_type("uint64 const * const") is Ptr(UInt(64, const=True), const=True, restrict=False) + assert create_type("uint64 const * const") is Ptr( + UInt(64, const=True), const=True, restrict=False + ) def test_parsing_negative(): @@ -104,7 +111,7 @@ def test_constify(): t = PsCustomType("std::shared_ptr< Custom >") assert deconstify(t) is t assert deconstify(constify(t)) is t - + s = PsCustomType("Field", const=True) assert constify(s) is s @@ -131,3 +138,29 @@ def test_struct_types(): assert str(t) == "<anonymous>" with pytest.raises(PsTypeError): t.c_string() + + +def test_pickle(): + types = [ + Bool(const=True), + Bool(const=False), + Custom("std::vector< uint_t >", const=False), + Ptr(Fp(32, const=False), restrict=True, const=True), + SInt(32, const=True), + SInt(16, const=False), + UInt(8, const=False), + UInt(width=16, const=False), + Int(width=32, const=False), + Fp(width=16, const=True), + PsStructType([("x", UInt(32)), ("y", UInt(32)), ("val", Fp(64))], "myStruct"), + PsStructType([("data", Fp(32))], "None"), + PsArrayType(Fp(16, const=True), 42), + PsArrayType(PsVectorType(Fp(32), 8, const=False), 42) + ] + + dumped = pickle.dumps(types) + restored = pickle.loads(dumped) + + for t1, t2 in zip(types, restored): + assert t1 == t2 + assert t1 is t2 -- GitLab