From 3e0c00c4aed4119cf28878a5e0be4fe9dbda4341 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Wed, 26 Mar 2025 07:36:38 +0100
Subject: [PATCH] Fixes to postprocessing: Remove unused code, test vector
 extraction, unify treatment of scalar fields

---
 src/pystencilssfg/ir/postprocessing.py        | 62 +++++++------------
 src/pystencilssfg/lang/extractions.py         |  8 ++-
 tests/generator_scripts/index.yaml            |  1 +
 .../source/VectorExtraction.harness.cpp       | 30 +++++++++
 .../source/VectorExtraction.py                | 21 +++++++
 tests/ir/test_postprocessing.py               | 51 ++++++++++++++-
 6 files changed, 128 insertions(+), 45 deletions(-)
 create mode 100644 tests/generator_scripts/source/VectorExtraction.harness.cpp
 create mode 100644 tests/generator_scripts/source/VectorExtraction.py

diff --git a/src/pystencilssfg/ir/postprocessing.py b/src/pystencilssfg/ir/postprocessing.py
index 8966933..0626e2e 100644
--- a/src/pystencilssfg/ir/postprocessing.py
+++ b/src/pystencilssfg/ir/postprocessing.py
@@ -27,38 +27,6 @@ from ..lang import (
 )
 
 
-class FlattenSequences:
-    """Flattens any nested sequences occuring in a kernel call tree."""
-
-    def __call__(self, node: SfgCallTreeNode) -> None:
-        self.visit(node)
-
-    def visit(self, node: SfgCallTreeNode):
-        match node:
-            case SfgSequence():
-                self.flatten(node)
-            case _:
-                for c in node.children:
-                    self.visit(c)
-
-    def flatten(self, sequence: SfgSequence) -> None:
-        children_flattened: list[SfgCallTreeNode] = []
-
-        def flatten(seq: SfgSequence):
-            for c in seq.children:
-                if isinstance(c, SfgSequence):
-                    flatten(c)
-                else:
-                    children_flattened.append(c)
-
-        flatten(sequence)
-
-        for c in children_flattened:
-            self.visit(c)
-
-        sequence.children = children_flattened
-
-
 class PostProcessingContext:
     def __init__(self) -> None:
         self._live_variables: dict[str, SfgVar] = dict()
@@ -129,9 +97,6 @@ class PostProcessingResult:
 
 
 class CallTreePostProcessing:
-    def __init__(self):
-        self._flattener = FlattenSequences()
-
     def __call__(self, ast: SfgCallTreeNode) -> PostProcessingResult:
         live_vars = self.get_live_variables(ast)
         return PostProcessingResult(live_vars)
@@ -214,6 +179,15 @@ class SfgDeferredParamSetter(SfgDeferredNode):
 class SfgDeferredFieldMapping(SfgDeferredNode):
     """Deferred mapping of a pystencils field to a field data structure."""
 
+    #   NOTE ON Scalar Fields
+    #
+    #   pystencils permits explicit (`index_shape = (1,)`) and implicit (`index_shape = ()`)
+    #   scalar fields. In order to handle both equivalently,
+    #   we ignore the trivial explicit scalar dimension in field extraction.
+    #   This makes sure that explicit D-dimensional scalar fields
+    #   can be mapped onto D-dimensional data structures, and do not require that
+    #   D+1st dimension.
+
     def __init__(
         self,
         psfield: Field,
@@ -227,10 +201,16 @@ class SfgDeferredFieldMapping(SfgDeferredNode):
     def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode:
         #    Find field pointer
         ptr: SfgKernelParamVar | None = None
-        shape: list[SfgKernelParamVar | str | None] = [None] * len(self._field.shape)
-        strides: list[SfgKernelParamVar | str | None] = [None] * len(
-            self._field.strides
-        )
+        rank: int
+
+        if self._field.index_shape == (1,):
+            #   explicit scalar field -> ignore index dimensions
+            rank = self._field.spatial_dimensions
+        else:
+            rank = len(self._field.shape)
+
+        shape: list[SfgKernelParamVar | str | None] = [None] * rank
+        strides: list[SfgKernelParamVar | str | None] = [None] * rank
 
         for param in ppc.live_variables:
             if isinstance(param, SfgKernelParamVar):
@@ -244,12 +224,12 @@ class SfgDeferredFieldMapping(SfgDeferredNode):
                             strides[coord] = param  # type: ignore
 
         #   Find constant or otherwise determined sizes
-        for coord, s in enumerate(self._field.shape):
+        for coord, s in enumerate(self._field.shape[:rank]):
             if shape[coord] is None:
                 shape[coord] = str(s)
 
         #   Find constant or otherwise determined strides
-        for coord, s in enumerate(self._field.strides):
+        for coord, s in enumerate(self._field.strides[:rank]):
             if strides[coord] is None:
                 strides[coord] = str(s)
 
diff --git a/src/pystencilssfg/lang/extractions.py b/src/pystencilssfg/lang/extractions.py
index e920fcb..39f8462 100644
--- a/src/pystencilssfg/lang/extractions.py
+++ b/src/pystencilssfg/lang/extractions.py
@@ -1,10 +1,11 @@
 from __future__ import annotations
-from typing import Protocol
+from typing import Protocol, runtime_checkable
 from abc import abstractmethod
 
 from .expressions import AugExpr
 
 
+@runtime_checkable
 class SupportsFieldExtraction(Protocol):
     """Protocol for field pointer and indexing extraction.
 
@@ -13,7 +14,7 @@ class SupportsFieldExtraction(Protocol):
     They can therefore be passed to `sfg.map_field <SfgBasicComposer.map_field>`.
     """
 
-#  how-to-guide begin
+    #  how-to-guide begin
     @abstractmethod
     def _extract_ptr(self) -> AugExpr:
         """Extract the field base pointer.
@@ -47,9 +48,12 @@ class SupportsFieldExtraction(Protocol):
 
         :meta public:
         """
+
+
 #  how-to-guide end
 
 
+@runtime_checkable
 class SupportsVectorExtraction(Protocol):
     """Protocol for component extraction from a vector.
 
diff --git a/tests/generator_scripts/index.yaml b/tests/generator_scripts/index.yaml
index c87977f..79f06f5 100644
--- a/tests/generator_scripts/index.yaml
+++ b/tests/generator_scripts/index.yaml
@@ -84,6 +84,7 @@ NestedNamespaces:
 ScaleKernel:
 JacobiMdspan:
 StlContainers1D:
+VectorExtraction:
 
 # std::mdspan
 
diff --git a/tests/generator_scripts/source/VectorExtraction.harness.cpp b/tests/generator_scripts/source/VectorExtraction.harness.cpp
new file mode 100644
index 0000000..55c4f05
--- /dev/null
+++ b/tests/generator_scripts/source/VectorExtraction.harness.cpp
@@ -0,0 +1,30 @@
+#include "VectorExtraction.hpp"
+#include <experimental/mdspan>
+#include <memory>
+#include <vector>
+
+#undef NDEBUG
+#include <cassert>
+
+namespace stdex = std::experimental;
+
+using extents_t = stdex::extents<std::int64_t, std::dynamic_extent, std::dynamic_extent, 3>;
+using vector_field_t = stdex::mdspan<double, extents_t, stdex::layout_right>;
+constexpr size_t N{41};
+
+int main(void)
+{
+    auto u_data = std::make_unique<double[]>(N * N * 3);
+    vector_field_t u_field{u_data.get(), extents_t{N, N}};
+    std::vector<double> v{3.1, 3.2, 3.4};
+
+    gen::invoke(u_field, v);
+
+    for (size_t j = 0; j < N; ++j)
+        for (size_t i = 0; i < N; ++i)
+        {
+            assert(u_field(j, i, 0) == v[0]);
+            assert(u_field(j, i, 1) == v[1]);
+            assert(u_field(j, i, 2) == v[2]);
+        }
+}
\ No newline at end of file
diff --git a/tests/generator_scripts/source/VectorExtraction.py b/tests/generator_scripts/source/VectorExtraction.py
new file mode 100644
index 0000000..dc60eca
--- /dev/null
+++ b/tests/generator_scripts/source/VectorExtraction.py
@@ -0,0 +1,21 @@
+from pystencilssfg import SourceFileGenerator
+from pystencilssfg.lang.cpp import std
+import pystencils as ps
+import sympy as sp
+
+std.mdspan.configure(namespace="std::experimental", header="<experimental/mdspan>")
+
+with SourceFileGenerator() as sfg:
+    sfg.namespace("gen")
+
+    u_field = ps.fields("u(3): double[2D]", layout="c")
+    u = sp.symbols("u_:3")
+
+    asms = [ps.Assignment(u_field(i), u[i]) for i in range(3)]
+    ker = sfg.kernels.create(asms)
+
+    sfg.function("invoke")(
+        sfg.map_field(u_field, std.mdspan.from_field(u_field, layout_policy="layout_right")),
+        sfg.map_vector(u, std.vector("double", const=True, ref=True).var("vel")),
+        sfg.call(ker)
+    )
diff --git a/tests/ir/test_postprocessing.py b/tests/ir/test_postprocessing.py
index 5a9150b..1b057bc 100644
--- a/tests/ir/test_postprocessing.py
+++ b/tests/ir/test_postprocessing.py
@@ -1,10 +1,19 @@
 import sympy as sp
-from pystencils import fields, kernel, TypedSymbol, Field, FieldType, create_type
+from pystencils import (
+    fields,
+    kernel,
+    TypedSymbol,
+    Field,
+    FieldType,
+    create_type,
+    Assignment,
+)
 from pystencils.types import PsCustomType
 
 from pystencilssfg.composer import make_sequence
 
 from pystencilssfg.lang import AugExpr, SupportsFieldExtraction
+from pystencilssfg.lang.cpp import std
 
 from pystencilssfg.ir import SfgStatements, SfgSequence
 from pystencilssfg.ir.postprocessing import CallTreePostProcessing
@@ -100,7 +109,9 @@ def test_field_extraction(sfg):
     khandle = sfg.kernels.create(set_constant)
 
     extraction = DemoFieldExtraction("f")
-    call_tree = make_sequence(sfg.map_field(f, extraction, cast_indexing_symbols=False), sfg.call(khandle))
+    call_tree = make_sequence(
+        sfg.map_field(f, extraction, cast_indexing_symbols=False), sfg.call(khandle)
+    )
 
     pp = CallTreePostProcessing()
     free_vars = pp.get_live_variables(call_tree)
@@ -165,3 +176,39 @@ def test_duplicate_field_shapes(sfg):
     for line, stmt in zip(lines_f, call_tree.children[1].children, strict=True):
         assert isinstance(stmt, SfgStatements)
         assert stmt.code_string == line
+
+
+def test_scalar_fields(sfg):
+    sc_expl = Field.create_generic("f", 1, "double", index_shape=(1,))
+    sc_impl = Field.create_generic("f", 1, "double", index_shape=())
+
+    asm_expl = Assignment(sc_expl.center(0), 3)
+    asm_impl = Assignment(sc_impl.center(), 3)
+
+    k_expl = sfg.kernels.create(asm_expl, "expl")
+    k_impl = sfg.kernels.create(asm_impl, "impl")
+
+    tree_expl = make_sequence(
+        sfg.map_field(sc_expl, std.span.from_field(sc_expl)), sfg.call(k_expl)
+    )
+
+    tree_impl = make_sequence(
+        sfg.map_field(sc_impl, std.span.from_field(sc_impl)), sfg.call(k_impl)
+    )
+
+    pp = CallTreePostProcessing()
+    _ = pp.get_live_variables(tree_expl)
+    _ = pp.get_live_variables(tree_impl)
+
+    extraction_expl = tree_expl.children[0]
+    assert isinstance(extraction_expl, SfgSequence)
+
+    extraction_impl = tree_impl.children[0]
+    assert isinstance(extraction_impl, SfgSequence)
+
+    for node1, node2 in zip(
+        extraction_expl.children, extraction_impl.children, strict=True
+    ):
+        assert isinstance(node1, SfgStatements)
+        assert isinstance(node2, SfgStatements)
+        assert node1.code_string == node2.code_string
-- 
GitLab