From 3e0c00c4aed4119cf28878a5e0be4fe9dbda4341 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Wed, 26 Mar 2025 07:36:38 +0100 Subject: [PATCH] Fixes to postprocessing: Remove unused code, test vector extraction, unify treatment of scalar fields --- src/pystencilssfg/ir/postprocessing.py | 62 +++++++------------ src/pystencilssfg/lang/extractions.py | 8 ++- tests/generator_scripts/index.yaml | 1 + .../source/VectorExtraction.harness.cpp | 30 +++++++++ .../source/VectorExtraction.py | 21 +++++++ tests/ir/test_postprocessing.py | 51 ++++++++++++++- 6 files changed, 128 insertions(+), 45 deletions(-) create mode 100644 tests/generator_scripts/source/VectorExtraction.harness.cpp create mode 100644 tests/generator_scripts/source/VectorExtraction.py diff --git a/src/pystencilssfg/ir/postprocessing.py b/src/pystencilssfg/ir/postprocessing.py index 8966933..0626e2e 100644 --- a/src/pystencilssfg/ir/postprocessing.py +++ b/src/pystencilssfg/ir/postprocessing.py @@ -27,38 +27,6 @@ from ..lang import ( ) -class FlattenSequences: - """Flattens any nested sequences occuring in a kernel call tree.""" - - def __call__(self, node: SfgCallTreeNode) -> None: - self.visit(node) - - def visit(self, node: SfgCallTreeNode): - match node: - case SfgSequence(): - self.flatten(node) - case _: - for c in node.children: - self.visit(c) - - def flatten(self, sequence: SfgSequence) -> None: - children_flattened: list[SfgCallTreeNode] = [] - - def flatten(seq: SfgSequence): - for c in seq.children: - if isinstance(c, SfgSequence): - flatten(c) - else: - children_flattened.append(c) - - flatten(sequence) - - for c in children_flattened: - self.visit(c) - - sequence.children = children_flattened - - class PostProcessingContext: def __init__(self) -> None: self._live_variables: dict[str, SfgVar] = dict() @@ -129,9 +97,6 @@ class PostProcessingResult: class CallTreePostProcessing: - def __init__(self): - self._flattener = FlattenSequences() - def __call__(self, ast: SfgCallTreeNode) -> PostProcessingResult: live_vars = self.get_live_variables(ast) return PostProcessingResult(live_vars) @@ -214,6 +179,15 @@ class SfgDeferredParamSetter(SfgDeferredNode): class SfgDeferredFieldMapping(SfgDeferredNode): """Deferred mapping of a pystencils field to a field data structure.""" + # NOTE ON Scalar Fields + # + # pystencils permits explicit (`index_shape = (1,)`) and implicit (`index_shape = ()`) + # scalar fields. In order to handle both equivalently, + # we ignore the trivial explicit scalar dimension in field extraction. + # This makes sure that explicit D-dimensional scalar fields + # can be mapped onto D-dimensional data structures, and do not require that + # D+1st dimension. + def __init__( self, psfield: Field, @@ -227,10 +201,16 @@ class SfgDeferredFieldMapping(SfgDeferredNode): def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode: # Find field pointer ptr: SfgKernelParamVar | None = None - shape: list[SfgKernelParamVar | str | None] = [None] * len(self._field.shape) - strides: list[SfgKernelParamVar | str | None] = [None] * len( - self._field.strides - ) + rank: int + + if self._field.index_shape == (1,): + # explicit scalar field -> ignore index dimensions + rank = self._field.spatial_dimensions + else: + rank = len(self._field.shape) + + shape: list[SfgKernelParamVar | str | None] = [None] * rank + strides: list[SfgKernelParamVar | str | None] = [None] * rank for param in ppc.live_variables: if isinstance(param, SfgKernelParamVar): @@ -244,12 +224,12 @@ class SfgDeferredFieldMapping(SfgDeferredNode): strides[coord] = param # type: ignore # Find constant or otherwise determined sizes - for coord, s in enumerate(self._field.shape): + for coord, s in enumerate(self._field.shape[:rank]): if shape[coord] is None: shape[coord] = str(s) # Find constant or otherwise determined strides - for coord, s in enumerate(self._field.strides): + for coord, s in enumerate(self._field.strides[:rank]): if strides[coord] is None: strides[coord] = str(s) diff --git a/src/pystencilssfg/lang/extractions.py b/src/pystencilssfg/lang/extractions.py index e920fcb..39f8462 100644 --- a/src/pystencilssfg/lang/extractions.py +++ b/src/pystencilssfg/lang/extractions.py @@ -1,10 +1,11 @@ from __future__ import annotations -from typing import Protocol +from typing import Protocol, runtime_checkable from abc import abstractmethod from .expressions import AugExpr +@runtime_checkable class SupportsFieldExtraction(Protocol): """Protocol for field pointer and indexing extraction. @@ -13,7 +14,7 @@ class SupportsFieldExtraction(Protocol): They can therefore be passed to `sfg.map_field <SfgBasicComposer.map_field>`. """ -# how-to-guide begin + # how-to-guide begin @abstractmethod def _extract_ptr(self) -> AugExpr: """Extract the field base pointer. @@ -47,9 +48,12 @@ class SupportsFieldExtraction(Protocol): :meta public: """ + + # how-to-guide end +@runtime_checkable class SupportsVectorExtraction(Protocol): """Protocol for component extraction from a vector. diff --git a/tests/generator_scripts/index.yaml b/tests/generator_scripts/index.yaml index c87977f..79f06f5 100644 --- a/tests/generator_scripts/index.yaml +++ b/tests/generator_scripts/index.yaml @@ -84,6 +84,7 @@ NestedNamespaces: ScaleKernel: JacobiMdspan: StlContainers1D: +VectorExtraction: # std::mdspan diff --git a/tests/generator_scripts/source/VectorExtraction.harness.cpp b/tests/generator_scripts/source/VectorExtraction.harness.cpp new file mode 100644 index 0000000..55c4f05 --- /dev/null +++ b/tests/generator_scripts/source/VectorExtraction.harness.cpp @@ -0,0 +1,30 @@ +#include "VectorExtraction.hpp" +#include <experimental/mdspan> +#include <memory> +#include <vector> + +#undef NDEBUG +#include <cassert> + +namespace stdex = std::experimental; + +using extents_t = stdex::extents<std::int64_t, std::dynamic_extent, std::dynamic_extent, 3>; +using vector_field_t = stdex::mdspan<double, extents_t, stdex::layout_right>; +constexpr size_t N{41}; + +int main(void) +{ + auto u_data = std::make_unique<double[]>(N * N * 3); + vector_field_t u_field{u_data.get(), extents_t{N, N}}; + std::vector<double> v{3.1, 3.2, 3.4}; + + gen::invoke(u_field, v); + + for (size_t j = 0; j < N; ++j) + for (size_t i = 0; i < N; ++i) + { + assert(u_field(j, i, 0) == v[0]); + assert(u_field(j, i, 1) == v[1]); + assert(u_field(j, i, 2) == v[2]); + } +} \ No newline at end of file diff --git a/tests/generator_scripts/source/VectorExtraction.py b/tests/generator_scripts/source/VectorExtraction.py new file mode 100644 index 0000000..dc60eca --- /dev/null +++ b/tests/generator_scripts/source/VectorExtraction.py @@ -0,0 +1,21 @@ +from pystencilssfg import SourceFileGenerator +from pystencilssfg.lang.cpp import std +import pystencils as ps +import sympy as sp + +std.mdspan.configure(namespace="std::experimental", header="<experimental/mdspan>") + +with SourceFileGenerator() as sfg: + sfg.namespace("gen") + + u_field = ps.fields("u(3): double[2D]", layout="c") + u = sp.symbols("u_:3") + + asms = [ps.Assignment(u_field(i), u[i]) for i in range(3)] + ker = sfg.kernels.create(asms) + + sfg.function("invoke")( + sfg.map_field(u_field, std.mdspan.from_field(u_field, layout_policy="layout_right")), + sfg.map_vector(u, std.vector("double", const=True, ref=True).var("vel")), + sfg.call(ker) + ) diff --git a/tests/ir/test_postprocessing.py b/tests/ir/test_postprocessing.py index 5a9150b..1b057bc 100644 --- a/tests/ir/test_postprocessing.py +++ b/tests/ir/test_postprocessing.py @@ -1,10 +1,19 @@ import sympy as sp -from pystencils import fields, kernel, TypedSymbol, Field, FieldType, create_type +from pystencils import ( + fields, + kernel, + TypedSymbol, + Field, + FieldType, + create_type, + Assignment, +) from pystencils.types import PsCustomType from pystencilssfg.composer import make_sequence from pystencilssfg.lang import AugExpr, SupportsFieldExtraction +from pystencilssfg.lang.cpp import std from pystencilssfg.ir import SfgStatements, SfgSequence from pystencilssfg.ir.postprocessing import CallTreePostProcessing @@ -100,7 +109,9 @@ def test_field_extraction(sfg): khandle = sfg.kernels.create(set_constant) extraction = DemoFieldExtraction("f") - call_tree = make_sequence(sfg.map_field(f, extraction, cast_indexing_symbols=False), sfg.call(khandle)) + call_tree = make_sequence( + sfg.map_field(f, extraction, cast_indexing_symbols=False), sfg.call(khandle) + ) pp = CallTreePostProcessing() free_vars = pp.get_live_variables(call_tree) @@ -165,3 +176,39 @@ def test_duplicate_field_shapes(sfg): for line, stmt in zip(lines_f, call_tree.children[1].children, strict=True): assert isinstance(stmt, SfgStatements) assert stmt.code_string == line + + +def test_scalar_fields(sfg): + sc_expl = Field.create_generic("f", 1, "double", index_shape=(1,)) + sc_impl = Field.create_generic("f", 1, "double", index_shape=()) + + asm_expl = Assignment(sc_expl.center(0), 3) + asm_impl = Assignment(sc_impl.center(), 3) + + k_expl = sfg.kernels.create(asm_expl, "expl") + k_impl = sfg.kernels.create(asm_impl, "impl") + + tree_expl = make_sequence( + sfg.map_field(sc_expl, std.span.from_field(sc_expl)), sfg.call(k_expl) + ) + + tree_impl = make_sequence( + sfg.map_field(sc_impl, std.span.from_field(sc_impl)), sfg.call(k_impl) + ) + + pp = CallTreePostProcessing() + _ = pp.get_live_variables(tree_expl) + _ = pp.get_live_variables(tree_impl) + + extraction_expl = tree_expl.children[0] + assert isinstance(extraction_expl, SfgSequence) + + extraction_impl = tree_impl.children[0] + assert isinstance(extraction_impl, SfgSequence) + + for node1, node2 in zip( + extraction_expl.children, extraction_impl.children, strict=True + ): + assert isinstance(node1, SfgStatements) + assert isinstance(node2, SfgStatements) + assert node1.code_string == node2.code_string -- GitLab