diff --git a/src/pystencils/backend/arrays.py b/src/pystencils/backend/arrays.py index ccefa1d7816911a0cf8b04cc5f604181f1ccca65..33d5a25d9f90eacec2cb7c7d2dfb6f69924848c1 100644 --- a/src/pystencils/backend/arrays.py +++ b/src/pystencils/backend/arrays.py @@ -44,7 +44,7 @@ from abc import ABC from .constants import PsConstant from ..types import ( - PsAbstractType, + PsType, PsPointerType, PsIntegerType, PsUnsignedIntegerType, @@ -74,7 +74,7 @@ class PsLinearizedArray: def __init__( self, name: str, - element_type: PsAbstractType, + element_type: PsType, shape: Sequence[int | EllipsisType], strides: Sequence[int | EllipsisType], index_dtype: PsIntegerType = PsSignedIntegerType(64), @@ -159,7 +159,7 @@ class PsArrayAssocSymbol(PsSymbol, ABC): __match_args__ = ("name", "dtype", "array") - def __init__(self, name: str, dtype: PsAbstractType, array: PsLinearizedArray): + def __init__(self, name: str, dtype: PsType, array: PsLinearizedArray): super().__init__(name, dtype) self._array = array diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index 0db3d98b62ba64db1cb5c1c2a672b6ad60374c2f..4d4cfb457b82b4dffca341e1f961f1d556d99df1 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -7,7 +7,7 @@ from ..constants import PsConstant from ..arrays import PsLinearizedArray, PsArrayBasePointer from ..functions import PsFunction from ...types import ( - PsAbstractType, + PsType, PsScalarType, PsVectorType, PsTypeError, @@ -176,7 +176,7 @@ class PsArrayAccess(PsSubscript): return self._base_ptr.array @property - def dtype(self) -> PsAbstractType: + def dtype(self) -> PsType: """Data type of this expression, i.e. the element type of the underlying array""" return self._base_ptr.array.element_type @@ -347,16 +347,16 @@ class PsAddressOf(PsUnOp): class PsCast(PsUnOp): __match_args__ = ("target_type", "operand") - def __init__(self, target_type: PsAbstractType, operand: PsExpression): + def __init__(self, target_type: PsType, operand: PsExpression): super().__init__(operand) self._target_type = target_type @property - def target_type(self) -> PsAbstractType: + def target_type(self) -> PsType: return self._target_type @target_type.setter - def target_type(self, dtype: PsAbstractType): + def target_type(self, dtype: PsType): self._target_type = dtype def structurally_equal(self, other: PsAstNode) -> bool: diff --git a/src/pystencils/backend/jit/cpu_extension_module.py b/src/pystencils/backend/jit/cpu_extension_module.py index 060aab78883b2e3d69bfad1f707fc30cb78a1c98..b7a317adeacec9d65cf551e4eb87bdae8e800c32 100644 --- a/src/pystencils/backend/jit/cpu_extension_module.py +++ b/src/pystencils/backend/jit/cpu_extension_module.py @@ -21,7 +21,7 @@ from ..arrays import ( PsArrayStrideSymbol, ) from ...types import ( - PsAbstractType, + PsType, PsUnsignedIntegerType, PsSignedIntegerType, PsIeeeFloatType, @@ -217,7 +217,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ self._call: str | None = None - def _scalar_extractor(self, dtype: PsAbstractType) -> str: + def _scalar_extractor(self, dtype: PsType) -> str: match dtype: case Fp(32) | Fp(64): return "PyFloat_AsDouble" @@ -231,7 +231,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ f"Don't know how to cast Python objects to {dtype}" ) - def _type_char(self, dtype: PsAbstractType) -> str | None: + def _type_char(self, dtype: PsType) -> str | None: if isinstance( dtype, (PsUnsignedIntegerType, PsSignedIntegerType, PsIeeeFloatType) ): diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index ffa400a16dcadb4e91de55fbfb54b627a0dfd57e..67aeb70420aa2a7ecbacdaeda5238ab66eeb4ac9 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -10,7 +10,7 @@ from ...sympyextensions.typed_sympy import TypedSymbol from ..symbols import PsSymbol from ..arrays import PsLinearizedArray -from ...types import PsAbstractType, PsIntegerType, PsNumericType, PsScalarType, PsStructType +from ...types import PsType, PsIntegerType, PsNumericType, PsScalarType, PsStructType from ..constraints import PsKernelParamsConstraint from ..exceptions import PsInternalCompilerError, KernelConstraintsError @@ -89,7 +89,7 @@ class KernelCreationContext: # Symbols - def get_symbol(self, name: str, dtype: PsAbstractType | None = None) -> PsSymbol: + def get_symbol(self, name: str, dtype: PsType | None = None) -> PsSymbol: if name not in self._symbols: symb = PsSymbol(name, None) self._symbols[name] = symb diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 2f6b5b33b3fdc6fcdc1f2f253bae1818d0dd3b06..5085dccfb1cd933f47ca4d9a7b4d2e71fbd9b3ba 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -3,7 +3,7 @@ from __future__ import annotations from typing import TypeVar from .context import KernelCreationContext -from ...types import PsAbstractType, PsNumericType, PsStructType, deconstify +from ...types import PsType, PsNumericType, PsStructType, deconstify from ..ast.structural import PsAstNode, PsBlock, PsLoop, PsExpression, PsAssignment from ..ast.expressions import ( PsSymbolExpr, @@ -26,7 +26,7 @@ NodeT = TypeVar("NodeT", bound=PsAstNode) class TypeContext: - def __init__(self, target_type: PsAbstractType | None = None): + def __init__(self, target_type: PsType | None = None): self._target_type = deconstify(target_type) if target_type is not None else None self._deferred_constants: list[PsConstantExpr] = [] @@ -40,7 +40,7 @@ class TypeContext: else: constexpr.constant.apply_dtype(self._target_type) - def apply_and_check(self, expr: PsExpression, expr_type: PsAbstractType): + def apply_and_check(self, expr: PsExpression, expr_type: PsType): """ If no target type has been set yet, establishes expr_type as the target type and typifies all deferred expressions. @@ -69,7 +69,7 @@ class TypeContext: ) @property - def target_type(self) -> PsAbstractType | None: + def target_type(self) -> PsType | None: return self._target_type diff --git a/src/pystencils/backend/symbols.py b/src/pystencils/backend/symbols.py index 7b266d05c41747a99930305640facb63e4532782..3c3d5ab6e828d4270d6c5643fded8c5cb63b39d4 100644 --- a/src/pystencils/backend/symbols.py +++ b/src/pystencils/backend/symbols.py @@ -1,4 +1,4 @@ -from ..types import PsAbstractType, PsTypeError +from ..types import PsType, PsTypeError from .exceptions import PsInternalCompilerError @@ -13,7 +13,7 @@ class PsSymbol: __match_args__ = ("name", "dtype") - def __init__(self, name: str, dtype: PsAbstractType | None = None): + def __init__(self, name: str, dtype: PsType | None = None): self._name = name self._dtype = dtype @@ -22,14 +22,14 @@ class PsSymbol: return self._name @property - def dtype(self) -> PsAbstractType | None: + def dtype(self) -> PsType | None: return self._dtype @dtype.setter - def dtype(self, value: PsAbstractType): + def dtype(self, value: PsType): self._dtype = value - def apply_dtype(self, dtype: PsAbstractType): + def apply_dtype(self, dtype: PsType): """Apply the given data type to this symbol, raising a TypeError if it conflicts with a previously set data type.""" @@ -40,7 +40,7 @@ class PsSymbol: self._dtype = dtype - def get_dtype(self) -> PsAbstractType: + def get_dtype(self) -> PsType: if self._dtype is None: raise PsInternalCompilerError("Symbol had no type assigned yet") return self._dtype diff --git a/src/pystencils/defaults.py b/src/pystencils/defaults.py index 16357fe737b3f4265284c05f458aa4dbb1c490a3..f8e96a3a35b2f928213573c09cbb65a2c51a4dfc 100644 --- a/src/pystencils/defaults.py +++ b/src/pystencils/defaults.py @@ -1,5 +1,5 @@ from typing import TypeVar, Generic, Callable -from .types import PsAbstractType, PsIeeeFloatType, PsSignedIntegerType, PsStructType +from .types import PsType, PsIeeeFloatType, PsSignedIntegerType, PsStructType from pystencils.sympyextensions.typed_sympy import TypedSymbol @@ -7,7 +7,7 @@ SymbolT = TypeVar("SymbolT") class GenericDefaults(Generic[SymbolT]): - def __init__(self, symcreate: Callable[[str, PsAbstractType], SymbolT]): + def __init__(self, symcreate: Callable[[str, PsType], SymbolT]): self.numeric_dtype = PsIeeeFloatType(64) """Default data type for numerical computations""" diff --git a/src/pystencils/field.py b/src/pystencils/field.py index 253c94fc0f4a209fcf59d434f9cb78881e8370ab..f1b7cb376dd8ad004510b2ea4c80f8f968152c96 100644 --- a/src/pystencils/field.py +++ b/src/pystencils/field.py @@ -14,7 +14,7 @@ from sympy.core.cache import cacheit from pystencils.alignedarray import aligned_empty from pystencils.spatial_coordinates import x_staggered_vector, x_vector from pystencils.stencil import direction_string_to_offset, inverse_direction, offset_to_direction_string -from pystencils.types import PsAbstractType, PsStructType, create_type +from pystencils.types import PsType, PsStructType, create_type from pystencils.sympyextensions.typed_sympy import (FieldShapeSymbol, FieldStrideSymbol, TypedSymbol) from pystencils.sympyextensions.math import is_integer_sequence @@ -315,7 +315,7 @@ class Field: return self.strides[self.spatial_dimensions:] @property - def dtype(self) -> PsAbstractType: + def dtype(self) -> PsType: return self._dtype @property diff --git a/src/pystencils/runhelper/db.py b/src/pystencils/runhelper/db.py index 53d65b2ea16fb90a72ae1f418802404b45842767..acbe5dfa9236105a39f2fbb9bad76bd5b95ac9df 100644 --- a/src/pystencils/runhelper/db.py +++ b/src/pystencils/runhelper/db.py @@ -14,7 +14,7 @@ from pystencils import CreateKernelConfig, Target, Backend, Field import json import sympy as sp -from pystencils.types import PsAbstractType +from pystencils.types import PsType class PystencilsJsonEncoder(JsonEncoder): @@ -26,7 +26,7 @@ class PystencilsJsonEncoder(JsonEncoder): return float(obj) if isinstance(obj, sp.Integer): return int(obj) - if isinstance(obj, (PsAbstractType, MappingProxyType)): + if isinstance(obj, (PsType, MappingProxyType)): return str(obj) if isinstance(obj, (Target, Backend, sp.Symbol)): return obj.name diff --git a/src/pystencils/sympyextensions/typed_sympy.py b/src/pystencils/sympyextensions/typed_sympy.py index c437afd92cec0925bf772495038228cac33932a1..82af3d5e46136796db96188c69d4de8d325aa607 100644 --- a/src/pystencils/sympyextensions/typed_sympy.py +++ b/src/pystencils/sympyextensions/typed_sympy.py @@ -1,10 +1,10 @@ import sympy as sp -from ..types import PsAbstractType, PsNumericType, PsPointerType, PsBoolType +from ..types import PsType, PsNumericType, PsPointerType, PsBoolType from ..types.quick import create_type -def assumptions_from_dtype(dtype: PsAbstractType): +def assumptions_from_dtype(dtype: PsType): """Derives SymPy assumptions from :class:`PsAbstractType` Args: @@ -56,7 +56,7 @@ class TypedSymbol(sp.Symbol): __xnew_cached_ = staticmethod(sp.core.cacheit(__new_stage2__)) @property - def dtype(self) -> PsAbstractType: + def dtype(self) -> PsType: # mypy: ignore return self._dtype @@ -155,7 +155,7 @@ class FieldPointerSymbol(TypedSymbol): obj = FieldPointerSymbol.__xnew_cached_(cls, *args, **kwds) return obj - def __new_stage2__(cls, field_name, field_dtype: PsAbstractType, const: bool): + def __new_stage2__(cls, field_name, field_dtype: PsType, const: bool): name = f"_data_{field_name}" dtype = PsPointerType(field_dtype, const=const, restrict=True) obj = super(FieldPointerSymbol, cls).__xnew__(cls, name, dtype) diff --git a/src/pystencils/types/__init__.py b/src/pystencils/types/__init__.py index f33d8a1f7eb51b1e230c20f55d3d61c36da3de42..7096bf6c24a9989e70fa12c5770a4168ba6cd2c4 100644 --- a/src/pystencils/types/__init__.py +++ b/src/pystencils/types/__init__.py @@ -1,5 +1,5 @@ from .basic_types import ( - PsAbstractType, + PsType, PsCustomType, PsStructType, PsNumericType, @@ -20,7 +20,7 @@ from .quick import create_type, create_numeric_type from .exception import PsTypeError __all__ = [ - "PsAbstractType", + "PsType", "PsCustomType", "PsStructType", "PsPointerType", diff --git a/src/pystencils/types/basic_types.py b/src/pystencils/types/basic_types.py index 749a4d34ccc30ad66049ae2d52d7ff896ab5f626..eb373bf3de878379054f4dd9cdfba4f5630516a4 100644 --- a/src/pystencils/types/basic_types.py +++ b/src/pystencils/types/basic_types.py @@ -9,7 +9,7 @@ import numpy as np from .exception import PsTypeError -class PsAbstractType(ABC): +class PsType(ABC): """Base class for all pystencils types. **Implementation Notes** @@ -58,7 +58,7 @@ class PsAbstractType(ABC): # Internal virtual operations # ------------------------------------------------------------------------------------------- - def _base_equal(self, other: PsAbstractType) -> bool: + def _base_equal(self, other: PsType) -> bool: return type(self) is type(other) and self._const == other._const def _const_string(self) -> str: @@ -84,7 +84,7 @@ class PsAbstractType(ABC): pass -class PsCustomType(PsAbstractType): +class PsCustomType(PsType): """Class to model custom types by their names.""" __match_args__ = ("name",) @@ -113,20 +113,20 @@ class PsCustomType(PsAbstractType): @final -class PsPointerType(PsAbstractType): +class PsPointerType(PsType): """Class to model C pointer types.""" __match_args__ = ("base_type",) def __init__( - self, base_type: PsAbstractType, const: bool = False, restrict: bool = True + self, base_type: PsType, const: bool = False, restrict: bool = True ): super().__init__(const) self._base_type = base_type self._restrict = restrict @property - def base_type(self) -> PsAbstractType: + def base_type(self) -> PsType: return self._base_type @property @@ -150,7 +150,7 @@ class PsPointerType(PsAbstractType): return f"PsPointerType( {repr(self.base_type)}, const={self.const} )" -class PsStructType(PsAbstractType): +class PsStructType(PsType): """Class to model structured data types. A struct type is defined by its sequence of members. @@ -161,11 +161,11 @@ class PsStructType(PsAbstractType): @dataclass(frozen=True) class Member: name: str - dtype: PsAbstractType + dtype: PsType def __init__( self, - members: Sequence[PsStructType.Member | tuple[str, PsAbstractType]], + members: Sequence[PsStructType.Member | tuple[str, PsType]], name: str | None = None, const: bool = False, ): @@ -253,7 +253,7 @@ class PsStructType(PsAbstractType): return f"PsStructType( [{members}], {name}, const={self.const} )" -class PsNumericType(PsAbstractType, ABC): +class PsNumericType(PsType, ABC): """Class to model numeric types, which are all types that may occur at the top-level inside arithmetic-logical expressions. @@ -680,7 +680,7 @@ class PsIeeeFloatType(PsScalarType): return f"PsIeeeFloatType( width={self.width}, const={self.const} )" -T = TypeVar("T", bound=PsAbstractType) +T = TypeVar("T", bound=PsType) def constify(t: T) -> T: diff --git a/src/pystencils/types/parsing.py b/src/pystencils/types/parsing.py index be9600c715bb5bf5ec721a9839afdbca9283a5f2..8330d1a683af2b8527e90460306a60666edd89d1 100644 --- a/src/pystencils/types/parsing.py +++ b/src/pystencils/types/parsing.py @@ -1,7 +1,7 @@ import numpy as np from .basic_types import ( - PsAbstractType, + PsType, PsPointerType, PsStructType, PsUnsignedIntegerType, @@ -10,7 +10,7 @@ from .basic_types import ( ) -def interpret_python_type(t: type) -> PsAbstractType: +def interpret_python_type(t: type) -> PsType: if t is int: return PsSignedIntegerType(64) if t is float: @@ -44,7 +44,7 @@ def interpret_python_type(t: type) -> PsAbstractType: raise ValueError(f"Could not interpret Python data type {t} as a pystencils type.") -def interpret_numpy_dtype(t: np.dtype) -> PsAbstractType: +def interpret_numpy_dtype(t: np.dtype) -> PsType: if t.fields is not None: # it's a struct members = [] @@ -60,7 +60,7 @@ def interpret_numpy_dtype(t: np.dtype) -> PsAbstractType: ) -def parse_type_string(s: str) -> PsAbstractType: +def parse_type_string(s: str) -> PsType: tokens = s.rsplit("*", 1) match tokens: case [base]: # input contained no '*', is no pointer diff --git a/src/pystencils/types/quick.py b/src/pystencils/types/quick.py index 0268e9a791f50581ead5b0c2a3b3e64619d4165c..24b6968c7cb3a2874fa79352f0eb7bbc75c50d9f 100644 --- a/src/pystencils/types/quick.py +++ b/src/pystencils/types/quick.py @@ -9,7 +9,7 @@ from __future__ import annotations import numpy as np from .basic_types import ( - PsAbstractType, + PsType, PsCustomType, PsNumericType, PsScalarType, @@ -21,10 +21,10 @@ from .basic_types import ( PsIeeeFloatType, ) -UserTypeSpec = str | type | np.dtype | PsAbstractType +UserTypeSpec = str | type | np.dtype | PsType -def create_type(type_spec: UserTypeSpec) -> PsAbstractType: +def create_type(type_spec: UserTypeSpec) -> PsType: """Create a pystencils type object from a variety of specifications. Possible arguments are: @@ -43,7 +43,7 @@ def create_type(type_spec: UserTypeSpec) -> PsAbstractType: from .parsing import parse_type_string, interpret_python_type, interpret_numpy_dtype - if isinstance(type_spec, PsAbstractType): + if isinstance(type_spec, PsType): return type_spec if isinstance(type_spec, str): return parse_type_string(type_spec)