From bddd3a37e8b55bc62080da4fcea10bd982356cc2 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Mon, 19 Feb 2024 16:07:35 +0100 Subject: [PATCH] refactor field and array handling in context --- .../backend/kernelcreation/context.py | 132 ++++++++++++------ tests/nbackend/kernelcreation/test_options.py | 2 +- 2 files changed, 87 insertions(+), 47 deletions(-) diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index 0982e8455..68fdfd9c5 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -1,5 +1,7 @@ from __future__ import annotations +from types import EllipsisType + from ...field import Field, FieldType from ...sympyextensions.typed_sympy import TypedSymbol, BasicType, StructType from ..arrays import PsLinearizedArray @@ -43,15 +45,18 @@ class KernelCreationContext: or full iteration space. """ - def __init__(self, - default_dtype: PsNumericType = PbDefaults.numeric_dtype, - index_dtype: PsIntegerType = PbDefaults.index_dtype): + def __init__( + self, + default_dtype: PsNumericType = PbDefaults.numeric_dtype, + index_dtype: PsIntegerType = PbDefaults.index_dtype, + ): self._default_dtype = default_dtype self._index_dtype = index_dtype - self._arrays: dict[Field, PsLinearizedArray] = dict() self._constraints: list[PsKernelConstraint] = [] + self._field_arrays: dict[Field, PsLinearizedArray] = dict() self._fields_collection = FieldsInKernel() + self._ispace: IterationSpace | None = None @property @@ -76,7 +81,22 @@ class KernelCreationContext: return self._fields_collection def add_field(self, field: Field): - """Add the given field to the context's fields collection""" + """Add the given field to the context's fields collection. + + This method adds the passed ``field`` to the context's field collection, which is + accesible through the `fields` member, and creates an array representation of the field, + which is retrievable through `get_array`. + Before adding the field to the collection, various sanity and constraint checks are applied. + """ + + if field in self._field_arrays: + # Field was already added + return + + arr_shape: list[EllipsisType | int] | None = None + arr_strides: list[EllipsisType | int] | None = None + + # Check field constraints and add to collection match field.field_type: case FieldType.GENERIC | FieldType.STAGGERED | FieldType.STAGGERED_FLUX: self._fields_collection.domain_fields.add(field) @@ -87,6 +107,23 @@ class KernelCreationContext: f"Invalid spatial shape of buffer field {field.name}: {field.spatial_dimensions}. " "Buffer fields must be one-dimensional." ) + + if field.index_dimensions > 1: + raise KernelConstraintsError( + f"Invalid index shape of buffer field {field.name}: {field.spatial_dimensions}. " + "Buffer fields can have at most one index dimension." + ) + + num_entries = field.index_shape[0] if field.index_shape else 1 + if not isinstance(num_entries, int): + raise KernelConstraintsError( + f"Invalid index shape of buffer field {field.name}: {field.spatial_dimensions}. " + "Buffer fields cannot have variable index shape." + ) + + arr_shape = [..., num_entries] + arr_strides = [num_entries, 1] + self._fields_collection.buffer_fields.add(field) case FieldType.INDEXED: @@ -103,48 +140,51 @@ class KernelCreationContext: case _: assert False, "unreachable code" - def get_array(self, field: Field) -> PsLinearizedArray: - if field not in self._arrays: - if field.field_type == FieldType.BUFFER: - # Buffers are always contiguous - assert field.spatial_dimensions == 1 - assert field.index_dimensions <= 1 - num_entries = field.index_shape[0] if field.index_shape else 1 + # For non-buffer fields, determine shape and strides - arr_shape = [..., num_entries] - arr_strides = [num_entries, 1] - else: - arr_shape = [ - ( - Ellipsis if isinstance(s, TypedSymbol) else s - ) # TODO: Field should also use ellipsis - for s in field.shape - ] - - arr_strides = [ - ( - Ellipsis if isinstance(s, TypedSymbol) else s - ) # TODO: Field should also use ellipsis - for s in field.strides - ] - - # The frontend doesn't quite agree with itself on how to model - # fields with trivial index dimensions. Sometimes the index_shape is empty, - # sometimes its (1,). This is canonicalized here. - if not field.index_shape: - arr_shape += [1] - arr_strides += [1] - - assert isinstance(field.dtype, (BasicType, StructType)) - element_type = make_type(field.dtype.numpy_dtype) - - arr = PsLinearizedArray( - field.name, element_type, arr_shape, arr_strides, self.index_dtype - ) - - self._arrays[field] = arr - - return self._arrays[field] + if arr_shape is None: + arr_shape = [ + ( + Ellipsis if isinstance(s, TypedSymbol) else s + ) # TODO: Field should also use ellipsis + for s in field.shape + ] + + arr_strides = [ + ( + Ellipsis if isinstance(s, TypedSymbol) else s + ) # TODO: Field should also use ellipsis + for s in field.strides + ] + + # The frontend doesn't quite agree with itself on how to model + # fields with trivial index dimensions. Sometimes the index_shape is empty, + # sometimes its (1,). This is canonicalized here. + if not field.index_shape: + arr_shape += [1] + arr_strides += [1] + + # Add array + assert arr_strides is not None + + assert isinstance(field.dtype, (BasicType, StructType)) + element_type = make_type(field.dtype.numpy_dtype) + + arr = PsLinearizedArray( + field.name, element_type, arr_shape, arr_strides, self.index_dtype + ) + + self._field_arrays[field] = arr + + def get_array(self, field: Field) -> PsLinearizedArray: + """Retrieve the underlying array for a given field. + + If the given field was not previously registered using `add_field`, + this method internally calls `add_field` to check the field for consistency. + """ + if field not in self._field_arrays: + self.add_field(field) + return self._field_arrays[field] # Iteration Space diff --git a/tests/nbackend/kernelcreation/test_options.py b/tests/nbackend/kernelcreation/test_options.py index 726ee8def..8d145da75 100644 --- a/tests/nbackend/kernelcreation/test_options.py +++ b/tests/nbackend/kernelcreation/test_options.py @@ -2,7 +2,7 @@ import pytest from pystencils.field import Field, FieldType from pystencils.backend.types.quick import * -from pystencils.kernelcreation import ( +from pystencils.config import ( CreateKernelConfig, PsOptionsError, ) -- GitLab