Skip to content
Snippets Groups Projects
Commit cbcd502d authored by Frederik Hennig's avatar Frederik Hennig
Browse files

extended sketch of backend array modelling

parent 85b35137
No related merge requests found
Pipeline #59604 failed with stages
in 2 minutes and 53 seconds
from __future__ import annotations
from typing import TypeAlias, Union, Any
from functools import reduce
from typing import TypeAlias, Union, Any, Tuple
import pymbolic.primitives as pb
from ..typing import AbstractType, BasicType
from ..typing import AbstractType, BasicType, PointerType
class PsTypedVariable(pb.Variable):
......@@ -17,26 +18,102 @@ class PsTypedVariable(pb.Variable):
return self._dtype
class PsArray:
def __init__(
self,
name: str,
length: pb.Expression,
element_type: BasicType, # todo Frederik: is BasicType correct?
):
self._name = name
self._length = length
self._element_type = element_type
@property
def name(self):
return self._name
@property
def length(self):
return self._length
@property
def element_type(self):
return self._element_type
class PsLinearizedArray(PsArray):
"""N-dimensional contiguous array"""
def __init__(
self,
name: str,
shape: Tuple[pb.Expression, ...],
strides: Tuple[pb.Expression],
element_type: BasicType,
):
length = reduce(lambda x, y: x * y, shape, 1)
super().__init__(name, length, element_type)
self._shape = shape
self._strides = strides
@property
def shape(self):
return self._shape
@property
def strides(self):
return self._strides
class PsArrayBasePointer(PsTypedVariable):
def __init__(self, name: str, base_type: AbstractType):
super(PsArrayBasePointer, self).__init__(name, base_type)
def __init__(self, name: str, array: PsArray):
dtype = PointerType(array.element_type)
super().__init__(name, dtype)
self._array = array
@property
def array(self):
return self._array
class PsArrayAccess(pb.Subscript):
def __init__(self, base_ptr: PsArrayBasePointer, index: pb.Expression):
super(PsArrayAccess, self).__init__(base_ptr, index)
self._base_ptr = base_ptr
self._index = index
@property
def base_ptr(self):
return self._base_ptr
@property
def index(self):
return self._index
@property
def array(self) -> PsArray:
return self._base_ptr.array
@property
def dtype(self) -> AbstractType:
"""Data type of this expression, i.e. the element type of the underlying array"""
return self._base_ptr.array.element_type
PsLvalue: TypeAlias = Union[PsTypedVariable, PsArrayAccess]
class PsTypedConstant:
@staticmethod
def _cast(value, target_dtype: AbstractType):
if isinstance(value, PsTypedConstant):
if value._dtype != target_dtype:
raise ValueError(f"Incompatible types: {value._dtype} and {target_dtype}")
raise ValueError(
f"Incompatible types: {value._dtype} and {target_dtype}"
)
return value
# TODO check legality
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment