diff --git a/pystencils/nbackend/typed_expressions.py b/pystencils/nbackend/typed_expressions.py index 761f2eba38c7fd280e13d701129af7c69a73163e..24d52bec8a2ce4bdd58cefbc384551282267309d 100644 --- a/pystencils/nbackend/typed_expressions.py +++ b/pystencils/nbackend/typed_expressions.py @@ -1,10 +1,11 @@ from __future__ import annotations -from typing import TypeAlias, Union, Any +from functools import reduce +from typing import TypeAlias, Union, Any, Tuple import pymbolic.primitives as pb -from ..typing import AbstractType, BasicType +from ..typing import AbstractType, BasicType, PointerType class PsTypedVariable(pb.Variable): @@ -17,26 +18,102 @@ class PsTypedVariable(pb.Variable): return self._dtype +class PsArray: + def __init__( + self, + name: str, + length: pb.Expression, + element_type: BasicType, # todo Frederik: is BasicType correct? + ): + self._name = name + self._length = length + self._element_type = element_type + + @property + def name(self): + return self._name + + @property + def length(self): + return self._length + + @property + def element_type(self): + return self._element_type + + +class PsLinearizedArray(PsArray): + """N-dimensional contiguous array""" + + def __init__( + self, + name: str, + shape: Tuple[pb.Expression, ...], + strides: Tuple[pb.Expression], + element_type: BasicType, + ): + length = reduce(lambda x, y: x * y, shape, 1) + super().__init__(name, length, element_type) + + self._shape = shape + self._strides = strides + + @property + def shape(self): + return self._shape + + @property + def strides(self): + return self._strides + + class PsArrayBasePointer(PsTypedVariable): - def __init__(self, name: str, base_type: AbstractType): - super(PsArrayBasePointer, self).__init__(name, base_type) + def __init__(self, name: str, array: PsArray): + dtype = PointerType(array.element_type) + super().__init__(name, dtype) + + self._array = array + + @property + def array(self): + return self._array class PsArrayAccess(pb.Subscript): def __init__(self, base_ptr: PsArrayBasePointer, index: pb.Expression): super(PsArrayAccess, self).__init__(base_ptr, index) + self._base_ptr = base_ptr + self._index = index + + @property + def base_ptr(self): + return self._base_ptr + + @property + def index(self): + return self._index + + @property + def array(self) -> PsArray: + return self._base_ptr.array + + @property + def dtype(self) -> AbstractType: + """Data type of this expression, i.e. the element type of the underlying array""" + return self._base_ptr.array.element_type PsLvalue: TypeAlias = Union[PsTypedVariable, PsArrayAccess] class PsTypedConstant: - @staticmethod def _cast(value, target_dtype: AbstractType): if isinstance(value, PsTypedConstant): if value._dtype != target_dtype: - raise ValueError(f"Incompatible types: {value._dtype} and {target_dtype}") + raise ValueError( + f"Incompatible types: {value._dtype} and {target_dtype}" + ) return value # TODO check legality