diff --git a/src/pystencils/backend/arrays.py b/src/pystencils/backend/arrays.py index 8f36bc76d34c594ff5f28187ead79026bf12f784..586da3799f3119f958edbd0d36633aa675cb4c88 100644 --- a/src/pystencils/backend/arrays.py +++ b/src/pystencils/backend/arrays.py @@ -5,7 +5,7 @@ An array has a fixed name, dimensionality, and element type, as well as a number variables. The associated variables are the *shape* and *strides* of the array, modelled by the -`PsArrayShapeVar` and `PsArrayStrideVar` classes. They have integer type and are used to +`PsArrayShapeSymbol` and `PsArrayStrideSymbol` classes. They have integer type and are used to reason about the array's memory layout. @@ -68,7 +68,7 @@ class PsLinearizedArray: For constant entries, their value must be given as an integer. For variable shape entries and strides, the Ellipsis `...` must be passed instead. Internally, the passed ``index_dtype`` will be used to create typed constants (`PsTypedConstant`) - and variables (`PsArrayShapeVar` and `PsArrayStrideVar`) from the passed values. + and variables (`PsArrayShapeSymbol` and `PsArrayStrideSymbol`) from the passed values. """ def __init__( @@ -86,18 +86,18 @@ class PsLinearizedArray: if len(shape) != len(strides): raise ValueError("Shape and stride tuples must have the same length") - self._shape: tuple[PsArrayShapeVar | PsConstant, ...] = tuple( + self._shape: tuple[PsArrayShapeSymbol | PsConstant, ...] = tuple( ( - PsArrayShapeVar(self, i, index_dtype) + PsArrayShapeSymbol(self, i, index_dtype) if s == Ellipsis else PsConstant(s, index_dtype) ) for i, s in enumerate(shape) ) - self._strides: tuple[PsArrayStrideVar | PsConstant, ...] = tuple( + self._strides: tuple[PsArrayStrideSymbol | PsConstant, ...] = tuple( ( - PsArrayStrideVar(self, i, index_dtype) + PsArrayStrideSymbol(self, i, index_dtype) if s == Ellipsis else PsConstant(s, index_dtype) ) @@ -117,8 +117,8 @@ class PsLinearizedArray: return self._base_ptr @property - def shape(self) -> tuple[PsArrayShapeVar | PsConstant, ...]: - """The array's shape, expressed using `PsTypedConstant` and `PsArrayShapeVar`""" + def shape(self) -> tuple[PsArrayShapeSymbol | PsConstant, ...]: + """The array's shape, expressed using `PsTypedConstant` and `PsArrayShapeSymbol`""" return self._shape @property @@ -129,8 +129,8 @@ class PsLinearizedArray: ) @property - def strides(self) -> tuple[PsArrayStrideVar | PsConstant, ...]: - """The array's strides, expressed using `PsTypedConstant` and `PsArrayStrideVar`""" + def strides(self) -> tuple[PsArrayStrideSymbol | PsConstant, ...]: + """The array's strides, expressed using `PsTypedConstant` and `PsArrayStrideSymbol`""" return self._strides @property @@ -144,32 +144,6 @@ class PsLinearizedArray: def element_type(self): return self._element_type - def _hashable_contents(self): - """Contents by which to compare two instances of `PsLinearizedArray`. - - Since equality checks on shape and stride variables internally check equality of their associated arrays, - if these variables would occur in here, an infinite recursion would follow. - Hence they are filtered and replaced by the ellipsis. - """ - shape_clean = self.shape_spec - strides_clean = self.strides_spec - return ( - self._name, - self._element_type, - self._index_dtype, - shape_clean, - strides_clean, - ) - - def __eq__(self, other: object) -> bool: - if not isinstance(other, PsLinearizedArray): - return False - - return self._hashable_contents() == other._hashable_contents() - - def __hash__(self) -> int: - return hash(self._hashable_contents()) - def __repr__(self) -> str: return ( f"PsLinearizedArray({self._name}: {self.element_type}[{len(self.shape)}D])" @@ -182,24 +156,18 @@ class PsArrayAssocSymbol(PsSymbol, ABC): Instances of this class represent pointers and indexing information bound to a particular array. """ - - init_arg_names: tuple[str, ...] = ("name", "dtype", "array") __match_args__ = ("name", "dtype", "array") def __init__(self, name: str, dtype: PsAbstractType, array: PsLinearizedArray): super().__init__(name, dtype) self._array = array - def __getinitargs__(self): - return self.name, self.dtype, self.array - @property def array(self) -> PsLinearizedArray: return self._array class PsArrayBasePointer(PsArrayAssocSymbol): - init_arg_names: tuple[str, ...] = ("name", "array") __match_args__ = ("name", "array") def __init__(self, name: str, array: PsLinearizedArray): @@ -208,9 +176,6 @@ class PsArrayBasePointer(PsArrayAssocSymbol): self._array = array - def __getinitargs__(self): - return self.name, self.array - class TypeErasedBasePointer(PsArrayBasePointer): """Base pointer for arrays whose element type has been erased. @@ -224,14 +189,13 @@ class TypeErasedBasePointer(PsArrayBasePointer): self._array = array -class PsArrayShapeVar(PsArrayAssocSymbol): +class PsArrayShapeSymbol(PsArrayAssocSymbol): """Variable that represents an array's shape in one coordinate. Do not instantiate this class yourself, but only use its instances as provided by `PsLinearizedArray.shape`. """ - init_arg_names: tuple[str, ...] = ("array", "coordinate", "dtype") __match_args__ = ("array", "coordinate", "dtype") def __init__(self, array: PsLinearizedArray, coordinate: int, dtype: PsIntegerType): @@ -243,18 +207,13 @@ class PsArrayShapeVar(PsArrayAssocSymbol): def coordinate(self) -> int: return self._coordinate - def __getinitargs__(self): - return self.array, self.coordinate, self.dtype - -class PsArrayStrideVar(PsArrayAssocSymbol): +class PsArrayStrideSymbol(PsArrayAssocSymbol): """Variable that represents an array's stride in one coordinate. Do not instantiate this class yourself, but only use its instances as provided by `PsLinearizedArray.strides`. """ - - init_arg_names: tuple[str, ...] = ("array", "coordinate", "dtype") __match_args__ = ("array", "coordinate", "dtype") def __init__(self, array: PsLinearizedArray, coordinate: int, dtype: PsIntegerType): @@ -265,6 +224,3 @@ class PsArrayStrideVar(PsArrayAssocSymbol): @property def coordinate(self) -> int: return self._coordinate - - def __getinitargs__(self): - return self.array, self.coordinate, self.dtype diff --git a/src/pystencils/backend/jit/cpu_extension_module.py b/src/pystencils/backend/jit/cpu_extension_module.py index 6004d1d175b7a819e19b8c0bd02ce66a959184d2..ce0ab049b81d7b2a1d367d005b8a51997a2edc1e 100644 --- a/src/pystencils/backend/jit/cpu_extension_module.py +++ b/src/pystencils/backend/jit/cpu_extension_module.py @@ -17,8 +17,8 @@ from ..arrays import ( PsLinearizedArray, PsArrayAssocSymbol, PsArrayBasePointer, - PsArrayShapeVar, - PsArrayStrideVar, + PsArrayShapeSymbol, + PsArrayStrideSymbol, ) from ..types import ( PsAbstractType, @@ -290,12 +290,12 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ match variable: case PsArrayBasePointer(): code = f"{variable.dtype} {variable.name} = ({variable.dtype}) {buffer}.buf;" - case PsArrayShapeVar(): + case PsArrayShapeSymbol(): coord = variable.coordinate code = ( f"{variable.dtype} {variable.name} = {buffer}.shape[{coord}];" ) - case PsArrayStrideVar(): + case PsArrayStrideSymbol(): coord = variable.coordinate code = ( f"{variable.dtype} {variable.name} = "