Skip to content
Snippets Groups Projects
Commit 45a51fe6 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

implement buffer index calculation; untested

parent 6a205065
No related merge requests found
Pipeline #61687 failed with stages
in 6 minutes and 49 seconds
......@@ -39,6 +39,7 @@ all occurences of the shape and stride variables with their constant value::
from __future__ import annotations
from sys import intern
from typing import Sequence
from types import EllipsisType
from abc import ABC
......@@ -71,8 +72,8 @@ class PsLinearizedArray:
self,
name: str,
element_type: PsAbstractType,
shape: tuple[int | EllipsisType, ...],
strides: tuple[int | EllipsisType, ...],
shape: Sequence[int | EllipsisType],
strides: Sequence[int | EllipsisType],
index_dtype: PsIntegerType = PsSignedIntegerType(64),
):
self._name = name
......
......@@ -103,30 +103,39 @@ class KernelCreationContext:
def get_array(self, field: Field) -> PsLinearizedArray:
if field not in self._arrays:
arr_shape = tuple(
(
Ellipsis if isinstance(s, TypedSymbol) else s
) # TODO: Field should also use ellipsis
for s in field.shape
)
arr_strides = tuple(
(
Ellipsis if isinstance(s, TypedSymbol) else s
) # TODO: Field should also use ellipsis
for s in field.strides
)
if field.field_type == FieldType.BUFFER:
# Buffers are always contiguous
assert field.spatial_dimensions == 1
assert field.index_dimensions <= 1
num_entries = field.index_shape[0] if field.index_shape else 1
arr_shape = [..., num_entries]
arr_strides = [num_entries, 1]
else:
arr_shape = [
(
Ellipsis if isinstance(s, TypedSymbol) else s
) # TODO: Field should also use ellipsis
for s in field.shape
]
arr_strides = [
(
Ellipsis if isinstance(s, TypedSymbol) else s
) # TODO: Field should also use ellipsis
for s in field.strides
]
# The frontend doesn't quite agree with itself on how to model
# fields with trivial index dimensions. Sometimes the index_shape is empty,
# sometimes its (1,). This is canonicalized here.
if not field.index_shape:
arr_shape += [1]
arr_strides += [1]
assert isinstance(field.dtype, (BasicType, StructType))
element_type = make_type(field.dtype.numpy_dtype)
# The frontend doesn't quite agree with itself on how to model
# fields with trivial index dimensions. Sometimes the index_shape is empty,
# sometimes its (1,). This is canonicalized here.
if not field.index_shape:
arr_shape += (1,)
arr_strides += (1,)
arr = PsLinearizedArray(
field.name, element_type, arr_shape, arr_strides, self.index_dtype
)
......
......@@ -3,7 +3,6 @@ from typing import overload, cast
import sympy as sp
import pymbolic.primitives as pb
from pymbolic.interop.sympy import SympyToPymbolicMapper
from itertools import chain
from ...assignment import Assignment
from ...simp import AssignmentCollection
......@@ -101,11 +100,16 @@ class FreezeExpressions(SympyToPymbolicMapper):
)
]
case FieldType.INDEXED:
# flake8: noqa
sparse_ispace = self._ctx.get_sparse_iteration_space()
# Add sparse iteration counter to offset
assert len(offsets) == 1 # must have been checked by the context
offsets = [offsets[0] + sparse_ispace.sparse_counter]
case FieldType.BUFFER:
# TODO: Test Cases
ispace = self._ctx.get_full_iteration_space()
compressed_ctr = ispace.compressed_counter()
assert len(offsets) == 1
offsets = [compressed_ctr + offsets[0]]
case FieldType.CUSTOM:
raise ValueError("Custom fields support only absolute accesses.")
case unknown:
......@@ -114,7 +118,6 @@ class FreezeExpressions(SympyToPymbolicMapper):
)
# If the array type is a struct, accesses are modelled using strings
# In that case, the index is empty
if isinstance(array.element_type, PsStructType):
if isinstance(access.index, str):
struct_member_name = access.index
......
......@@ -141,15 +141,26 @@ class FullIterationSpace(IterationSpace):
def steps(self):
return (dim.step for dim in self._dimensions)
def num_iteration_items(self, dimension: int | None = None) -> ExprOrConstant:
def actual_iterations(self, dimension: int | None = None) -> ExprOrConstant:
if dimension is None:
return reduce(
mul, (self.num_iteration_items(d) for d in range(len(self.dimensions)))
mul, (self.actual_iterations(d) for d in range(len(self.dimensions)))
)
else:
dim = self.dimensions[dimension]
one = PsTypedConstant(1, self._ctx.index_dtype)
return one + (dim.stop - dim.start - one) / dim.step
def compressed_counter(self) -> ExprOrConstant:
"""Expression counting the actual number of items processed at the iteration defined by the counter tuple.
Used primarily for indexing buffers."""
actual_iters = [self.actual_iterations()]
compressed_counters = [(dim.counter - dim.start) / dim.step for dim in self.dimensions]
compressed_idx = compressed_counters[0]
for ctr, iters in zip(compressed_counters[1:], actual_iters[1:]):
compressed_idx = compressed_idx * iters + ctr
return compressed_idx
class SparseIterationSpace(IterationSpace):
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment