Skip to content
Snippets Groups Projects
Commit ee64e7e1 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

array associated symbols cleanup

parent 0b14e421
Branches
Tags
No related merge requests found
Pipeline #63215 failed with stages
in 2 minutes and 49 seconds
......@@ -5,7 +5,7 @@ An array has a fixed name, dimensionality, and element type, as well as a number
variables.
The associated variables are the *shape* and *strides* of the array, modelled by the
`PsArrayShapeVar` and `PsArrayStrideVar` classes. They have integer type and are used to
`PsArrayShapeSymbol` and `PsArrayStrideSymbol` classes. They have integer type and are used to
reason about the array's memory layout.
......@@ -68,7 +68,7 @@ class PsLinearizedArray:
For constant entries, their value must be given as an integer.
For variable shape entries and strides, the Ellipsis `...` must be passed instead.
Internally, the passed ``index_dtype`` will be used to create typed constants (`PsTypedConstant`)
and variables (`PsArrayShapeVar` and `PsArrayStrideVar`) from the passed values.
and variables (`PsArrayShapeSymbol` and `PsArrayStrideSymbol`) from the passed values.
"""
def __init__(
......@@ -86,18 +86,18 @@ class PsLinearizedArray:
if len(shape) != len(strides):
raise ValueError("Shape and stride tuples must have the same length")
self._shape: tuple[PsArrayShapeVar | PsConstant, ...] = tuple(
self._shape: tuple[PsArrayShapeSymbol | PsConstant, ...] = tuple(
(
PsArrayShapeVar(self, i, index_dtype)
PsArrayShapeSymbol(self, i, index_dtype)
if s == Ellipsis
else PsConstant(s, index_dtype)
)
for i, s in enumerate(shape)
)
self._strides: tuple[PsArrayStrideVar | PsConstant, ...] = tuple(
self._strides: tuple[PsArrayStrideSymbol | PsConstant, ...] = tuple(
(
PsArrayStrideVar(self, i, index_dtype)
PsArrayStrideSymbol(self, i, index_dtype)
if s == Ellipsis
else PsConstant(s, index_dtype)
)
......@@ -117,8 +117,8 @@ class PsLinearizedArray:
return self._base_ptr
@property
def shape(self) -> tuple[PsArrayShapeVar | PsConstant, ...]:
"""The array's shape, expressed using `PsTypedConstant` and `PsArrayShapeVar`"""
def shape(self) -> tuple[PsArrayShapeSymbol | PsConstant, ...]:
"""The array's shape, expressed using `PsTypedConstant` and `PsArrayShapeSymbol`"""
return self._shape
@property
......@@ -129,8 +129,8 @@ class PsLinearizedArray:
)
@property
def strides(self) -> tuple[PsArrayStrideVar | PsConstant, ...]:
"""The array's strides, expressed using `PsTypedConstant` and `PsArrayStrideVar`"""
def strides(self) -> tuple[PsArrayStrideSymbol | PsConstant, ...]:
"""The array's strides, expressed using `PsTypedConstant` and `PsArrayStrideSymbol`"""
return self._strides
@property
......@@ -144,32 +144,6 @@ class PsLinearizedArray:
def element_type(self):
return self._element_type
def _hashable_contents(self):
"""Contents by which to compare two instances of `PsLinearizedArray`.
Since equality checks on shape and stride variables internally check equality of their associated arrays,
if these variables would occur in here, an infinite recursion would follow.
Hence they are filtered and replaced by the ellipsis.
"""
shape_clean = self.shape_spec
strides_clean = self.strides_spec
return (
self._name,
self._element_type,
self._index_dtype,
shape_clean,
strides_clean,
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, PsLinearizedArray):
return False
return self._hashable_contents() == other._hashable_contents()
def __hash__(self) -> int:
return hash(self._hashable_contents())
def __repr__(self) -> str:
return (
f"PsLinearizedArray({self._name}: {self.element_type}[{len(self.shape)}D])"
......@@ -182,24 +156,18 @@ class PsArrayAssocSymbol(PsSymbol, ABC):
Instances of this class represent pointers and indexing information bound
to a particular array.
"""
init_arg_names: tuple[str, ...] = ("name", "dtype", "array")
__match_args__ = ("name", "dtype", "array")
def __init__(self, name: str, dtype: PsAbstractType, array: PsLinearizedArray):
super().__init__(name, dtype)
self._array = array
def __getinitargs__(self):
return self.name, self.dtype, self.array
@property
def array(self) -> PsLinearizedArray:
return self._array
class PsArrayBasePointer(PsArrayAssocSymbol):
init_arg_names: tuple[str, ...] = ("name", "array")
__match_args__ = ("name", "array")
def __init__(self, name: str, array: PsLinearizedArray):
......@@ -208,9 +176,6 @@ class PsArrayBasePointer(PsArrayAssocSymbol):
self._array = array
def __getinitargs__(self):
return self.name, self.array
class TypeErasedBasePointer(PsArrayBasePointer):
"""Base pointer for arrays whose element type has been erased.
......@@ -224,14 +189,13 @@ class TypeErasedBasePointer(PsArrayBasePointer):
self._array = array
class PsArrayShapeVar(PsArrayAssocSymbol):
class PsArrayShapeSymbol(PsArrayAssocSymbol):
"""Variable that represents an array's shape in one coordinate.
Do not instantiate this class yourself, but only use its instances
as provided by `PsLinearizedArray.shape`.
"""
init_arg_names: tuple[str, ...] = ("array", "coordinate", "dtype")
__match_args__ = ("array", "coordinate", "dtype")
def __init__(self, array: PsLinearizedArray, coordinate: int, dtype: PsIntegerType):
......@@ -243,18 +207,13 @@ class PsArrayShapeVar(PsArrayAssocSymbol):
def coordinate(self) -> int:
return self._coordinate
def __getinitargs__(self):
return self.array, self.coordinate, self.dtype
class PsArrayStrideVar(PsArrayAssocSymbol):
class PsArrayStrideSymbol(PsArrayAssocSymbol):
"""Variable that represents an array's stride in one coordinate.
Do not instantiate this class yourself, but only use its instances
as provided by `PsLinearizedArray.strides`.
"""
init_arg_names: tuple[str, ...] = ("array", "coordinate", "dtype")
__match_args__ = ("array", "coordinate", "dtype")
def __init__(self, array: PsLinearizedArray, coordinate: int, dtype: PsIntegerType):
......@@ -265,6 +224,3 @@ class PsArrayStrideVar(PsArrayAssocSymbol):
@property
def coordinate(self) -> int:
return self._coordinate
def __getinitargs__(self):
return self.array, self.coordinate, self.dtype
......@@ -17,8 +17,8 @@ from ..arrays import (
PsLinearizedArray,
PsArrayAssocSymbol,
PsArrayBasePointer,
PsArrayShapeVar,
PsArrayStrideVar,
PsArrayShapeSymbol,
PsArrayStrideSymbol,
)
from ..types import (
PsAbstractType,
......@@ -290,12 +290,12 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
match variable:
case PsArrayBasePointer():
code = f"{variable.dtype} {variable.name} = ({variable.dtype}) {buffer}.buf;"
case PsArrayShapeVar():
case PsArrayShapeSymbol():
coord = variable.coordinate
code = (
f"{variable.dtype} {variable.name} = {buffer}.shape[{coord}];"
)
case PsArrayStrideVar():
case PsArrayStrideSymbol():
coord = variable.coordinate
code = (
f"{variable.dtype} {variable.name} = "
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment