diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 664f205703a3d5799b003f6b222786d26e947f96..dbe3ae0eb892449efee86123de62ca94b6c28c7e 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -330,6 +330,18 @@ nbackend-unit-tests: tags: - docker +doctest: + stage: "Unit Tests" + needs: [] + image: i10git.cs.fau.de:5005/pycodegen/pycodegen/full + before_script: + - pip install -e .[tests] + script: + - pytest src/pystencils/backend + - pytest src/pystencils/types + tags: + - docker + # -------------------- Documentation --------------------------------------------------------------------- diff --git a/docs/source/api/types.rst b/docs/source/api/types.rst index 1794fd8568fbf547f2cfea71af0a1126342ab576..ce08697e20d327073ed7c2f450989ecce21b7a63 100644 --- a/docs/source/api/types.rst +++ b/docs/source/api/types.rst @@ -1,21 +1,26 @@ -********** -Data Types -********** - -Type System Module (pystencils.types) -------------------------------------- +*********** +Type System +*********** .. automodule:: pystencils.types +Basic Functions +------------------------------------- + .. autofunction:: pystencils.types.create_type .. autofunction:: pystencils.types.create_numeric_type +.. autofunction:: pystencils.types.constify +.. autofunction:: pystencils.types.deconstify Data Type Class Hierarchy ------------------------- -.. automodule:: pystencils.types.basic_types +.. autoclass:: pystencils.types.PsType + :members: + +.. automodule:: pystencils.types.types :members: @@ -23,4 +28,17 @@ Data Type Abbreviations ----------------------- .. automodule:: pystencils.types.quick - :members: \ No newline at end of file + :members: + + +Metaclass, Base Class and Uniquing Mechanisms +--------------------------------------------- + +.. automodule:: pystencils.types.meta + +.. autoclass:: pystencils.types.meta.PsTypeMeta + :members: + +.. autofunction:: pystencils.types.PsType.__args__ + +.. autofunction:: pystencils.types.PsType.__canonical_args__ diff --git a/src/pystencils/field.py b/src/pystencils/field.py index 2a82b8a43ebf0f0acb78d267b1aeced6b06ce3cd..ac0aa3ae0c639ad1ace6d7827933ea7a1c9e52d3 100644 --- a/src/pystencils/field.py +++ b/src/pystencils/field.py @@ -17,7 +17,7 @@ from pystencils.stencil import direction_string_to_offset, inverse_direction, of from pystencils.types import PsType, PsStructType, create_type from pystencils.sympyextensions.typed_sympy import (FieldShapeSymbol, FieldStrideSymbol, TypedSymbol) from pystencils.sympyextensions.math import is_integer_sequence -from pystencils.types.quick import UserTypeSpec +from pystencils.types import UserTypeSpec __all__ = ['Field', 'fields', 'FieldType', 'Field'] diff --git a/src/pystencils/sympyextensions/typed_sympy.py b/src/pystencils/sympyextensions/typed_sympy.py index 71caf5a2fb60528c9db6ce46259ad24624e8effd..e022db511ed9e637d0a4c2eea31d62a2214dd9ca 100644 --- a/src/pystencils/sympyextensions/typed_sympy.py +++ b/src/pystencils/sympyextensions/typed_sympy.py @@ -1,7 +1,6 @@ import sympy as sp -from ..types import PsType, PsNumericType, PsPointerType, PsBoolType -from ..types.quick import create_type +from ..types import PsType, PsNumericType, PsPointerType, PsBoolType, create_type def assumptions_from_dtype(dtype: PsType): diff --git a/src/pystencils/types/__init__.py b/src/pystencils/types/__init__.py index 05e14fa7083bf501aa267a42f6a1e6a313d90d64..5f433583936f89033920e43116c5d05a38833512 100644 --- a/src/pystencils/types/__init__.py +++ b/src/pystencils/types/__init__.py @@ -4,8 +4,12 @@ to model data types. Data types are used extensively within the code generator, but can safely be ignored by most users unless you wish to force certain types on symbols, generate mixed-precision kernels, et cetera. -For more user-friendly and less verbose access to the type modelling system, refer to -the `pystencils.types.quick` submodule. +The various classes that constitute the pystencils type system are implemented in +`pystencils.types.types`; most have abbreviated names defined in `pystencils.types.quick`. + +For more information about the type system's internal workings, and developer's guidance on +how to extend it, refer to `pystencils.types.meta`. + """ from .meta import PsType, constify, deconstify @@ -26,7 +30,7 @@ from .types import ( PsIeeeFloatType, ) -from .quick import UserTypeSpec, create_type, create_numeric_type +from .parsing import UserTypeSpec, create_type, create_numeric_type from .exception import PsTypeError diff --git a/src/pystencils/types/meta.py b/src/pystencils/types/meta.py index 4a115ef97ee3526d3bf273dcffbca9db51006e1a..915943cf8d730c9925c72e5f4896dd9591321d7a 100644 --- a/src/pystencils/types/meta.py +++ b/src/pystencils/types/meta.py @@ -1,3 +1,38 @@ +""" +Although mostly invisible to the user, types are ubiquitous throughout pystencils. +They are created and converted in many places, especially in the code generation backend. +To handle and compare types more efficiently, the pystencils type system implements +a uniquing mechanism to ensure that at any point there exists only one instance of each type. +This means, for example, if a 32-bit const unsigned integer type gets created in two places +at two different times in the program, the two types don't just behave identically, but +in fact refer to the same object: + +>>> from pystencils.types import PsUnsignedIntegerType +>>> t1 = PsUnsignedIntegerType(32, const=True) +>>> t2 = PsUnsignedIntegerType(32, const=True) +>>> t1 is t2 +True + +Both calls to `PsUnsignedIntegerType` return the same object. This is ensured by the +`PsTypeMeta` metaclass. +This metaclass holds an internal registry of all type objects ever created, +and alters the class instantiation mechanism such that whenever a type is instantiated +a second time with the same arguments, the pre-existing instance is found and returned instead. + +For this to work, all instantiable subclasses of `PsType` must implement the following protocol: + +- The ``const`` parameter must be the last keyword parameter of ``__init__``. +- The ``__canonical_args__`` classmethod must have the same signature as ``__init__``, except it does + not take the ``const`` parameter. It must return a tuple containing all the positional and keyword + arguments in their canonical order. This method is used by `PsTypeMeta` to identify instances of the type, + and to catch the various different possibilities Python offers for passing function arguments. +- The ``__args__`` method, when called on an instance of the type, must return a tuple containing the constructor + arguments required to create that exact instance. + +Developers intending to extend the type class hierarchy are advised to study the implementations +of this protocol in the existing classes. +""" + from __future__ import annotations from abc import ABCMeta, abstractmethod @@ -6,17 +41,26 @@ import numpy as np class PsTypeMeta(ABCMeta): + """Metaclass for the `PsType` hierarchy. + + `PsTypeMeta` holds an internal cache of all instances of `PsType` and overrides object creation + such that whenever a type gets instantiated more than once, instead of creating a new object, + the existing object is returned. + """ _instances: dict[Any, PsType] = dict() - def __call__(cls, *args: Any, const: bool = False, **kwargs: Any) -> Any: - obj = super(PsTypeMeta, cls).__call__(*args, const=const, **kwargs) - canonical_args = obj.__args__() + def __call__( + cls: PsTypeMeta, *args: Any, const: bool = False, **kwargs: Any + ) -> Any: + assert issubclass(cls, PsType) + canonical_args = cls.__canonical_args__(*args, **kwargs) key = (cls, canonical_args, const) if key in cls._instances: obj = cls._instances[key] else: + obj = super().__call__(*args, const=const, **kwargs) cls._instances[key] = obj return obj @@ -27,18 +71,19 @@ class PsType(metaclass=PsTypeMeta): Args: const: Const-qualification of this type + """ - **Implementation details for subclasses:** - `PsType` and its metaclass ``PsTypeMeta`` together implement a uniquing mechanism to ensure that of each type, - only one instance ever exists in the public. - For this to work, subclasses have to adhere to several rules: + def __new__(cls, *args, _pickle=False, **kwargs): + if _pickle: + # force unpickler to use metaclass uniquing mechanism + return cls(*args, **kwargs) + else: + return super().__new__(cls) - - All instances of `PsType` must be immutable. - - The `const` argument must be the last keyword argument to ``__init__`` and must be passed to the superclass - ``__init__``. - - The `__args__` method must return a tuple of positional arguments excluding the `const` property, - which, when passed to the class's constructor, create an identically-behaving instance. - """ + def __getnewargs_ex__(self): + args = self.__args__() + kwargs = {"const": self._const, "_pickle": True} + return args, kwargs def __init__(self, const: bool = False): self._const = const @@ -80,15 +125,21 @@ class PsType(metaclass=PsTypeMeta): """Arguments to this type, excluding the const-qualifier. The tuple returned by this method is used to serialize, deserialize, and check equality of types. - For each instantiable subclass ``MyType`` of ``PsType``, the following must hold: + For each instantiable subclass ``MyType`` of ``PsType``, the following must hold:: + + t = MyType(< arguments >) + assert MyType(*t.__args__()) == t - ``` - t = MyType(< arguments >) - assert MyType(*t.__args__()) == t - ``` """ pass + @classmethod + @abstractmethod + def __canonical_args__(cls, *args, **kwargs): + """Return a tuple containing the positional and keyword arguments of ``__init__`` + in their canonical order.""" + pass + def _const_string(self) -> str: return "const " if self._const else "" diff --git a/src/pystencils/types/parsing.py b/src/pystencils/types/parsing.py index 40f989a09d35e2a530b0aacef9a212b3044fe738..75fb35d223a9dbc1cfd1090a661e0c61a5335cf8 100644 --- a/src/pystencils/types/parsing.py +++ b/src/pystencils/types/parsing.py @@ -4,11 +4,59 @@ from .types import ( PsType, PsPointerType, PsStructType, + PsNumericType, PsUnsignedIntegerType, PsSignedIntegerType, PsIeeeFloatType, ) +UserTypeSpec = str | type | np.dtype | PsType + + +def create_type(type_spec: UserTypeSpec) -> PsType: + """Create a pystencils type object from a variety of specifications. + + This function converts several possible representations of data types to an instance of `PsType`. + The ``type_spec`` argument can be any of the following: + + - Strings (`str`): will be parsed as common C types, throwing an exception if that fails. + To construct a `PsCustomType` instead, use the constructor of `PsCustomType` + or its abbreviation `types.quick.Custom`. + - Python builtin data types (instances of `type`): Attempts to interpret Python numeric types like so: + - `int` becomes a signed 64-bit integer + - `float` becomes a double-precision IEEE-754 float + - No others are supported at the moment + - Supported Numpy scalar data types (see https://numpy.org/doc/stable/reference/arrays.scalars.html) + are converted to pystencils scalar data types + - Instances of `numpy.dtype`: Attempt to interpret scalar types like above, and structured types as structs. + - Instances of `PsType` will be returned as they are + + Args: + type_spec: The data type, in one of the above formats + """ + + from .parsing import parse_type_string, interpret_python_type, interpret_numpy_dtype + + if isinstance(type_spec, PsType): + return type_spec + if isinstance(type_spec, str): + return parse_type_string(type_spec) + if isinstance(type_spec, type): + return interpret_python_type(type_spec) + if isinstance(type_spec, np.dtype): + return interpret_numpy_dtype(type_spec) + raise ValueError(f"{type_spec} is not a valid type specification.") + + +def create_numeric_type(type_spec: UserTypeSpec) -> PsNumericType: + """Like `create_type`, but only for numeric types.""" + dtype = create_type(type_spec) + if not isinstance(dtype, PsNumericType): + raise ValueError( + f"Given type {type_spec} does not translate to a numeric type." + ) + return dtype + def interpret_python_type(t: type) -> PsType: if t is int: diff --git a/src/pystencils/types/quick.py b/src/pystencils/types/quick.py index 1c44ba3980af5cafd746a9b959602a5506e002ad..146528c65ebfe66074b6412002db4ca151fde898 100644 --- a/src/pystencils/types/quick.py +++ b/src/pystencils/types/quick.py @@ -2,12 +2,8 @@ from __future__ import annotations -import numpy as np - from .types import ( - PsType, PsCustomType, - PsNumericType, PsScalarType, PsBoolType, PsPointerType, @@ -18,54 +14,6 @@ from .types import ( PsIeeeFloatType, ) -UserTypeSpec = str | type | np.dtype | PsType - - -def create_type(type_spec: UserTypeSpec) -> PsType: - """Create a pystencils type object from a variety of specifications. - - This function converts several possible representations of data types to an instance of `PsType`. - The ``type_spec`` argument can be any of the following: - - - Strings (`str`): will be parsed as common C types, throwing an exception if that fails. - To construct a `PsCustomType` instead, use the constructor of `PsCustomType` - or its abbreviation `types.quick.Custom`. - - Python builtin data types (instances of `type`): Attempts to interpret Python numeric types like so: - - `int` becomes a signed 64-bit integer - - `float` becomes a double-precision IEEE-754 float - - No others are supported at the moment - - Supported Numpy scalar data types (see https://numpy.org/doc/stable/reference/arrays.scalars.html) - are converted to pystencils scalar data types - - Instances of `numpy.dtype`: Attempt to interpret scalar types like above, and structured types as structs. - - Instances of `PsType` will be returned as they are - - Args: - type_spec: The data type, in one of the above formats - """ - - from .parsing import parse_type_string, interpret_python_type, interpret_numpy_dtype - - if isinstance(type_spec, PsType): - return type_spec - if isinstance(type_spec, str): - return parse_type_string(type_spec) - if isinstance(type_spec, type): - return interpret_python_type(type_spec) - if isinstance(type_spec, np.dtype): - return interpret_numpy_dtype(type_spec) - raise ValueError(f"{type_spec} is not a valid type specification.") - - -def create_numeric_type(type_spec: UserTypeSpec) -> PsNumericType: - """Like `create_type`, but only for numeric types.""" - dtype = create_type(type_spec) - if not isinstance(dtype, PsNumericType): - raise ValueError( - f"Given type {type_spec} does not translate to a numeric type." - ) - return dtype - - Custom = PsCustomType """Custom data types are modelled only by their name.""" @@ -91,7 +39,7 @@ Int = PsSignedIntegerType """``Int(width)`` matches ``PsSignedIntegerType(width)``""" SInt = PsSignedIntegerType -"""``SInt(width)` matches `PsSignedIntegerType(width)``""" +"""``SInt(width)`` matches ``PsSignedIntegerType(width)``""" Fp = PsIeeeFloatType -"""``Fp(width)` matches `PsIeeeFloatType(width)``""" +"""``Fp(width)`` matches ``PsIeeeFloatType(width)``""" diff --git a/src/pystencils/types/types.py b/src/pystencils/types/types.py index b61c421d0a2de7ec2104342f746dc921f5fe3fd9..5499d359158f2d33669edf977abf628c800919ed 100644 --- a/src/pystencils/types/types.py +++ b/src/pystencils/types/types.py @@ -22,9 +22,9 @@ class PsCustomType(PsType): super().__init__(const) self._name = name - @property - def name(self) -> str: - return self._name + @classmethod + def __canonical_args__(cls, name: str): + return (name,) def __args__(self) -> tuple[Any, ...]: """ @@ -34,6 +34,10 @@ class PsCustomType(PsType): """ return (self._name,) + @property + def name(self) -> str: + return self._name + def c_string(self) -> str: return f"{self._const_string()} {self._name}" @@ -45,7 +49,7 @@ class PsDereferencableType(PsType, ABC): """Base class for subscriptable types. `PsDereferencableType` represents any type that may be dereferenced and may - occur as the base of a subscript, that is, before the C `[]` operator. + occur as the base of a subscript, that is, before the C ``[]`` operator. Args: base_type: The base type, which is the type of the object obtained by dereferencing. @@ -76,6 +80,10 @@ class PsPointerType(PsDereferencableType): super().__init__(base_type, const) self._restrict = restrict + @classmethod + def __canonical_args__(cls, base_type: PsType, restrict: bool = True): + return (base_type, restrict) + def __args__(self) -> tuple[Any, ...]: """ >>> t = PsPointerType(PsBoolType()) @@ -106,6 +114,10 @@ class PsArrayType(PsDereferencableType): self._length = length super().__init__(base_type, const) + @classmethod + def __canonical_args__(cls, base_type: PsType, length: int | None = None): + return (base_type, length) + def __args__(self) -> tuple[Any, ...]: """ >>> t = PsArrayType(PsBoolType(), 13) @@ -138,6 +150,13 @@ class PsStructType(PsType): name: str dtype: PsType + @staticmethod + def _canonical_members(members: Sequence[PsStructType.Member | tuple[str, PsType]]): + return tuple( + (PsStructType.Member(m[0], m[1]) if isinstance(m, tuple) else m) + for m in members + ) + def __init__( self, members: Sequence[PsStructType.Member | tuple[str, PsType]], @@ -147,10 +166,7 @@ class PsStructType(PsType): super().__init__(const=const) self._name = name - self._members = tuple( - (PsStructType.Member(m[0], m[1]) if isinstance(m, tuple) else m) - for m in members - ) + self._members = self._canonical_members(members) names: set[str] = set() for member in self._members: @@ -158,6 +174,14 @@ class PsStructType(PsType): raise ValueError(f"Duplicate struct member name: {member.name}") names.add(member.name) + @classmethod + def __canonical_args__( + cls, + members: Sequence[PsStructType.Member | tuple[str, PsType]], + name: str | None = None, + ): + return (cls._canonical_members(members), name) + def __args__(self) -> tuple[Any, ...]: """ >>> t = PsStructType([("idx", PsSignedIntegerType(32)), ("val", PsBoolType())], "sname") @@ -310,6 +334,10 @@ class PsVectorType(PsNumericType): self._vector_entries = vector_entries self._scalar_type = constify(scalar_type) if const else deconstify(scalar_type) + @classmethod + def __canonical_args__(cls, scalar_type: PsScalarType, vector_entries: int): + return (scalar_type, vector_entries) + def __args__(self) -> tuple[Any, ...]: """ >>> t = PsVectorType(PsBoolType(), 8) @@ -390,6 +418,10 @@ class PsBoolType(PsScalarType): def __init__(self, const: bool = False): super().__init__(const) + @classmethod + def __canonical_args__(cls, *args, **kwargs): + return () + def __args__(self) -> tuple[Any, ...]: """ >>> t = PsBoolType() @@ -519,6 +551,10 @@ class PsSignedIntegerType(PsIntegerType): def __init__(self, width: int, const: bool = False): super().__init__(width, True, const) + @classmethod + def __canonical_args__(cls, width: int): + return (width,) + def __args__(self) -> tuple[Any, ...]: """ >>> t = PsSignedIntegerType(32) @@ -544,6 +580,10 @@ class PsUnsignedIntegerType(PsIntegerType): def __init__(self, width: int, const: bool = False): super().__init__(width, False, const) + @classmethod + def __canonical_args__(cls, width: int): + return (width,) + def __args__(self) -> tuple[Any, ...]: """ >>> t = PsUnsignedIntegerType(32) @@ -576,6 +616,10 @@ class PsIeeeFloatType(PsScalarType): super().__init__(const) self._width = width + @classmethod + def __canonical_args__(cls, width: int): + return (width,) + def __args__(self) -> tuple[Any, ...]: """ >>> t = PsIeeeFloatType(32) diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py index abe22ccc1ce32a38cd26812cb9e8710ea4cbfbbf..adca2245b02feeac286a4507c27b7fb570620af8 100644 --- a/tests/nbackend/kernelcreation/test_typification.py +++ b/tests/nbackend/kernelcreation/test_typification.py @@ -9,8 +9,8 @@ from pystencils import Assignment, TypedSymbol, Field, FieldType from pystencils.backend.ast.structural import PsDeclaration, PsAssignment, PsExpression from pystencils.backend.ast.expressions import PsConstantExpr, PsSymbolExpr, PsBinOp from pystencils.backend.constants import PsConstant -from pystencils.types import constify -from pystencils.types.quick import Fp, create_type, create_numeric_type +from pystencils.types import constify, create_type, create_numeric_type +from pystencils.types.quick import Fp from pystencils.backend.kernelcreation.context import KernelCreationContext from pystencils.backend.kernelcreation.freeze import FreezeExpressions from pystencils.backend.kernelcreation.typification import Typifier, TypificationError diff --git a/tests/nbackend/types/test_types.py b/tests/nbackend/types/test_types.py index 74467080a5bfaf18fcc9441d8727fa42c1c159a8..65cbf9d08aec19c857912af3af98069d09801ffa 100644 --- a/tests/nbackend/types/test_types.py +++ b/tests/nbackend/types/test_types.py @@ -1,5 +1,6 @@ import pytest import numpy as np +import pickle from pystencils.backend.exceptions import PsInternalCompilerError from pystencils.types import * @@ -22,10 +23,16 @@ def test_parsing_positive(): assert create_type("const uint32_t * restrict") is Ptr( UInt(32, const=True), restrict=True ) - assert create_type("float * * const") is Ptr(Ptr(Fp(32), restrict=False), const=True, restrict=False) - assert create_type("float * * restrict const") is Ptr(Ptr(Fp(32), restrict=False), const=True, restrict=True) + assert create_type("float * * const") is Ptr( + Ptr(Fp(32), restrict=False), const=True, restrict=False + ) + assert create_type("float * * restrict const") is Ptr( + Ptr(Fp(32), restrict=False), const=True, restrict=True + ) assert create_type("uint16 * const") is Ptr(UInt(16), const=True, restrict=False) - assert create_type("uint64 const * const") is Ptr(UInt(64, const=True), const=True, restrict=False) + assert create_type("uint64 const * const") is Ptr( + UInt(64, const=True), const=True, restrict=False + ) def test_parsing_negative(): @@ -104,7 +111,7 @@ def test_constify(): t = PsCustomType("std::shared_ptr< Custom >") assert deconstify(t) is t assert deconstify(constify(t)) is t - + s = PsCustomType("Field", const=True) assert constify(s) is s @@ -131,3 +138,29 @@ def test_struct_types(): assert str(t) == "<anonymous>" with pytest.raises(PsTypeError): t.c_string() + + +def test_pickle(): + types = [ + Bool(const=True), + Bool(const=False), + Custom("std::vector< uint_t >", const=False), + Ptr(Fp(32, const=False), restrict=True, const=True), + SInt(32, const=True), + SInt(16, const=False), + UInt(8, const=False), + UInt(width=16, const=False), + Int(width=32, const=False), + Fp(width=16, const=True), + PsStructType([("x", UInt(32)), ("y", UInt(32)), ("val", Fp(64))], "myStruct"), + PsStructType([("data", Fp(32))], "None"), + PsArrayType(Fp(16, const=True), 42), + PsArrayType(PsVectorType(Fp(32), 8, const=False), 42) + ] + + dumped = pickle.dumps(types) + restored = pickle.loads(dumped) + + for t1, t2 in zip(types, restored): + assert t1 == t2 + assert t1 is t2