Skip to content
Snippets Groups Projects
Commit 45a51fe6 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

implement buffer index calculation; untested

parent 6a205065
No related merge requests found
Pipeline #61687 failed with stages
in 6 minutes and 49 seconds
...@@ -39,6 +39,7 @@ all occurences of the shape and stride variables with their constant value:: ...@@ -39,6 +39,7 @@ all occurences of the shape and stride variables with their constant value::
from __future__ import annotations from __future__ import annotations
from sys import intern from sys import intern
from typing import Sequence
from types import EllipsisType from types import EllipsisType
from abc import ABC from abc import ABC
...@@ -71,8 +72,8 @@ class PsLinearizedArray: ...@@ -71,8 +72,8 @@ class PsLinearizedArray:
self, self,
name: str, name: str,
element_type: PsAbstractType, element_type: PsAbstractType,
shape: tuple[int | EllipsisType, ...], shape: Sequence[int | EllipsisType],
strides: tuple[int | EllipsisType, ...], strides: Sequence[int | EllipsisType],
index_dtype: PsIntegerType = PsSignedIntegerType(64), index_dtype: PsIntegerType = PsSignedIntegerType(64),
): ):
self._name = name self._name = name
......
...@@ -103,30 +103,39 @@ class KernelCreationContext: ...@@ -103,30 +103,39 @@ class KernelCreationContext:
def get_array(self, field: Field) -> PsLinearizedArray: def get_array(self, field: Field) -> PsLinearizedArray:
if field not in self._arrays: if field not in self._arrays:
arr_shape = tuple( if field.field_type == FieldType.BUFFER:
( # Buffers are always contiguous
Ellipsis if isinstance(s, TypedSymbol) else s assert field.spatial_dimensions == 1
) # TODO: Field should also use ellipsis assert field.index_dimensions <= 1
for s in field.shape num_entries = field.index_shape[0] if field.index_shape else 1
)
arr_shape = [..., num_entries]
arr_strides = tuple( arr_strides = [num_entries, 1]
( else:
Ellipsis if isinstance(s, TypedSymbol) else s arr_shape = [
) # TODO: Field should also use ellipsis (
for s in field.strides Ellipsis if isinstance(s, TypedSymbol) else s
) ) # TODO: Field should also use ellipsis
for s in field.shape
]
arr_strides = [
(
Ellipsis if isinstance(s, TypedSymbol) else s
) # TODO: Field should also use ellipsis
for s in field.strides
]
# The frontend doesn't quite agree with itself on how to model
# fields with trivial index dimensions. Sometimes the index_shape is empty,
# sometimes its (1,). This is canonicalized here.
if not field.index_shape:
arr_shape += [1]
arr_strides += [1]
assert isinstance(field.dtype, (BasicType, StructType)) assert isinstance(field.dtype, (BasicType, StructType))
element_type = make_type(field.dtype.numpy_dtype) element_type = make_type(field.dtype.numpy_dtype)
# The frontend doesn't quite agree with itself on how to model
# fields with trivial index dimensions. Sometimes the index_shape is empty,
# sometimes its (1,). This is canonicalized here.
if not field.index_shape:
arr_shape += (1,)
arr_strides += (1,)
arr = PsLinearizedArray( arr = PsLinearizedArray(
field.name, element_type, arr_shape, arr_strides, self.index_dtype field.name, element_type, arr_shape, arr_strides, self.index_dtype
) )
......
...@@ -3,7 +3,6 @@ from typing import overload, cast ...@@ -3,7 +3,6 @@ from typing import overload, cast
import sympy as sp import sympy as sp
import pymbolic.primitives as pb import pymbolic.primitives as pb
from pymbolic.interop.sympy import SympyToPymbolicMapper from pymbolic.interop.sympy import SympyToPymbolicMapper
from itertools import chain
from ...assignment import Assignment from ...assignment import Assignment
from ...simp import AssignmentCollection from ...simp import AssignmentCollection
...@@ -101,11 +100,16 @@ class FreezeExpressions(SympyToPymbolicMapper): ...@@ -101,11 +100,16 @@ class FreezeExpressions(SympyToPymbolicMapper):
) )
] ]
case FieldType.INDEXED: case FieldType.INDEXED:
# flake8: noqa
sparse_ispace = self._ctx.get_sparse_iteration_space() sparse_ispace = self._ctx.get_sparse_iteration_space()
# Add sparse iteration counter to offset # Add sparse iteration counter to offset
assert len(offsets) == 1 # must have been checked by the context assert len(offsets) == 1 # must have been checked by the context
offsets = [offsets[0] + sparse_ispace.sparse_counter] offsets = [offsets[0] + sparse_ispace.sparse_counter]
case FieldType.BUFFER:
# TODO: Test Cases
ispace = self._ctx.get_full_iteration_space()
compressed_ctr = ispace.compressed_counter()
assert len(offsets) == 1
offsets = [compressed_ctr + offsets[0]]
case FieldType.CUSTOM: case FieldType.CUSTOM:
raise ValueError("Custom fields support only absolute accesses.") raise ValueError("Custom fields support only absolute accesses.")
case unknown: case unknown:
...@@ -114,7 +118,6 @@ class FreezeExpressions(SympyToPymbolicMapper): ...@@ -114,7 +118,6 @@ class FreezeExpressions(SympyToPymbolicMapper):
) )
# If the array type is a struct, accesses are modelled using strings # If the array type is a struct, accesses are modelled using strings
# In that case, the index is empty
if isinstance(array.element_type, PsStructType): if isinstance(array.element_type, PsStructType):
if isinstance(access.index, str): if isinstance(access.index, str):
struct_member_name = access.index struct_member_name = access.index
......
...@@ -141,15 +141,26 @@ class FullIterationSpace(IterationSpace): ...@@ -141,15 +141,26 @@ class FullIterationSpace(IterationSpace):
def steps(self): def steps(self):
return (dim.step for dim in self._dimensions) return (dim.step for dim in self._dimensions)
def num_iteration_items(self, dimension: int | None = None) -> ExprOrConstant: def actual_iterations(self, dimension: int | None = None) -> ExprOrConstant:
if dimension is None: if dimension is None:
return reduce( return reduce(
mul, (self.num_iteration_items(d) for d in range(len(self.dimensions))) mul, (self.actual_iterations(d) for d in range(len(self.dimensions)))
) )
else: else:
dim = self.dimensions[dimension] dim = self.dimensions[dimension]
one = PsTypedConstant(1, self._ctx.index_dtype) one = PsTypedConstant(1, self._ctx.index_dtype)
return one + (dim.stop - dim.start - one) / dim.step return one + (dim.stop - dim.start - one) / dim.step
def compressed_counter(self) -> ExprOrConstant:
"""Expression counting the actual number of items processed at the iteration defined by the counter tuple.
Used primarily for indexing buffers."""
actual_iters = [self.actual_iterations()]
compressed_counters = [(dim.counter - dim.start) / dim.step for dim in self.dimensions]
compressed_idx = compressed_counters[0]
for ctr, iters in zip(compressed_counters[1:], actual_iters[1:]):
compressed_idx = compressed_idx * iters + ctr
return compressed_idx
class SparseIterationSpace(IterationSpace): class SparseIterationSpace(IterationSpace):
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment