From 9ccb6d0b6cc7b1aef2a1cc466e2ed8ef663c6197 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Thu, 9 Nov 2023 16:55:33 +0900
Subject: [PATCH] slowly cleaning up data types.

---
 pystencilssfg/context.py                      |   9 +-
 pystencilssfg/kernel_namespace.py             |   2 -
 .../source_components/header_include.py       |   3 +-
 pystencilssfg/source_concepts/__init__.py     |   6 +
 pystencilssfg/source_concepts/containers.py   |  36 ------
 pystencilssfg/source_concepts/cpp/__init__.py |   6 +
 .../source_concepts/cpp/std_mdspan.py         |  25 +++--
 .../source_concepts/cpp/std_vector.py         |  73 ++++++++++++
 .../source_concepts/source_concepts.py        |  30 -----
 .../source_concepts/source_objects.py         | 106 ++++++++++++++++++
 pystencilssfg/tree/basic_nodes.py             |  35 ++----
 pystencilssfg/tree/conditional.py             |   5 +-
 pystencilssfg/tree/deferred_nodes.py          |   2 +-
 pystencilssfg/tree/visitors.py                |   8 +-
 pystencilssfg/types.py                        |  12 ++
 tests/mdspan/Makefile                         |   2 +
 tests/mdspan/kernels.py                       |  30 +++--
 tests/mdspan/main.cpp                         |  54 ++++++++-
 18 files changed, 322 insertions(+), 122 deletions(-)
 delete mode 100644 pystencilssfg/source_concepts/containers.py
 create mode 100644 pystencilssfg/source_concepts/cpp/__init__.py
 create mode 100644 pystencilssfg/source_concepts/cpp/std_vector.py
 delete mode 100644 pystencilssfg/source_concepts/source_concepts.py
 create mode 100644 pystencilssfg/source_concepts/source_objects.py
 create mode 100644 pystencilssfg/types.py

diff --git a/pystencilssfg/context.py b/pystencilssfg/context.py
index 659eecc..c4eb595 100644
--- a/pystencilssfg/context.py
+++ b/pystencilssfg/context.py
@@ -13,11 +13,11 @@ from pystencils import Field
 from pystencils.astnodes import KernelFunction
 
 from .kernel_namespace import SfgKernelNamespace, SfgKernelHandle
-from .tree import SfgCallTreeNode, SfgSequence, SfgKernelCallNode
+from .tree import SfgCallTreeNode, SfgSequence, SfgKernelCallNode, SfgStatements
 from .tree.deferred_nodes import SfgDeferredFieldMapping
 from .tree.builders import SfgBranchBuilder, make_sequence
 from .tree.visitors import CollectIncludes
-from .source_concepts.containers import SrcField
+from .source_concepts import SrcField, TypedSymbolOrObject
 from .source_components import SfgFunction, SfgHeaderInclude
 
 
@@ -167,7 +167,7 @@ class SfgContext:
         
 
     #----------------------------------------------------------------------------------------------
-    #   Call Tree Node Factory
+    #   In-Sequence builders to be used within the second phase of SfgContext.function().
     #----------------------------------------------------------------------------------------------
 
     def call(self, kernel_handle: SfgKernelHandle) -> SfgKernelCallNode:
@@ -182,4 +182,7 @@ class SfgContext:
             raise NotImplementedError("Automatic field extraction is not implemented yet.")
         else:
             return SfgDeferredFieldMapping(field, src_object)
+    
+    def map_param(self, lhs: TypedSymbolOrObject, rhs: TypedSymbolOrObject, mapping: str):
+        return SfgStatements(mapping, (lhs,), (rhs,))
     
\ No newline at end of file
diff --git a/pystencilssfg/kernel_namespace.py b/pystencilssfg/kernel_namespace.py
index 191d6d3..890ff60 100644
--- a/pystencilssfg/kernel_namespace.py
+++ b/pystencilssfg/kernel_namespace.py
@@ -1,5 +1,3 @@
-# from .context import SfgContext
-
 from typing import Sequence
 
 from pystencils import CreateKernelConfig, create_kernel
diff --git a/pystencilssfg/source_components/header_include.py b/pystencilssfg/source_components/header_include.py
index 915423d..3345456 100644
--- a/pystencilssfg/source_components/header_include.py
+++ b/pystencilssfg/source_components/header_include.py
@@ -24,6 +24,7 @@ class SfgHeaderInclude:
         return hash((self._header_file, self._system_header, self._private))
     
     def __eq__(self, other: SfgHeaderInclude) -> bool:
-        return (self._header_file == other._header_file
+        return (isinstance(other, SfgHeaderInclude) 
+                and self._header_file == other._header_file
                 and self._system_header == other._system_header
                 and self._private == other._private)
diff --git a/pystencilssfg/source_concepts/__init__.py b/pystencilssfg/source_concepts/__init__.py
index e69de29..10ba677 100644
--- a/pystencilssfg/source_concepts/__init__.py
+++ b/pystencilssfg/source_concepts/__init__.py
@@ -0,0 +1,6 @@
+from .source_objects import SrcObject, SrcField, SrcVector, PsType, SrcType, TypedSymbolOrObject
+
+__all__ = [
+    SrcObject, SrcField, SrcVector,
+    PsType, SrcType, TypedSymbolOrObject
+]
\ No newline at end of file
diff --git a/pystencilssfg/source_concepts/containers.py b/pystencilssfg/source_concepts/containers.py
deleted file mode 100644
index c634310..0000000
--- a/pystencilssfg/source_concepts/containers.py
+++ /dev/null
@@ -1,36 +0,0 @@
-from typing import Optional, Union
-from abc import ABC, abstractmethod
-
-from pystencils import Field
-from pystencils.typing import FieldPointerSymbol, FieldStrideSymbol, FieldShapeSymbol
-
-from .source_concepts import SrcObject
-from ..tree import SfgStatements, SfgSequence
-
-class SrcField(SrcObject):
-    def __init__(self, src_type, identifier: Optional[str]):
-        super().__init__(src_type, identifier)
-
-    @abstractmethod
-    def extract_ptr(self, ptr_symbol: FieldPointerSymbol) -> SfgStatements:
-        pass
-
-    @abstractmethod
-    def extract_size(self, coordinate: int, size: Union[int, FieldShapeSymbol]) -> SfgStatements:
-        pass
-
-    @abstractmethod
-    def extract_stride(self, coordinate: int, stride: Union[int, FieldStrideSymbol]) -> SfgStatements:
-        pass
-
-    def extract_parameters(self, field: Field) -> SfgSequence:
-        ptr = FieldPointerSymbol(field.name, field.dtype, False)
-
-        from ..tree import make_sequence
-
-        return make_sequence(
-            self.extract_ptr(ptr),
-            *(self.extract_size(c, s) for c, s in enumerate(field.shape)),
-            *(self.extract_stride(c, s) for c, s in enumerate(field.strides))
-        )
-
diff --git a/pystencilssfg/source_concepts/cpp/__init__.py b/pystencilssfg/source_concepts/cpp/__init__.py
new file mode 100644
index 0000000..4d3397a
--- /dev/null
+++ b/pystencilssfg/source_concepts/cpp/__init__.py
@@ -0,0 +1,6 @@
+from .std_mdspan import std_mdspan
+from .std_vector import std_vector, std_vector_ref
+
+__all__= [
+    std_mdspan, std_vector, std_vector_ref
+]
\ No newline at end of file
diff --git a/pystencilssfg/source_concepts/cpp/std_mdspan.py b/pystencilssfg/source_concepts/cpp/std_mdspan.py
index c4ff910..2b99b86 100644
--- a/pystencilssfg/source_concepts/cpp/std_mdspan.py
+++ b/pystencilssfg/source_concepts/cpp/std_mdspan.py
@@ -1,22 +1,22 @@
 from typing import Set, Union, Tuple
-from numpy import dtype
 
 from pystencils.typing import FieldPointerSymbol, FieldStrideSymbol, FieldShapeSymbol
 
-from pystencilssfg.source_components import SfgHeaderInclude
-
 from ...tree import SfgStatements
-from ..containers import SrcField
+from ..source_objects import SrcField
 from ...source_components.header_include import SfgHeaderInclude
+from ..source_objects import PsType
 from ...exceptions import SfgException
 
 class std_mdspan(SrcField):
     dynamic_extent = "std::dynamic_extent"
 
-    def __init__(self, identifer: str, T: dtype, extents: Tuple[int, str]):
+    def __init__(self, identifer: str, T: PsType, extents: Tuple[int, str], extents_type: PsType = int, reference: bool = False):
         from pystencils.typing import create_type
         T = create_type(T)
-        typestring = f"std::mdspan< {T}, std::extents< int, {', '.join(str(e) for e in extents)} > >"
+        extents_type = create_type(extents_type)
+
+        typestring = f"std::mdspan< {T}, std::extents< {extents_type}, {', '.join(str(e) for e in extents)} > > {'&' if reference else ''}"
         super().__init__(typestring, identifer)
 
         self._extents = extents
@@ -33,8 +33,15 @@ class std_mdspan(SrcField):
         )
 
     def extract_size(self, coordinate: int, size: Union[int, FieldShapeSymbol]) -> SfgStatements:
-        if coordinate >= len(self._extents):
-            raise SfgException(f"Cannot extract size in coordinate {coordinate} from a {len(self._extents)}-dimensional mdspan")
+        dim = len(self._extents)
+        if coordinate >= dim:
+            if isinstance(size, FieldShapeSymbol):
+                raise SfgException(f"Cannot extract size in coordinate {coordinate} from a {dim}-dimensional mdspan!")
+            elif size != 1:
+                raise SfgException(f"Cannot map field with size {size} in coordinate {coordinate} to {dim}-dimensional mdspan!")
+            else:
+                #   trivial trailing index dimensions are OK -> do nothing
+                return SfgStatements(f"// {self._identifier}.extents().extent({coordinate}) == 1", (), ())
 
         if isinstance(size, FieldShapeSymbol):
             return SfgStatements(
@@ -50,7 +57,7 @@ class std_mdspan(SrcField):
         
     def extract_stride(self, coordinate: int, stride: Union[int, FieldStrideSymbol]) -> SfgStatements:
         if coordinate >= len(self._extents):
-            raise SfgException(f"Cannot extract size in coordinate {coordinate} from a {len(self._extents)}-dimensional mdspan")
+            raise SfgException(f"Cannot extract stride in coordinate {coordinate} from a {len(self._extents)}-dimensional mdspan")
         
         if isinstance(stride, FieldStrideSymbol):
             return SfgStatements(
diff --git a/pystencilssfg/source_concepts/cpp/std_vector.py b/pystencilssfg/source_concepts/cpp/std_vector.py
new file mode 100644
index 0000000..5f2e8f0
--- /dev/null
+++ b/pystencilssfg/source_concepts/cpp/std_vector.py
@@ -0,0 +1,73 @@
+from typing import Set, Union, Tuple
+
+from pystencils.typing import FieldPointerSymbol, FieldStrideSymbol, FieldShapeSymbol, create_type
+
+from ...tree import SfgStatements
+from ..source_objects import SrcField, SrcVector
+from ..source_objects import SrcObject, SrcType, TypedSymbolOrObject
+from ...source_components.header_include import SfgHeaderInclude
+from ...exceptions import SfgException
+
+class std_vector(SrcVector, SrcField):
+    def __init__(self, identifer: str, T: SrcType, unsafe: bool = False):
+        typestring = f"std::vector< {T} >"
+        super(SrcObject, self).__init__(identifer, typestring)
+
+        self._element_type = T
+        self._unsafe = unsafe
+
+    @property
+    def required_includes(self) -> Set[SfgHeaderInclude]:
+        return { SfgHeaderInclude("vector", system_header=True) }
+    
+    def extract_ptr(self, ptr_symbol: FieldPointerSymbol):
+        if ptr_symbol.dtype != self._element_type:
+            if self._unsafe:
+                mapping = f"{ptr_symbol.dtype} {ptr_symbol.name} = ({ptr_symbol.dtype}) {self._identifier}.data();"
+            else:
+                raise SfgException("Field type and std::vector element type do not match, and unsafe extraction was not enabled.")
+        else:
+            mapping = f"{ptr_symbol.dtype} {ptr_symbol.name} = {self._identifier}.data();"
+
+        return SfgStatements(mapping, (ptr_symbol,), (self,))
+    
+    def extract_size(self, coordinate: int, size: Union[int, FieldShapeSymbol]) -> SfgStatements:
+        if coordinate > 0:
+            raise SfgException(f"Cannot extract size in coordinate {coordinate} from std::vector")
+
+        if isinstance(size, FieldShapeSymbol):
+            return SfgStatements(
+                    f"{size.dtype} {size.name} = {self._identifier}.size();",
+                    (size, ),
+                    (self, )
+                )
+        else:
+            return SfgStatements(
+                f"assert( {self._identifier}.size() == {size} );",
+                (), (self, )
+            )
+        
+    def extract_stride(self, coordinate: int, stride: Union[int, FieldStrideSymbol]) -> SfgStatements:
+        if coordinate > 0:
+            raise SfgException(f"Cannot extract stride in coordinate {coordinate} from std::vector")
+        
+        if isinstance(stride, FieldStrideSymbol):
+            return SfgStatements(f"{stride.dtype} {stride.name} = 1;", (stride, ), ())
+        else:
+            return SfgStatements(f"assert( 1 == {stride} );", (), ())
+
+
+    def extract_component(self, destination: TypedSymbolOrObject, coordinate: int):
+        if self._unsafe:
+            mapping = f"{destination.dtype} {destination.name} = {self._identifier}[{coordinate}];"
+        else:
+            mapping = f"{destination.dtype} {destination.name} = {self._identifier}.at({coordinate});"
+
+        return SfgStatements(mapping, (destination,), (self,))
+
+
+
+class std_vector_ref(std_vector):
+    def __init__(self, identifer: str, T: SrcType):
+        typestring = f"std::vector< {T} > &"
+        super(SrcObject, self).__init__(identifer, typestring)
diff --git a/pystencilssfg/source_concepts/source_concepts.py b/pystencilssfg/source_concepts/source_concepts.py
deleted file mode 100644
index 66a22d8..0000000
--- a/pystencilssfg/source_concepts/source_concepts.py
+++ /dev/null
@@ -1,30 +0,0 @@
-from __future__ import annotations
-
-from typing import TYPE_CHECKING, Optional, Sequence, Union, Set
-
-if TYPE_CHECKING:
-    from ..source_components import SfgHeaderInclude
-
-from abc import ABC
-from pystencils import TypedSymbol
-
-class SrcObject(ABC):
-    def __init__(self, src_type, identifier: Optional[str]):
-        self._src_type = src_type
-        self._identifier = identifier
-
-    @property
-    def src_type(self):
-        return self._src_type
-    
-    @property
-    def identifier(self):
-        return self._identifier
-
-    @property
-    def required_includes(self) -> Set[SfgHeaderInclude]:
-        return set()
-
-    @property
-    def typed_symbol(self):
-        return TypedSymbol(self._identifier, self._src_type)
diff --git a/pystencilssfg/source_concepts/source_objects.py b/pystencilssfg/source_concepts/source_objects.py
new file mode 100644
index 0000000..244ff75
--- /dev/null
+++ b/pystencilssfg/source_concepts/source_objects.py
@@ -0,0 +1,106 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Optional, Union, Set, TypeAlias, NewType
+
+if TYPE_CHECKING:
+    from ..source_components import SfgHeaderInclude
+    from ..tree import SfgStatements, SfgSequence
+
+from numpy import dtype
+
+from abc import ABC, abstractmethod
+
+from pystencils import TypedSymbol, Field
+from pystencils.typing import AbstractType, FieldPointerSymbol, FieldStrideSymbol, FieldShapeSymbol
+
+PsType: TypeAlias = Union[type, dtype, AbstractType]
+"""Types used in interacting with pystencils.
+
+PsType represents various ways of specifying types within pystencils.
+In particular, it encompasses most ways to construct an instance of `AbstractType`,
+for example via `create_type`.
+
+(Note that, while `create_type` does accept strings, they are excluded here for
+reasons of safety. It is discouraged to use strings for type specifications when working
+with pystencils!)
+"""
+
+SrcType = NewType('SrcType', str)
+"""Nonprimitive C/C++-Types occuring during source file generation.
+
+Nonprimitive C/C++ types are represented by their names.
+When necessary, the SFG package checks equality of types by these name strings; it does
+not care about typedefs, aliases, namespaces, etc!
+"""
+
+
+class SrcObject:
+    """C/C++ object of nonprimitive type.
+    
+    Two objects are identical if they have the same identifier and type string."""
+
+    def __init__(self, src_type: SrcType, identifier: Optional[str]):
+        self._src_type = src_type
+        self._identifier = identifier
+    
+    @property
+    def identifier(self):
+        return self._identifier
+
+    @property
+    def name(self):
+        """For interface compatibility with ps.TypedSymbol"""
+        return self._identifier
+
+    @property
+    def dtype(self):
+        return self._src_type
+
+    @property
+    def required_includes(self) -> Set[SfgHeaderInclude]:
+        return set()
+    
+    def __hash__(self) -> int:
+        return hash((self._identifier, self._src_type))
+    
+    def __eq__(self, other: SrcObject) -> bool:
+        return (isinstance(other, SrcObject) 
+                and self._identifier == other._identifier
+                and self._src_type == other._src_type)
+
+
+TypedSymbolOrObject: TypeAlias = Union[TypedSymbol, SrcObject]
+
+
+class SrcField(SrcObject, ABC):
+    def __init__(self, src_type: SrcType, identifier: Optional[str]):
+        super().__init__(src_type, identifier)
+
+    @abstractmethod
+    def extract_ptr(self, ptr_symbol: FieldPointerSymbol) -> SfgStatements:
+        pass
+
+    @abstractmethod
+    def extract_size(self, coordinate: int, size: Union[int, FieldShapeSymbol]) -> SfgStatements:
+        pass
+
+    @abstractmethod
+    def extract_stride(self, coordinate: int, stride: Union[int, FieldStrideSymbol]) -> SfgStatements:
+        pass
+
+    def extract_parameters(self, field: Field) -> SfgSequence:
+        ptr = FieldPointerSymbol(field.name, field.dtype, False)
+
+        from ..tree import make_sequence
+
+        return make_sequence(
+            self.extract_ptr(ptr),
+            *(self.extract_size(c, s) for c, s in enumerate(field.shape)),
+            *(self.extract_stride(c, s) for c, s in enumerate(field.strides))
+        )
+
+
+class SrcVector(SrcObject):
+    @abstractmethod
+    def extract_component(self, destination: TypedSymbolOrObject, coordinate: int):
+        pass
diff --git a/pystencilssfg/tree/basic_nodes.py b/pystencilssfg/tree/basic_nodes.py
index 429f504..75fceb5 100644
--- a/pystencilssfg/tree/basic_nodes.py
+++ b/pystencilssfg/tree/basic_nodes.py
@@ -6,13 +6,10 @@ if TYPE_CHECKING:
     from ..source_components import SfgHeaderInclude
 
 from abc import ABC, abstractmethod
-from functools import reduce
 from itertools import chain
 
-from jinja2.filters import do_indent
-
 from ..kernel_namespace import SfgKernelHandle
-from ..source_concepts.source_concepts import SrcObject
+from  ..source_concepts.source_objects import SrcObject, TypedSymbolOrObject
 
 from ..exceptions import SfgException
 
@@ -54,7 +51,7 @@ class SfgCallTreeLeaf(SfgCallTreeNode, ABC):
 
     @property
     @abstractmethod
-    def required_symbols(self) -> Set[TypedSymbol]:
+    def required_parameters(self) -> Set[TypedSymbolOrObject]:
         pass
 
 
@@ -77,33 +74,25 @@ class SfgStatements(SfgCallTreeLeaf):
 
     def __init__(self, 
                  code_string: str,
-                 defined_objects: Sequence[Union[SrcObject, TypedSymbol]],
-                 required_objects: Sequence[Union[SrcObject, TypedSymbol]]):
+                 defined_params: Sequence[TypedSymbolOrObject],
+                 required_params: Sequence[TypedSymbolOrObject]):
         self._code_string = code_string
         
-        def to_symbol(obj: Union[SrcObject, TypedSymbol]):
-            if isinstance(obj, SrcObject):
-                return obj.typed_symbol
-            elif isinstance(obj, TypedSymbol):
-                return obj
-            else:
-                raise ValueError(f"Required object in expression is neither TypedSymbol nor SrcObject: {obj}")
-        
-        self._defined_symbols = set(map(to_symbol, defined_objects))
-        self._required_symbols = set(map(to_symbol, required_objects))
+        self._defined_params = set(defined_params)
+        self._required_params = set(required_params)
 
         self._required_includes = set()
-        for obj in chain(required_objects, defined_objects):
+        for obj in chain(required_params, defined_params):
             if isinstance(obj, SrcObject):
                 self._required_includes |= obj.required_includes
             
     @property
-    def required_symbols(self) -> Set[TypedSymbol]:
-        return self._required_symbols
+    def required_parameters(self) -> Set[TypedSymbolOrObject]:
+        return self._required_params
     
     @property
-    def defined_symbols(self) -> Set[TypedSymbol]:
-        return self._defined_symbols
+    def defined_parameters(self) -> Set[TypedSymbolOrObject]:
+        return self._defined_params
     
     @property
     def required_includes(self) -> Set[SfgHeaderInclude]:
@@ -153,7 +142,7 @@ class SfgKernelCallNode(SfgCallTreeLeaf):
         self._kernel_handle = kernel_handle
 
     @property
-    def required_symbols(self) -> Set[TypedSymbol]:
+    def required_parameters(self) -> Set[TypedSymbolOrObject]:
         return set(p.symbol for p in self._kernel_handle.parameters)
     
     def get_code(self, ctx: SfgContext) -> str:
diff --git a/pystencilssfg/tree/conditional.py b/pystencilssfg/tree/conditional.py
index acfcc61..8731344 100644
--- a/pystencilssfg/tree/conditional.py
+++ b/pystencilssfg/tree/conditional.py
@@ -1,5 +1,5 @@
 from __future__ import annotations
-from typing import TYPE_CHECKING, Sequence, Optional
+from typing import TYPE_CHECKING, Sequence, Optional, Set
 
 if TYPE_CHECKING:
     from ..context import SfgContext
@@ -8,6 +8,7 @@ from jinja2.filters import do_indent
 from pystencils.typing import TypedSymbol
 
 from .basic_nodes import SfgCallTreeNode, SfgCallTreeLeaf
+from ..source_concepts.source_objects import TypedSymbolOrObject
 
 class SfgCondition(SfgCallTreeLeaf):
     pass
@@ -16,7 +17,7 @@ class SfgCustomCondition(SfgCondition):
     def __init__(self, cond_text: str):
         self._cond_text = cond_text
 
-    def required_symbols(self) -> set(TypedSymbol):
+    def required_parameters(self) -> Set[TypedSymbolOrObject]:
         return set()
 
     def get_code(self, ctx: SfgContext) -> str:
diff --git a/pystencilssfg/tree/deferred_nodes.py b/pystencilssfg/tree/deferred_nodes.py
index 349c7bd..b14ba2d 100644
--- a/pystencilssfg/tree/deferred_nodes.py
+++ b/pystencilssfg/tree/deferred_nodes.py
@@ -14,7 +14,7 @@ from ..exceptions import SfgException
 from .basic_nodes import SfgCallTreeNode
 from .builders import make_sequence
 
-from ..source_concepts.containers import SrcField
+from ..source_concepts import SrcField
 
 
 class SfgDeferredNode(SfgCallTreeNode, ABC):
diff --git a/pystencilssfg/tree/visitors.py b/pystencilssfg/tree/visitors.py
index 88c078d..9da7db2 100644
--- a/pystencilssfg/tree/visitors.py
+++ b/pystencilssfg/tree/visitors.py
@@ -66,7 +66,7 @@ class ExpandingParameterCollector():
             return self._visit_branchingNode(node)
 
     def _visit_SfgCallTreeLeaf(self, leaf: SfgCallTreeLeaf) -> Set[TypedSymbol]:
-        return leaf.required_symbols
+        return leaf.required_parameters
 
     def _visit_SfgSequence(self, sequence: SfgSequence) -> Set[TypedSymbol]:
         """
@@ -87,7 +87,7 @@ class ExpandingParameterCollector():
                     iter_nested_sequences(c, visible_params)
                 else:
                     if isinstance(c, SfgStatements):
-                        visible_params -= c.defined_symbols
+                        visible_params -= c.defined_parameters
                     
                     visible_params |= self.visit(c)
 
@@ -117,7 +117,7 @@ class ParameterCollector():
             return self._visit_branchingNode(node)
 
     def _visit_SfgCallTreeLeaf(self, leaf: SfgCallTreeLeaf) -> Set[TypedSymbol]:
-        return leaf.required_symbols
+        return leaf.required_parameters
 
     def _visit_SfgSequence(self, sequence: SfgSequence) -> Set[TypedSymbol]:
         """
@@ -127,7 +127,7 @@ class ParameterCollector():
         params = set()
         for c in sequence.children[::-1]:
             if isinstance(c, SfgStatements):
-                params -= c.defined_symbols
+                params -= c.defined_parameters
             
             assert not isinstance(c, SfgSequence), "Sequence not flattened."
             params |= self.visit(c)
diff --git a/pystencilssfg/types.py b/pystencilssfg/types.py
new file mode 100644
index 0000000..9e309e8
--- /dev/null
+++ b/pystencilssfg/types.py
@@ -0,0 +1,12 @@
+from pystencils.typing import AbstractType, BasicType, StructType, PointerType
+
+
+class SrcType:
+    """Valid C/C++-Type occuring during source file generation.
+
+    Nonprimitive C/C++ types are represented by their names.
+    When necessary, the SFG package checks equality of types by these name strings; it does
+    not care about typedefs, aliases, namespaces, etc!
+    """
+    
+
diff --git a/tests/mdspan/Makefile b/tests/mdspan/Makefile
index 76e57a6..1761528 100644
--- a/tests/mdspan/Makefile
+++ b/tests/mdspan/Makefile
@@ -27,8 +27,10 @@ $(OBJ)/kernels.o: $(GEN_SRC)/kernels.cpp $(GEN_SRC)/kernels.h
 	$(CXX) $(CXX_FLAGS) -c -o $@ $<
 
 $(OBJ)/main.o: main.cpp $(GEN_SRC)/kernels.h
+	@$(dir_guard)
 	$(CXX) $(CXX_FLAGS) -c -o $@ $<
 
 $(BIN)/mdspan_test: $(OBJ)/kernels.o $(OBJ)/main.o
+	@$(dir_guard)
 	$(CXX) $(CXX_FLAGS) -o $@ $^
 
diff --git a/tests/mdspan/kernels.py b/tests/mdspan/kernels.py
index d8f582c..27a482a 100644
--- a/tests/mdspan/kernels.py
+++ b/tests/mdspan/kernels.py
@@ -1,22 +1,32 @@
+import sympy as sp
 import numpy as np
+
 from pystencils.session import *
 
 from pystencilssfg import SourceFileGenerator
-from pystencilssfg.source_concepts.cpp.std_mdspan import std_mdspan
+from pystencilssfg.source_concepts.cpp import std_mdspan
+
+def field_t(field: ps.Field):
+    return std_mdspan(field.name,
+                      field.dtype,
+                      (std_mdspan.dynamic_extent, std_mdspan.dynamic_extent),
+                      extents_type=np.uint32,
+                      reference=True)
+
 
-with SourceFileGenerator() as sfg:
+with SourceFileGenerator("poisson") as sfg:
     src, dst = ps.fields("src, dst(1) : double[2D]")
 
-    @ps.kernel
-    def poisson_gs():
-        dst[0,0] @= src[1, 0] + src[-1, 0] + src[0, 1] + src[0, -1] - 4 * src[0, 0]
+    h = sp.Symbol('h')
 
-    sfg.include("<iostream>")
+    @ps.kernel
+    def poisson_jacobi():
+        dst[0,0] @= (src[1, 0] + src[-1, 0] + src[0, 1] + src[0, -1]) / 4
 
-    poisson_kernel = sfg.kernels.create(poisson_gs)
+    poisson_kernel = sfg.kernels.create(poisson_jacobi)
 
-    sfg.function("myFunction")(
-        sfg.map_field(src, std_mdspan(src.name, np.float64, (std_mdspan.dynamic_extent, std_mdspan.dynamic_extent, 1))),
-        sfg.map_field(dst, std_mdspan(dst.name, np.float64, (2, 2, 1))),
+    sfg.function("jacobi_smooth")(
+        sfg.map_field(src, field_t(src)),
+        sfg.map_field(dst, field_t(dst)),
         sfg.call(poisson_kernel)
     )
diff --git a/tests/mdspan/main.cpp b/tests/mdspan/main.cpp
index f44ca21..d8247a5 100644
--- a/tests/mdspan/main.cpp
+++ b/tests/mdspan/main.cpp
@@ -1,6 +1,58 @@
+#include <iostream>
+#include <fstream>
+
+#include <cstdint>
+#include <vector>
+
+#include <experimental/mdspan>
+
 #include "generated_src/kernels.h"
 
+using field_t = std::mdspan< double, std::extents< uint32_t, std::dynamic_extent, std::dynamic_extent > >;
+
+double boundary(double x, double y){
+    return 1.0;
+}
+
 int main(int argc, char ** argv){
-    pystencils::myFunction();
+    uint32_t N = 8; /* number of grid nodes */
+    double h = 1.0 / (double(N) - 1);
+    uint32_t n_iters = 100;
+
+    std::vector< double > data_src(N*N);
+    field_t src(data_src.data(), N, N);
+
+    std::vector< double > data_dst(N*N);
+    field_t dst(data_dst.data(), N, N);
+
+    for(uint32_t i = 0; i < N; ++i){
+        for(uint32_t j = 0; j < N; ++j){
+            if(i == 0 || j == 0 || i == N-1 || j == N-1){
+                src[i, j] = boundary(double(i) * h, double(j) * h);
+                dst[i, j] = boundary(double(i) * h, double(j) * h);
+            }
+        }
+    }
+    
+    for(uint32_t i = 0; i < n_iters; ++i){
+        poisson::jacobi_smooth(dst, src);
+        std::swap(src, dst);
+    }
+
+    std::ofstream f("data.out", std::ios::trunc | std::ios::out);
+
+    if(!f.is_open()){
+        std::cerr << "Could not open output file.\n";
+    } else {
+        for(uint32_t i = 0; i < N; ++i){
+            for(uint32_t j = 0; j < N; ++j){
+                f << src[i, j] << " ";
+            }
+            f << '\n';
+        }
+    }
+
+    f.close();
+
     return 0;
 }
-- 
GitLab