diff --git a/src/pystencils/nbackend/arrays.py b/src/pystencils/nbackend/arrays.py index 60343c4d9f211e6408ddf04eacf359e6c082007a..8c8ae200113e05160cee79f04a3f181e3a3a6313 100644 --- a/src/pystencils/nbackend/arrays.py +++ b/src/pystencils/nbackend/arrays.py @@ -39,6 +39,7 @@ all occurences of the shape and stride variables with their constant value:: from __future__ import annotations from sys import intern +from typing import Sequence from types import EllipsisType from abc import ABC @@ -71,8 +72,8 @@ class PsLinearizedArray: self, name: str, element_type: PsAbstractType, - shape: tuple[int | EllipsisType, ...], - strides: tuple[int | EllipsisType, ...], + shape: Sequence[int | EllipsisType], + strides: Sequence[int | EllipsisType], index_dtype: PsIntegerType = PsSignedIntegerType(64), ): self._name = name diff --git a/src/pystencils/nbackend/kernelcreation/context.py b/src/pystencils/nbackend/kernelcreation/context.py index b46fbeb96a1c32befd6b8df16dc90cc13c050960..d2cddb1421e6c3435e0b0bccd12b55ef98f262df 100644 --- a/src/pystencils/nbackend/kernelcreation/context.py +++ b/src/pystencils/nbackend/kernelcreation/context.py @@ -103,30 +103,39 @@ class KernelCreationContext: def get_array(self, field: Field) -> PsLinearizedArray: if field not in self._arrays: - arr_shape = tuple( - ( - Ellipsis if isinstance(s, TypedSymbol) else s - ) # TODO: Field should also use ellipsis - for s in field.shape - ) - - arr_strides = tuple( - ( - Ellipsis if isinstance(s, TypedSymbol) else s - ) # TODO: Field should also use ellipsis - for s in field.strides - ) + if field.field_type == FieldType.BUFFER: + # Buffers are always contiguous + assert field.spatial_dimensions == 1 + assert field.index_dimensions <= 1 + num_entries = field.index_shape[0] if field.index_shape else 1 + + arr_shape = [..., num_entries] + arr_strides = [num_entries, 1] + else: + arr_shape = [ + ( + Ellipsis if isinstance(s, TypedSymbol) else s + ) # TODO: Field should also use ellipsis + for s in field.shape + ] + + arr_strides = [ + ( + Ellipsis if isinstance(s, TypedSymbol) else s + ) # TODO: Field should also use ellipsis + for s in field.strides + ] + + # The frontend doesn't quite agree with itself on how to model + # fields with trivial index dimensions. Sometimes the index_shape is empty, + # sometimes its (1,). This is canonicalized here. + if not field.index_shape: + arr_shape += [1] + arr_strides += [1] assert isinstance(field.dtype, (BasicType, StructType)) element_type = make_type(field.dtype.numpy_dtype) - # The frontend doesn't quite agree with itself on how to model - # fields with trivial index dimensions. Sometimes the index_shape is empty, - # sometimes its (1,). This is canonicalized here. - if not field.index_shape: - arr_shape += (1,) - arr_strides += (1,) - arr = PsLinearizedArray( field.name, element_type, arr_shape, arr_strides, self.index_dtype ) diff --git a/src/pystencils/nbackend/kernelcreation/freeze.py b/src/pystencils/nbackend/kernelcreation/freeze.py index 38f100e72a18c4eb76418966f08ef6b5f9016fb4..94ecc63315349ca694b1035a11eb6ea92d42431d 100644 --- a/src/pystencils/nbackend/kernelcreation/freeze.py +++ b/src/pystencils/nbackend/kernelcreation/freeze.py @@ -3,7 +3,6 @@ from typing import overload, cast import sympy as sp import pymbolic.primitives as pb from pymbolic.interop.sympy import SympyToPymbolicMapper -from itertools import chain from ...assignment import Assignment from ...simp import AssignmentCollection @@ -101,11 +100,16 @@ class FreezeExpressions(SympyToPymbolicMapper): ) ] case FieldType.INDEXED: - # flake8: noqa sparse_ispace = self._ctx.get_sparse_iteration_space() # Add sparse iteration counter to offset assert len(offsets) == 1 # must have been checked by the context offsets = [offsets[0] + sparse_ispace.sparse_counter] + case FieldType.BUFFER: + # TODO: Test Cases + ispace = self._ctx.get_full_iteration_space() + compressed_ctr = ispace.compressed_counter() + assert len(offsets) == 1 + offsets = [compressed_ctr + offsets[0]] case FieldType.CUSTOM: raise ValueError("Custom fields support only absolute accesses.") case unknown: @@ -114,7 +118,6 @@ class FreezeExpressions(SympyToPymbolicMapper): ) # If the array type is a struct, accesses are modelled using strings - # In that case, the index is empty if isinstance(array.element_type, PsStructType): if isinstance(access.index, str): struct_member_name = access.index diff --git a/src/pystencils/nbackend/kernelcreation/iteration_space.py b/src/pystencils/nbackend/kernelcreation/iteration_space.py index 43ac685e5c59acc805e54fa22612d056bd804101..a739893c0826c523d4fcf060022236249f2a8d40 100644 --- a/src/pystencils/nbackend/kernelcreation/iteration_space.py +++ b/src/pystencils/nbackend/kernelcreation/iteration_space.py @@ -141,15 +141,26 @@ class FullIterationSpace(IterationSpace): def steps(self): return (dim.step for dim in self._dimensions) - def num_iteration_items(self, dimension: int | None = None) -> ExprOrConstant: + def actual_iterations(self, dimension: int | None = None) -> ExprOrConstant: if dimension is None: return reduce( - mul, (self.num_iteration_items(d) for d in range(len(self.dimensions))) + mul, (self.actual_iterations(d) for d in range(len(self.dimensions))) ) else: dim = self.dimensions[dimension] one = PsTypedConstant(1, self._ctx.index_dtype) return one + (dim.stop - dim.start - one) / dim.step + + def compressed_counter(self) -> ExprOrConstant: + """Expression counting the actual number of items processed at the iteration defined by the counter tuple. + + Used primarily for indexing buffers.""" + actual_iters = [self.actual_iterations()] + compressed_counters = [(dim.counter - dim.start) / dim.step for dim in self.dimensions] + compressed_idx = compressed_counters[0] + for ctr, iters in zip(compressed_counters[1:], actual_iters[1:]): + compressed_idx = compressed_idx * iters + ctr + return compressed_idx class SparseIterationSpace(IterationSpace):