diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 664f205703a3d5799b003f6b222786d26e947f96..dbe3ae0eb892449efee86123de62ca94b6c28c7e 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -330,6 +330,18 @@ nbackend-unit-tests: tags: - docker +doctest: + stage: "Unit Tests" + needs: [] + image: i10git.cs.fau.de:5005/pycodegen/pycodegen/full + before_script: + - pip install -e .[tests] + script: + - pytest src/pystencils/backend + - pytest src/pystencils/types + tags: + - docker + # -------------------- Documentation --------------------------------------------------------------------- diff --git a/docs/source/api/types.rst b/docs/source/api/types.rst index 1794fd8568fbf547f2cfea71af0a1126342ab576..5a740c05818e632abfa1f7ca810d275c48ead77a 100644 --- a/docs/source/api/types.rst +++ b/docs/source/api/types.rst @@ -1,21 +1,30 @@ -********** -Data Types -********** - -Type System Module (pystencils.types) -------------------------------------- +*********** +Type System +*********** .. automodule:: pystencils.types + +Basic Functions +------------------------------------- + .. autofunction:: pystencils.types.create_type .. autofunction:: pystencils.types.create_numeric_type - +.. autofunction:: pystencils.types.constify +.. autofunction:: pystencils.types.deconstify Data Type Class Hierarchy ------------------------- -.. automodule:: pystencils.types.basic_types +.. inheritance-diagram:: pystencils.types.meta.PsType pystencils.types.types + :top-classes: pystencils.types.PsType + :parts: 1 + +.. autoclass:: pystencils.types.PsType + :members: + +.. automodule:: pystencils.types.types :members: @@ -23,4 +32,10 @@ Data Type Abbreviations ----------------------- .. automodule:: pystencils.types.quick - :members: \ No newline at end of file + :members: + + +Implementation Details +---------------------- + +.. automodule:: pystencils.types.meta diff --git a/docs/source/conf.py b/docs/source/conf.py index 45546d9b7faf8f565330bf3e39ed29eeb9498363..4f68a697eaf628e3c83dd844ae2d382c3b9eeb74 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -29,6 +29,7 @@ extensions = [ "sphinx.ext.intersphinx", "sphinx.ext.mathjax", "sphinx.ext.napoleon", + "sphinx.ext.inheritance_diagram", "nbsphinx", "sphinxcontrib.bibtex", "sphinx_autodoc_typehints", @@ -47,6 +48,12 @@ intersphinx_mapping = { "sympy": ("https://docs.sympy.org/latest/", None), } +# -- Options for inheritance diagrams----------------------------------------- + +inheritance_graph_attrs = { + "bgcolor": "white", +} + # -- Options for HTML output ------------------------------------------------- html_theme = "furo" diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 034675e6209c4eee01efcfe5e35984e7140f1780..1bf3c49807ff52a34bc9ab319f5da67e4fa59ebc 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -408,7 +408,7 @@ class Typifier: f"Unable to determine type of argument to AddressOf: {arg}" ) - ptr_type = PsPointerType(arg_tc.target_type, True) + ptr_type = PsPointerType(arg_tc.target_type, const=True) tc.apply_dtype(ptr_type, expr) case PsLookup(aggr, member_name): diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py index ef4861aa7dd3b7e92275cb61a651e7f3bb0875c4..839cd34f4c9fada060a0ac6253b635f7c7812948 100644 --- a/src/pystencils/backend/platforms/generic_gpu.py +++ b/src/pystencils/backend/platforms/generic_gpu.py @@ -1,5 +1,5 @@ from pystencils.backend.functions import CFunction, PsMathFunction -from pystencils.types.basic_types import PsType +from pystencils.types.types import PsType from .platform import Platform from ..kernelcreation.iteration_space import ( diff --git a/src/pystencils/field.py b/src/pystencils/field.py index 2a82b8a43ebf0f0acb78d267b1aeced6b06ce3cd..ac0aa3ae0c639ad1ace6d7827933ea7a1c9e52d3 100644 --- a/src/pystencils/field.py +++ b/src/pystencils/field.py @@ -17,7 +17,7 @@ from pystencils.stencil import direction_string_to_offset, inverse_direction, of from pystencils.types import PsType, PsStructType, create_type from pystencils.sympyextensions.typed_sympy import (FieldShapeSymbol, FieldStrideSymbol, TypedSymbol) from pystencils.sympyextensions.math import is_integer_sequence -from pystencils.types.quick import UserTypeSpec +from pystencils.types import UserTypeSpec __all__ = ['Field', 'fields', 'FieldType', 'Field'] diff --git a/src/pystencils/sympyextensions/pointers.py b/src/pystencils/sympyextensions/pointers.py index 130338c99583bc4874b53cc804f4ff4761f0e580..a814f941e0a2968be7fedfbb82bff612ae8f1d1a 100644 --- a/src/pystencils/sympyextensions/pointers.py +++ b/src/pystencils/sympyextensions/pointers.py @@ -26,6 +26,6 @@ class AddressOf(sp.Function): @property def dtype(self): if hasattr(self.args[0], 'dtype'): - return PsPointerType(self.args[0].dtype, const=True, restrict=True) + return PsPointerType(self.args[0].dtype, restrict=True, const=True) else: raise ValueError(f'pystencils supports only non void pointers. Current address_of type: {self.args[0]}') diff --git a/src/pystencils/sympyextensions/typed_sympy.py b/src/pystencils/sympyextensions/typed_sympy.py index 541f9aed707b6f0ead80843f3c9bde27e24b0545..e022db511ed9e637d0a4c2eea31d62a2214dd9ca 100644 --- a/src/pystencils/sympyextensions/typed_sympy.py +++ b/src/pystencils/sympyextensions/typed_sympy.py @@ -1,7 +1,6 @@ import sympy as sp -from ..types import PsType, PsNumericType, PsPointerType, PsBoolType -from ..types.quick import create_type +from ..types import PsType, PsNumericType, PsPointerType, PsBoolType, create_type def assumptions_from_dtype(dtype: PsType): @@ -172,7 +171,7 @@ class FieldPointerSymbol(TypedSymbol): def __new_stage2__(cls, field_name, field_dtype: PsType, const: bool): name = f"_data_{field_name}" - dtype = PsPointerType(field_dtype, const=const, restrict=True) + dtype = PsPointerType(field_dtype, restrict=True, const=const) obj = super(FieldPointerSymbol, cls).__xnew__(cls, name, dtype) obj.field_name = field_name return obj diff --git a/src/pystencils/types/__init__.py b/src/pystencils/types/__init__.py index 2c7bac59bf17318f86dd2d24352ead9cde4ee163..e9b67096baf7dca5d63e4b9ab57a85a8195bc51e 100644 --- a/src/pystencils/types/__init__.py +++ b/src/pystencils/types/__init__.py @@ -3,13 +3,11 @@ The `pystencils.types` module contains the set of classes used by pystencils to model data types. Data types are used extensively within the code generator, but can safely be ignored by most users unless you wish to force certain types on symbols, generate mixed-precision kernels, et cetera. - -For more user-friendly and less verbose access to the type modelling system, refer to -the `pystencils.types.quick` submodule. """ -from .basic_types import ( - PsType, +from .meta import PsType, constify, deconstify + +from .types import ( PsCustomType, PsStructType, PsNumericType, @@ -23,11 +21,9 @@ from .basic_types import ( PsUnsignedIntegerType, PsSignedIntegerType, PsIeeeFloatType, - constify, - deconstify, ) -from .quick import UserTypeSpec, create_type, create_numeric_type +from .parsing import UserTypeSpec, create_type, create_numeric_type from .exception import PsTypeError diff --git a/src/pystencils/types/meta.py b/src/pystencils/types/meta.py new file mode 100644 index 0000000000000000000000000000000000000000..1d605edf89d7f30169dbc2601c85fa4787bb3128 --- /dev/null +++ b/src/pystencils/types/meta.py @@ -0,0 +1,186 @@ +""" + +Caching of Instances +^^^^^^^^^^^^^^^^^^^^ + +To handle and compare types more efficiently, the pystencils type system customizes class +instantiation to cache and reuse existing instances of types. +This means, for example, if a 32-bit const unsigned integer type gets created in two places +in the program, the resulting objects are exactly the same: + +>>> from pystencils.types import PsUnsignedIntegerType +>>> t1 = PsUnsignedIntegerType(32, const=True) +>>> t2 = PsUnsignedIntegerType(32, const=True) +>>> t1 is t2 +True + +This mechanism is implemented by the metaclass `PsTypeMeta`. It is not perfect, however; +some parts of Python that bypass the regular object creation sequence, such as `pickle` and +`copy.copy`, may create additional instances of types. + +.. autoclass:: pystencils.types.meta.PsTypeMeta + :members: + +Extending the Type System +^^^^^^^^^^^^^^^^^^^^^^^^^ + +When extending the type system's class hierarchy, new classes need to implement at least the internal +method `__args__`. This method, when called on a type object, must return a hashable sequence of arguments +-- not including the const-qualifier -- +that can be used to recreate that exact type. It is used internally to compute hashes and compare equality +of types, as well as for const-conversion. + +.. autofunction:: pystencils.types.PsType.__args__ + +""" + +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +from typing import TypeVar, Any, cast +import numpy as np + + +class PsTypeMeta(ABCMeta): + """Metaclass for the `PsType` hierarchy. + + `PsTypeMeta` holds an internal cache of all created instances of `PsType` and overrides object creation + such that whenever a type gets instantiated more than once with the same argument list, + instead of creating a new object, the existing object is returned. + """ + + _instances: dict[Any, PsType] = dict() + + def __call__(cls: PsTypeMeta, *args: Any, **kwargs: Any) -> Any: + assert issubclass(cls, PsType) + kwarg_tuples = tuple(sorted(kwargs.items(), key=lambda t: t[0])) + + try: + key = (cls, args, kwarg_tuples) + + if key in cls._instances: + return cls._instances[key] + except TypeError: + key = None + + obj = super().__call__(*args, **kwargs) + canonical_key = (cls, obj.__args__(), (("const", obj.const),)) + + if canonical_key in cls._instances: + obj = cls._instances[canonical_key] + else: + cls._instances[canonical_key] = obj + + if key is not None: + cls._instances[key] = obj + + return obj + + +class PsType(metaclass=PsTypeMeta): + """Base class for all pystencils types. + + Args: + const: Const-qualification of this type + """ + + # ------------------------------------------------------------------------------------------- + # Arguments, Equality and Hashing + # ------------------------------------------------------------------------------------------- + + @abstractmethod + def __args__(self) -> tuple[Any, ...]: + """Return the arguments used to create this instance, in canonical order, excluding the const-qualifier. + + The tuple returned by this method must be hashable and for each instantiable subclass + ``MyType`` of ``PsType``, the following must hold:: + + t = MyType(< arguments >) + assert MyType(*t.__args__(), const=t.const) == t + + """ + + def __eq__(self, other: object) -> bool: + if self is other: + return True + + if type(self) is not type(other): + return False + + other = cast(PsType, other) + return self.const == other.const and self.__args__() == other.__args__() + + def __hash__(self) -> int: + return hash((type(self), self.const, self.__args__())) + + # ------------------------------------------------------------------------------------------- + # Constructor and properties + # ------------------------------------------------------------------------------------------- + + def __init__(self, const: bool = False): + self._const = const + + self._requalified: PsType | None = None + + @property + def const(self) -> bool: + return self._const + + # ------------------------------------------------------------------------------------------- + # Optional Info + # ------------------------------------------------------------------------------------------- + + @property + def required_headers(self) -> set[str]: + """The set of header files required when this type occurs in generated code.""" + return set() + + @property + def itemsize(self) -> int | None: + """If this type has a valid in-memory size, return that size.""" + return None + + @property + def numpy_dtype(self) -> np.dtype | None: + """A np.dtype object representing this data type. + + Available both for backward compatibility and for interaction with the numpy-based runtime system. + """ + return None + + # ------------------------------------------------------------------------------------------- + # String Conversion + # ------------------------------------------------------------------------------------------- + + def _const_string(self) -> str: + return "const " if self._const else "" + + @abstractmethod + def c_string(self) -> str: + pass + + def __str__(self) -> str: + return self.c_string() + + +T = TypeVar("T", bound=PsType) + + +def constify(t: T) -> T: + """Adds the const qualifier to a given type.""" + if not t.const: + if t._requalified is None: + t._requalified = type(t)(*t.__args__(), const=True) # type: ignore + return cast(T, t._requalified) + else: + return t + + +def deconstify(t: T) -> T: + """Removes the const qualifier from a given type.""" + if t.const: + if t._requalified is None: + t._requalified = type(t)(*t.__args__(), const=False) # type: ignore + return cast(T, t._requalified) + else: + return t diff --git a/src/pystencils/types/parsing.py b/src/pystencils/types/parsing.py index e28b83ae7074d3b79c145a8b5b6e47b20120873f..75fb35d223a9dbc1cfd1090a661e0c61a5335cf8 100644 --- a/src/pystencils/types/parsing.py +++ b/src/pystencils/types/parsing.py @@ -1,14 +1,62 @@ import numpy as np -from .basic_types import ( +from .types import ( PsType, PsPointerType, PsStructType, + PsNumericType, PsUnsignedIntegerType, PsSignedIntegerType, PsIeeeFloatType, ) +UserTypeSpec = str | type | np.dtype | PsType + + +def create_type(type_spec: UserTypeSpec) -> PsType: + """Create a pystencils type object from a variety of specifications. + + This function converts several possible representations of data types to an instance of `PsType`. + The ``type_spec`` argument can be any of the following: + + - Strings (`str`): will be parsed as common C types, throwing an exception if that fails. + To construct a `PsCustomType` instead, use the constructor of `PsCustomType` + or its abbreviation `types.quick.Custom`. + - Python builtin data types (instances of `type`): Attempts to interpret Python numeric types like so: + - `int` becomes a signed 64-bit integer + - `float` becomes a double-precision IEEE-754 float + - No others are supported at the moment + - Supported Numpy scalar data types (see https://numpy.org/doc/stable/reference/arrays.scalars.html) + are converted to pystencils scalar data types + - Instances of `numpy.dtype`: Attempt to interpret scalar types like above, and structured types as structs. + - Instances of `PsType` will be returned as they are + + Args: + type_spec: The data type, in one of the above formats + """ + + from .parsing import parse_type_string, interpret_python_type, interpret_numpy_dtype + + if isinstance(type_spec, PsType): + return type_spec + if isinstance(type_spec, str): + return parse_type_string(type_spec) + if isinstance(type_spec, type): + return interpret_python_type(type_spec) + if isinstance(type_spec, np.dtype): + return interpret_numpy_dtype(type_spec) + raise ValueError(f"{type_spec} is not a valid type specification.") + + +def create_numeric_type(type_spec: UserTypeSpec) -> PsNumericType: + """Like `create_type`, but only for numeric types.""" + dtype = create_type(type_spec) + if not isinstance(dtype, PsNumericType): + raise ValueError( + f"Given type {type_spec} does not translate to a numeric type." + ) + return dtype + def interpret_python_type(t: type) -> PsType: if t is int: @@ -76,13 +124,13 @@ def parse_type_string(s: str) -> PsType: base_type = parse_type_string(base) match suffix.split(): case []: - return PsPointerType(base_type, const=False, restrict=False) + return PsPointerType(base_type, restrict=False, const=False) case ["const"]: - return PsPointerType(base_type, const=True, restrict=False) + return PsPointerType(base_type, restrict=False, const=True) case ["restrict"]: - return PsPointerType(base_type, const=False, restrict=True) + return PsPointerType(base_type, restrict=True, const=False) case ["const", "restrict"] | ["restrict", "const"]: - return PsPointerType(base_type, const=True, restrict=True) + return PsPointerType(base_type, restrict=True, const=True) case _: raise ValueError(f"Could not parse token '{s}' as C type.") diff --git a/src/pystencils/types/quick.py b/src/pystencils/types/quick.py index c1a3aadc5deee62a5db4593dc72c08c782075828..146528c65ebfe66074b6412002db4ca151fde898 100644 --- a/src/pystencils/types/quick.py +++ b/src/pystencils/types/quick.py @@ -2,12 +2,8 @@ from __future__ import annotations -import numpy as np - -from .basic_types import ( - PsType, +from .types import ( PsCustomType, - PsNumericType, PsScalarType, PsBoolType, PsPointerType, @@ -18,54 +14,6 @@ from .basic_types import ( PsIeeeFloatType, ) -UserTypeSpec = str | type | np.dtype | PsType - - -def create_type(type_spec: UserTypeSpec) -> PsType: - """Create a pystencils type object from a variety of specifications. - - This function converts several possible representations of data types to an instance of `PsType`. - The ``type_spec`` argument can be any of the following: - - - Strings (`str`): will be parsed as common C types, throwing an exception if that fails. - To construct a `PsCustomType` instead, use the constructor of `PsCustomType` - or its abbreviation `types.quick.Custom`. - - Python builtin data types (instances of `type`): Attempts to interpret Python numeric types like so: - - `int` becomes a signed 64-bit integer - - `float` becomes a double-precision IEEE-754 float - - No others are supported at the moment - - Supported Numpy scalar data types (see https://numpy.org/doc/stable/reference/arrays.scalars.html) - are converted to pystencils scalar data types - - Instances of `numpy.dtype`: Attempt to interpret scalar types like above, and structured types as structs. - - Instances of `PsType` will be returned as they are - - Args: - type_spec: The data type, in one of the above formats - """ - - from .parsing import parse_type_string, interpret_python_type, interpret_numpy_dtype - - if isinstance(type_spec, PsType): - return type_spec - if isinstance(type_spec, str): - return parse_type_string(type_spec) - if isinstance(type_spec, type): - return interpret_python_type(type_spec) - if isinstance(type_spec, np.dtype): - return interpret_numpy_dtype(type_spec) - raise ValueError(f"{type_spec} is not a valid type specification.") - - -def create_numeric_type(type_spec: UserTypeSpec) -> PsNumericType: - """Like `create_type`, but only for numeric types.""" - dtype = create_type(type_spec) - if not isinstance(dtype, PsNumericType): - raise ValueError( - f"Given type {type_spec} does not translate to a numeric type." - ) - return dtype - - Custom = PsCustomType """Custom data types are modelled only by their name.""" @@ -91,7 +39,7 @@ Int = PsSignedIntegerType """``Int(width)`` matches ``PsSignedIntegerType(width)``""" SInt = PsSignedIntegerType -"""``SInt(width)` matches `PsSignedIntegerType(width)``""" +"""``SInt(width)`` matches ``PsSignedIntegerType(width)``""" Fp = PsIeeeFloatType -"""``Fp(width)` matches `PsIeeeFloatType(width)``""" +"""``Fp(width)`` matches ``PsIeeeFloatType(width)``""" diff --git a/src/pystencils/types/basic_types.py b/src/pystencils/types/types.py similarity index 73% rename from src/pystencils/types/basic_types.py rename to src/pystencils/types/types.py index 3678ea126a4318c6ce94d8f7cd2f3103c45885dd..8e51f939705fa77ba5b543af8417890bf192f819 100644 --- a/src/pystencils/types/basic_types.py +++ b/src/pystencils/types/types.py @@ -1,78 +1,12 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import final, TypeVar, Any, Sequence +from typing import final, Any, Sequence from dataclasses import dataclass -from copy import copy import numpy as np from .exception import PsTypeError - - -class PsType(ABC): - """Base class for all pystencils types. - - Args: - const: Const-qualification of this type - """ - - def __init__(self, const: bool = False): - self._const = const - - @property - def const(self) -> bool: - return self._const - - # ------------------------------------------------------------------------------------------- - # Optional Info - # ------------------------------------------------------------------------------------------- - - @property - def required_headers(self) -> set[str]: - """The set of header files required when this type occurs in generated code.""" - return set() - - @property - def itemsize(self) -> int | None: - """If this type has a valid in-memory size, return that size.""" - return None - - @property - def numpy_dtype(self) -> np.dtype | None: - """A np.dtype object representing this data type. - - Available both for backward compatibility and for interaction with the numpy-based runtime system. - """ - return None - - # ------------------------------------------------------------------------------------------- - # Internal virtual operations - # ------------------------------------------------------------------------------------------- - - def _base_equal(self, other: PsType) -> bool: - return type(self) is type(other) and self._const == other._const - - def _const_string(self) -> str: - return "const " if self._const else "" - - @abstractmethod - def c_string(self) -> str: - pass - - # ------------------------------------------------------------------------------------------- - # Dunder Methods - # ------------------------------------------------------------------------------------------- - - @abstractmethod - def __eq__(self, other: object) -> bool: - pass - - def __str__(self) -> str: - return self.c_string() - - @abstractmethod - def __hash__(self) -> int: - pass +from .meta import PsType, constify, deconstify class PsCustomType(PsType): @@ -88,18 +22,18 @@ class PsCustomType(PsType): super().__init__(const) self._name = name + def __args__(self) -> tuple[Any, ...]: + """ + >>> t = PsCustomType("std::vector< int >") + >>> t == PsCustomType(*t.__args__()) + True + """ + return (self._name,) + @property def name(self) -> str: return self._name - def __eq__(self, other: object) -> bool: - if not isinstance(other, PsCustomType): - return False - return self._base_equal(other) and self._name == other._name - - def __hash__(self) -> int: - return hash(("PsCustomType", self._name, self._const)) - def c_string(self) -> str: return f"{self._const_string()} {self._name}" @@ -111,7 +45,7 @@ class PsDereferencableType(PsType, ABC): """Base class for subscriptable types. `PsDereferencableType` represents any type that may be dereferenced and may - occur as the base of a subscript, that is, before the C `[]` operator. + occur as the base of a subscript, that is, before the C ``[]`` operator. Args: base_type: The base type, which is the type of the object obtained by dereferencing. @@ -138,22 +72,22 @@ class PsPointerType(PsDereferencableType): __match_args__ = ("base_type",) - def __init__(self, base_type: PsType, const: bool = False, restrict: bool = True): + def __init__(self, base_type: PsType, restrict: bool = True, const: bool = False): super().__init__(base_type, const) self._restrict = restrict + def __args__(self) -> tuple[Any, ...]: + """ + >>> t = PsPointerType(PsBoolType()) + >>> t == PsPointerType(*t.__args__()) + True + """ + return (self._base_type, self._restrict) + @property def restrict(self) -> bool: return self._restrict - def __eq__(self, other: object) -> bool: - if not isinstance(other, PsPointerType): - return False - return self._base_equal(other) and self._base_type == other._base_type - - def __hash__(self) -> int: - return hash(("PsPointerType", self._base_type, self._restrict, self._const)) - def c_string(self) -> str: base_str = self._base_type.c_string() restrict_str = " RESTRICT" if self._restrict else "" @@ -172,6 +106,14 @@ class PsArrayType(PsDereferencableType): self._length = length super().__init__(base_type, const) + def __args__(self) -> tuple[Any, ...]: + """ + >>> t = PsArrayType(PsBoolType(), 13) + >>> t == PsArrayType(*t.__args__()) + True + """ + return (self._base_type, self._length) + @property def length(self) -> int | None: return self._length @@ -179,19 +121,6 @@ class PsArrayType(PsDereferencableType): def c_string(self) -> str: return f"{self._base_type.c_string()} [{str(self._length) if self._length is not None else ''}]" - def __eq__(self, other: object) -> bool: - if not isinstance(other, PsArrayType): - return False - - return ( - self._base_equal(other) - and self._base_type == other._base_type - and self._length == other._length - ) - - def __hash__(self) -> int: - return hash(("PsArrayType", self._base_type, self._length, self._const)) - def __repr__(self) -> str: return f"PsArrayType(element_type={repr(self._base_type)}, size={self._length}, const={self._const})" @@ -209,6 +138,13 @@ class PsStructType(PsType): name: str dtype: PsType + @staticmethod + def _canonical_members(members: Sequence[PsStructType.Member | tuple[str, PsType]]): + return tuple( + (PsStructType.Member(m[0], m[1]) if isinstance(m, tuple) else m) + for m in members + ) + def __init__( self, members: Sequence[PsStructType.Member | tuple[str, PsType]], @@ -218,10 +154,7 @@ class PsStructType(PsType): super().__init__(const=const) self._name = name - self._members = tuple( - (PsStructType.Member(m[0], m[1]) if isinstance(m, tuple) else m) - for m in members - ) + self._members = self._canonical_members(members) names: set[str] = set() for member in self._members: @@ -229,6 +162,14 @@ class PsStructType(PsType): raise ValueError(f"Duplicate struct member name: {member.name}") names.add(member.name) + def __args__(self) -> tuple[Any, ...]: + """ + >>> t = PsStructType([("idx", PsSignedIntegerType(32)), ("val", PsBoolType())], "sname") + >>> t == PsStructType(*t.__args__()) + True + """ + return (self._members, self._name) + @property def members(self) -> tuple[PsStructType.Member, ...]: return self._members @@ -276,19 +217,6 @@ class PsStructType(PsType): else: return self._name - def __eq__(self, other: object) -> bool: - if not isinstance(other, PsStructType): - return False - - return ( - self._base_equal(other) - and self._name == other._name - and self._members == other._members - ) - - def __hash__(self) -> int: - return hash(("PsStructTupe", self._name, self._members, self._const)) - def __repr__(self) -> str: members = ", ".join(f"{m.dtype} {m.name}" for m in self._members) name = "<anonymous>" if self.anonymous else f"name={self._name}" @@ -386,6 +314,14 @@ class PsVectorType(PsNumericType): self._vector_entries = vector_entries self._scalar_type = constify(scalar_type) if const else deconstify(scalar_type) + def __args__(self) -> tuple[Any, ...]: + """ + >>> t = PsVectorType(PsBoolType(), 8) + >>> t == PsVectorType(*t.__args__()) + True + """ + return (self._scalar_type, self._vector_entries) + @property def scalar_type(self) -> PsScalarType: return self._scalar_type @@ -437,21 +373,6 @@ class PsVectorType(PsNumericType): [element] * self._vector_entries, dtype=self.scalar_type.numpy_dtype ) - def __eq__(self, other: object) -> bool: - if not isinstance(other, PsVectorType): - return False - - return ( - self._base_equal(other) - and self._scalar_type == other._scalar_type - and self._vector_entries == other._vector_entries - ) - - def __hash__(self) -> int: - return hash( - ("PsVectorType", self._scalar_type, self._vector_entries, self._const) - ) - def c_string(self) -> str: raise PsTypeError("Cannot retrieve C type string for generic vector types.") @@ -473,6 +394,14 @@ class PsBoolType(PsScalarType): def __init__(self, const: bool = False): super().__init__(const) + def __args__(self) -> tuple[Any, ...]: + """ + >>> t = PsBoolType() + >>> t == PsBoolType(*t.__args__()) + True + """ + return () + @property def width(self) -> int: return 8 @@ -487,7 +416,9 @@ class PsBoolType(PsScalarType): def create_literal(self, value: Any) -> str: if not isinstance(value, self.NUMPY_TYPE): - raise PsTypeError(f"Given value {value} is not of required type {self.NUMPY_TYPE}") + raise PsTypeError( + f"Given value {value} is not of required type {self.NUMPY_TYPE}" + ) if value == np.True_: return "true" @@ -507,15 +438,6 @@ class PsBoolType(PsScalarType): def c_string(self) -> str: return "bool" - def __eq__(self, other: object) -> bool: - if not isinstance(other, PsBoolType): - return False - - return self._base_equal(other) - - def __hash__(self) -> int: - return hash(("PsBoolType", self._const)) - class PsIntegerType(PsScalarType, ABC): """Signed and unsigned integer types. @@ -563,31 +485,20 @@ class PsIntegerType(PsScalarType, ABC): unsigned_suffix = "" if self.signed else "u" # TODO: cast literal to correct type? return str(value) + unsigned_suffix - + def create_constant(self, value: Any) -> Any: np_type = self.NUMPY_TYPES[self._width] if isinstance(value, (int, np.integer)): iinfo = np.iinfo(np_type) # type: ignore if value < iinfo.min or value > iinfo.max: - raise PsTypeError(f"Could not interpret {value} as {self}: Value is out of bounds.") + raise PsTypeError( + f"Could not interpret {value} as {self}: Value is out of bounds." + ) return np_type(value) raise PsTypeError(f"Could not interpret {value} as {repr(self)}") - def __eq__(self, other: object) -> bool: - if not isinstance(other, PsIntegerType): - return False - - return ( - self._base_equal(other) - and self._width == other._width - and self._signed == other._signed - ) - - def __hash__(self) -> int: - return hash(("PsIntegerType", self._width, self._signed, self._const)) - def c_string(self) -> str: prefix = "" if self._signed else "u" return f"{self._const_string()}{prefix}int{self._width}_t" @@ -612,6 +523,14 @@ class PsSignedIntegerType(PsIntegerType): def __init__(self, width: int, const: bool = False): super().__init__(width, True, const) + def __args__(self) -> tuple[Any, ...]: + """ + >>> t = PsSignedIntegerType(32) + >>> t == PsSignedIntegerType(*t.__args__()) + True + """ + return (self._width,) + @final class PsUnsignedIntegerType(PsIntegerType): @@ -629,6 +548,14 @@ class PsUnsignedIntegerType(PsIntegerType): def __init__(self, width: int, const: bool = False): super().__init__(width, False, const) + def __args__(self) -> tuple[Any, ...]: + """ + >>> t = PsUnsignedIntegerType(32) + >>> t == PsUnsignedIntegerType(*t.__args__()) + True + """ + return (self._width,) + @final class PsIeeeFloatType(PsScalarType): @@ -647,12 +574,20 @@ class PsIeeeFloatType(PsScalarType): def __init__(self, width: int, const: bool = False): if width not in self.SUPPORTED_WIDTHS: raise ValueError( - f"Invalid integer width {width}; must be one of {self.SUPPORTED_WIDTHS}." + f"Invalid floating-point width {width}; must be one of {self.SUPPORTED_WIDTHS}." ) super().__init__(const) self._width = width + def __args__(self) -> tuple[Any, ...]: + """ + >>> t = PsIeeeFloatType(32) + >>> t == PsIeeeFloatType(*t.__args__()) + True + """ + return (self._width,) + @property def width(self) -> int: return self._width @@ -693,19 +628,13 @@ class PsIeeeFloatType(PsScalarType): if isinstance(value, (int, float, np.floating)): finfo = np.finfo(np_type) # type: ignore if value < finfo.min or value > finfo.max: - raise PsTypeError(f"Could not interpret {value} as {self}: Value is out of bounds.") + raise PsTypeError( + f"Could not interpret {value} as {self}: Value is out of bounds." + ) return np_type(value) raise PsTypeError(f"Could not interpret {value} as {repr(self)}") - def __eq__(self, other: object) -> bool: - if not isinstance(other, PsIeeeFloatType): - return False - return self._base_equal(other) and self._width == other._width - - def __hash__(self) -> int: - return hash(("PsIeeeFloatType", self._width, self._const)) - def c_string(self) -> str: match self._width: case 16: @@ -719,26 +648,3 @@ class PsIeeeFloatType(PsScalarType): def __repr__(self) -> str: return f"PsIeeeFloatType( width={self.width}, const={self.const} )" - - -T = TypeVar("T", bound=PsType) - - -def constify(t: T) -> T: - """Adds the const qualifier to a given type.""" - if not t.const: - t_copy = copy(t) - t_copy._const = True - return t_copy - else: - return t - - -def deconstify(t: T) -> T: - """Removes the const qualifier from a given type.""" - if t.const: - t_copy = copy(t) - t_copy._const = False - return t_copy - else: - return t diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py index 0afa5b9e8da6dc18bb47fb2fdf49933b3f7d28d9..60d0d6e7424bdfea730cafe18995afdb7dc253df 100644 --- a/tests/nbackend/kernelcreation/test_typification.py +++ b/tests/nbackend/kernelcreation/test_typification.py @@ -28,8 +28,8 @@ from pystencils.backend.ast.expressions import ( PsLt, ) from pystencils.backend.constants import PsConstant -from pystencils.types import constify -from pystencils.types.quick import Fp, Bool, create_type, create_numeric_type +from pystencils.types import constify, create_type, create_numeric_type +from pystencils.types.quick import Fp, Bool from pystencils.backend.kernelcreation.context import KernelCreationContext from pystencils.backend.kernelcreation.freeze import FreezeExpressions from pystencils.backend.kernelcreation.typification import Typifier, TypificationError diff --git a/tests/nbackend/types/test_types.py b/tests/nbackend/types/test_types.py index 204ee24cfec0b9219098b0137ed80017b70b8e27..39f89e6fe6ef7ff77fdf5534eaebb3510f9caf4b 100644 --- a/tests/nbackend/types/test_types.py +++ b/tests/nbackend/types/test_types.py @@ -1,5 +1,6 @@ import pytest import numpy as np +import pickle from pystencils.backend.exceptions import PsInternalCompilerError from pystencils.types import * @@ -19,12 +20,19 @@ def test_widths(Type): def test_parsing_positive(): - assert create_type("const uint32_t * restrict") == Ptr( + assert create_type("const uint32_t * restrict") is Ptr( UInt(32, const=True), restrict=True ) - assert create_type("float * * const") == Ptr(Ptr(Fp(32)), const=True) - assert create_type("uint16 * const") == Ptr(UInt(16), const=True) - assert create_type("uint64 const * const") == Ptr(UInt(64, const=True), const=True) + assert create_type("float * * const") is Ptr( + Ptr(Fp(32), restrict=False), const=True, restrict=False + ) + assert create_type("float * * restrict const") is Ptr( + Ptr(Fp(32), restrict=False), const=True, restrict=True + ) + assert create_type("uint16 * const") is Ptr(UInt(16), const=True, restrict=False) + assert create_type("uint64 const * const") is Ptr( + UInt(64, const=True), const=True, restrict=False + ) def test_parsing_negative(): @@ -44,14 +52,14 @@ def test_parsing_negative(): def test_numpy(): import numpy as np - assert create_type(np.single) == create_type(np.float32) == PsIeeeFloatType(32) + assert create_type(np.single) is create_type(np.float32) is PsIeeeFloatType(32) assert ( create_type(float) - == create_type(np.double) - == create_type(np.float64) - == PsIeeeFloatType(64) + is create_type(np.double) + is create_type(np.float64) + is PsIeeeFloatType(64) ) - assert create_type(int) == create_type(np.int64) == PsSignedIntegerType(64) + assert create_type(int) is create_type(np.int64) is PsSignedIntegerType(64) @pytest.mark.parametrize( @@ -101,10 +109,21 @@ def test_numpy_translation(numpy_type): def test_constify(): t = PsCustomType("std::shared_ptr< Custom >") - assert deconstify(t) == t - assert deconstify(constify(t)) == t + assert deconstify(t) is t + assert deconstify(constify(t)) is t + s = PsCustomType("Field", const=True) - assert constify(s) == s + assert constify(s) is s + + i32 = create_type(np.int32) + i32_2 = PsSignedIntegerType(32) + + assert i32 is i32_2 + assert constify(i32) is constify(i32_2) + + i32_const = PsSignedIntegerType(32, const=True) + assert i32_const is not i32 + assert i32_const is constify(i32) def test_struct_types(): @@ -119,3 +138,28 @@ def test_struct_types(): assert str(t) == "<anonymous>" with pytest.raises(PsTypeError): t.c_string() + + +def test_pickle(): + types = [ + Bool(const=True), + Bool(const=False), + Custom("std::vector< uint_t >", const=False), + Ptr(Fp(32, const=False), restrict=True, const=True), + SInt(32, const=True), + SInt(16, const=False), + UInt(8, const=False), + UInt(width=16, const=False), + Int(width=32, const=False), + Fp(width=16, const=True), + PsStructType([("x", UInt(32)), ("y", UInt(32)), ("val", Fp(64))], "myStruct"), + PsStructType([("data", Fp(32))], "None"), + PsArrayType(Fp(16, const=True), 42), + PsArrayType(PsVectorType(Fp(32), 8, const=False), 42) + ] + + dumped = pickle.dumps(types) + restored = pickle.loads(dumped) + + for t1, t2 in zip(types, restored): + assert t1 == t2