From 2975c223a5fbbc4c5e45d0ef3bb5f639de338fc2 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Thu, 9 Nov 2023 13:28:17 +0900 Subject: [PATCH] completed field mapping, mdspan integration test work in progress --- pystencilssfg/context.py | 63 +++++++++----- pystencilssfg/emitters/cpu/basic_cpu.py | 3 +- .../emitters/cpu/templates/BasicCpu.tmpl.cpp | 6 +- .../emitters/cpu/templates/BasicCpu.tmpl.h | 7 +- pystencilssfg/exceptions.py | 3 + pystencilssfg/kernel_namespace.py | 4 +- pystencilssfg/source_components/__init__.py | 6 ++ .../function.py} | 16 ++-- .../source_components/header_include.py | 29 +++++++ .../source_concepts/cpp/std_mdspan.py | 32 ++++++-- .../source_concepts/source_concepts.py | 31 ++----- pystencilssfg/tree/basic_nodes.py | 40 ++++++++- pystencilssfg/tree/conditional.py | 6 ++ pystencilssfg/tree/deferred_nodes.py | 82 +++++++++++++++++++ pystencilssfg/tree/visitors.py | 72 +++++++++++++++- tests/mdspan/Makefile | 4 +- tests/mdspan/kernels.py | 16 +++- 17 files changed, 352 insertions(+), 68 deletions(-) create mode 100644 pystencilssfg/exceptions.py create mode 100644 pystencilssfg/source_components/__init__.py rename pystencilssfg/{source_components.py => source_components/function.py} (65%) create mode 100644 pystencilssfg/source_components/header_include.py create mode 100644 pystencilssfg/tree/deferred_nodes.py diff --git a/pystencilssfg/context.py b/pystencilssfg/context.py index 572d843..659eecc 100644 --- a/pystencilssfg/context.py +++ b/pystencilssfg/context.py @@ -13,10 +13,12 @@ from pystencils import Field from pystencils.astnodes import KernelFunction from .kernel_namespace import SfgKernelNamespace, SfgKernelHandle -from .tree import SfgCallTreeNode, SfgSequence, SfgKernelCallNode, SfgCondition, SfgBranch +from .tree import SfgCallTreeNode, SfgSequence, SfgKernelCallNode +from .tree.deferred_nodes import SfgDeferredFieldMapping from .tree.builders import SfgBranchBuilder, make_sequence +from .tree.visitors import CollectIncludes from .source_concepts.containers import SrcField -from .source_components import SfgFunction +from .source_components import SfgFunction, SfgHeaderInclude @dataclass @@ -73,7 +75,7 @@ class SfgContext: self._default_kernel_namespace = SfgKernelNamespace(self, "kernels") # Source Components - self._includes = [] + self._includes = set() self._kernel_namespaces = { self._default_kernel_namespace.name : self._default_kernel_namespace } self._functions = dict() @@ -89,6 +91,38 @@ class SfgContext: def codestyle(self) -> SfgCodeStyle: return self._codestyle + #---------------------------------------------------------------------------------------------- + # Source Component Getters + #---------------------------------------------------------------------------------------------- + + def includes(self) -> Generator[SfgHeaderInclude, None, None]: + yield from self._includes + + def kernel_namespaces(self) -> Generator[SfgKernelNamespace, None, None]: + yield from self._kernel_namespaces.values() + + def functions(self) -> Generator[SfgFunction, None, None]: + yield from self._functions.values() + + #---------------------------------------------------------------------------------------------- + # Source Component Adders + #---------------------------------------------------------------------------------------------- + + def add_include(self, include: SfgHeaderInclude): + self._includes.add(include) + + def add_function(self, func: SfgFunction): + if func.name in self._functions: + raise ValueError(f"Duplicate function: {func.name}") + + self._functions[func.name] = func + for incl in CollectIncludes().visit(func._tree): + self.add_include(incl) + + #---------------------------------------------------------------------------------------------- + # Factory-like Adders + #---------------------------------------------------------------------------------------------- + @property def kernels(self) -> SfgKernelNamespace: return self._default_kernel_namespace @@ -100,23 +134,14 @@ class SfgContext: kns = SfgKernelNamespace(self, name) self._kernel_namespaces[name] = kns return kns - - def includes(self) -> Generator[str, None, None]: - yield from self._includes - - def kernel_namespaces(self) -> Generator[SfgKernelNamespace, None, None]: - yield from self._kernel_namespaces.values() - - def functions(self) -> Generator[SfgFunction, None, None]: - yield from self._functions.values() def include(self, header_file: str): - if not (header_file.startswith("<") and header_file.endswith(">")): - if not (header_file.startswith('"') and header_file.endswith('"')): - header_file = f'"{header_file}"' + system_header = False + if header_file.startswith("<") and header_file.endswith(">"): + header_file = header_file[1:-1] + system_header = True - self._includes.append(header_file) - + self.add_include(SfgHeaderInclude(header_file, system_header=system_header)) def function(self, name: str, @@ -136,7 +161,7 @@ class SfgContext: def sequencer(*args: SfgCallTreeNode): tree = make_sequence(*args) func = SfgFunction(self, name, tree) - self._functions[name] = func + self.add_function(func) return sequencer @@ -156,5 +181,5 @@ class SfgContext: if src_object is None: raise NotImplementedError("Automatic field extraction is not implemented yet.") else: - return src_object.extract_parameters(field) + return SfgDeferredFieldMapping(field, src_object) \ No newline at end of file diff --git a/pystencilssfg/emitters/cpu/basic_cpu.py b/pystencilssfg/emitters/cpu/basic_cpu.py index d36c98e..954b300 100644 --- a/pystencilssfg/emitters/cpu/basic_cpu.py +++ b/pystencilssfg/emitters/cpu/basic_cpu.py @@ -24,7 +24,8 @@ class BasicCpuEmitter: 'ctx': self._ctx, 'basename': self._basename, 'root_namespace': self._ctx.root_namespace, - 'includes': list(self._ctx.includes()), + 'public_includes': list(incl.get_code() for incl in self._ctx.includes() if not incl.private), + 'private_includes': list(incl.get_code() for incl in self._ctx.includes() if incl.private), 'kernel_namespaces': list(self._ctx.kernel_namespaces()), 'functions': list(self._ctx.functions()) } diff --git a/pystencilssfg/emitters/cpu/templates/BasicCpu.tmpl.cpp b/pystencilssfg/emitters/cpu/templates/BasicCpu.tmpl.cpp index c5d368a..80cfe4d 100644 --- a/pystencilssfg/emitters/cpu/templates/BasicCpu.tmpl.cpp +++ b/pystencilssfg/emitters/cpu/templates/BasicCpu.tmpl.cpp @@ -1,5 +1,9 @@ #include "{{basename}}.h" +{% for incl in private_includes -%} +{{incl}} +{% endfor %} + #define FUNC_PREFIX inline namespace {{root_namespace}} { @@ -9,7 +13,7 @@ namespace {{root_namespace}} { *************************************************************************************/ {% for kns in kernel_namespaces -%} -namespace {{ kns.name }}{ +namespace {{ kns.name }} { {% for ast in kns.asts %} {{ ast | generate_kernel_definition }} diff --git a/pystencilssfg/emitters/cpu/templates/BasicCpu.tmpl.h b/pystencilssfg/emitters/cpu/templates/BasicCpu.tmpl.h index 8301a21..cf594e5 100644 --- a/pystencilssfg/emitters/cpu/templates/BasicCpu.tmpl.h +++ b/pystencilssfg/emitters/cpu/templates/BasicCpu.tmpl.h @@ -1,13 +1,12 @@ #pragma once -#define RESTRICT __restrict__ - #include <cstdint> -{% for header in includes %} -#include {{header}} +{% for incl in public_includes -%} +{{incl}} {% endfor %} +#define RESTRICT __restrict__ namespace {{root_namespace}} { diff --git a/pystencilssfg/exceptions.py b/pystencilssfg/exceptions.py new file mode 100644 index 0000000..1351733 --- /dev/null +++ b/pystencilssfg/exceptions.py @@ -0,0 +1,3 @@ + +class SfgException(Exception): + pass diff --git a/pystencilssfg/kernel_namespace.py b/pystencilssfg/kernel_namespace.py index 4af5da4..191d6d3 100644 --- a/pystencilssfg/kernel_namespace.py +++ b/pystencilssfg/kernel_namespace.py @@ -29,8 +29,8 @@ class SfgKernelNamespace: return SfgKernelHandle(self._ctx, astname, self, ast.get_parameters()) - def create(self, assignments, config: CreateKernelConfig): - ast = create_kernel(assignments, config) + def create(self, assignments, config: CreateKernelConfig = None): + ast = create_kernel(assignments, config=config) return self.add(ast) diff --git a/pystencilssfg/source_components/__init__.py b/pystencilssfg/source_components/__init__.py new file mode 100644 index 0000000..307c354 --- /dev/null +++ b/pystencilssfg/source_components/__init__.py @@ -0,0 +1,6 @@ +from .function import SfgFunction +from .header_include import SfgHeaderInclude + +__all__ = [ + SfgFunction, SfgHeaderInclude +] \ No newline at end of file diff --git a/pystencilssfg/source_components.py b/pystencilssfg/source_components/function.py similarity index 65% rename from pystencilssfg/source_components.py rename to pystencilssfg/source_components/function.py index c296e97..26ff659 100644 --- a/pystencilssfg/source_components.py +++ b/pystencilssfg/source_components/function.py @@ -2,10 +2,10 @@ from __future__ import annotations from typing import TYPE_CHECKING if TYPE_CHECKING: - from .context import SfgContext - -from .tree import SfgCallTreeNode, SfgSequence -from .tree.visitors import FlattenSequences, ParameterCollector + from ..context import SfgContext + +from ..tree import SfgCallTreeNode +from ..tree.visitors import FlattenSequences, ExpandingParameterCollector class SfgFunction: def __init__(self, ctx: SfgContext, name: str, tree: SfgCallTreeNode): @@ -14,9 +14,9 @@ class SfgFunction: self._tree = tree flattener = FlattenSequences() - flattener.visit(self._tree) + # flattener.visit(self._tree) - param_collector = ParameterCollector() + param_collector = ExpandingParameterCollector(self._ctx) self._parameters = param_collector.visit(self._tree) @property @@ -26,6 +26,10 @@ class SfgFunction: @property def parameters(self): return self._parameters + + @property + def tree(self): + return self._tree def get_code(self): return self._tree.get_code(self._ctx) diff --git a/pystencilssfg/source_components/header_include.py b/pystencilssfg/source_components/header_include.py new file mode 100644 index 0000000..915423d --- /dev/null +++ b/pystencilssfg/source_components/header_include.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +class SfgHeaderInclude: + def __init__(self, header_file: str, system_header : bool = False, private: bool = False): + self._header_file = header_file + self._system_header = system_header + self._private = private + + @property + def system_header(self): + return self._system_header + + @property + def private(self): + return self._private + + def get_code(self): + if self._system_header: + return f"#include <{self._header_file}>" + else: + return f'#include "{self._header_file}"' + + def __hash__(self) -> int: + return hash((self._header_file, self._system_header, self._private)) + + def __eq__(self, other: SfgHeaderInclude) -> bool: + return (self._header_file == other._header_file + and self._system_header == other._system_header + and self._private == other._private) diff --git a/pystencilssfg/source_concepts/cpp/std_mdspan.py b/pystencilssfg/source_concepts/cpp/std_mdspan.py index 93c8811..c4ff910 100644 --- a/pystencilssfg/source_concepts/cpp/std_mdspan.py +++ b/pystencilssfg/source_concepts/cpp/std_mdspan.py @@ -1,13 +1,29 @@ -from typing import Union +from typing import Set, Union, Tuple +from numpy import dtype from pystencils.typing import FieldPointerSymbol, FieldStrideSymbol, FieldShapeSymbol +from pystencilssfg.source_components import SfgHeaderInclude + from ...tree import SfgStatements from ..containers import SrcField +from ...source_components.header_include import SfgHeaderInclude +from ...exceptions import SfgException class std_mdspan(SrcField): - def __init__(self, identifer: str): - super().__init__("std::mdspan", identifer) + dynamic_extent = "std::dynamic_extent" + + def __init__(self, identifer: str, T: dtype, extents: Tuple[int, str]): + from pystencils.typing import create_type + T = create_type(T) + typestring = f"std::mdspan< {T}, std::extents< int, {', '.join(str(e) for e in extents)} > >" + super().__init__(typestring, identifer) + + self._extents = extents + + @property + def required_includes(self) -> Set[SfgHeaderInclude]: + return { SfgHeaderInclude("experimental/mdspan", system_header=True) } def extract_ptr(self, ptr_symbol: FieldPointerSymbol): return SfgStatements( @@ -17,6 +33,9 @@ class std_mdspan(SrcField): ) def extract_size(self, coordinate: int, size: Union[int, FieldShapeSymbol]) -> SfgStatements: + if coordinate >= len(self._extents): + raise SfgException(f"Cannot extract size in coordinate {coordinate} from a {len(self._extents)}-dimensional mdspan") + if isinstance(size, FieldShapeSymbol): return SfgStatements( f"{size.dtype} {size.name} = {self._identifier}.extents().extent({coordinate});", @@ -29,8 +48,11 @@ class std_mdspan(SrcField): (), (self, ) ) - def extract_stride(self, coordinate: int, stride: Union[int, FieldShapeSymbol]) -> SfgStatements: - if isinstance(stride, FieldShapeSymbol): + def extract_stride(self, coordinate: int, stride: Union[int, FieldStrideSymbol]) -> SfgStatements: + if coordinate >= len(self._extents): + raise SfgException(f"Cannot extract size in coordinate {coordinate} from a {len(self._extents)}-dimensional mdspan") + + if isinstance(stride, FieldStrideSymbol): return SfgStatements( f"{stride.dtype} {stride.name} = {self._identifier}.stride({coordinate});", (stride, ), diff --git a/pystencilssfg/source_concepts/source_concepts.py b/pystencilssfg/source_concepts/source_concepts.py index 4e663f4..66a22d8 100644 --- a/pystencilssfg/source_concepts/source_concepts.py +++ b/pystencilssfg/source_concepts/source_concepts.py @@ -1,6 +1,10 @@ from __future__ import annotations -from typing import Optional, Sequence, Union, Set +from typing import TYPE_CHECKING, Optional, Sequence, Union, Set + +if TYPE_CHECKING: + from ..source_components import SfgHeaderInclude + from abc import ABC from pystencils import TypedSymbol @@ -17,27 +21,10 @@ class SrcObject(ABC): def identifier(self): return self._identifier + @property + def required_includes(self) -> Set[SfgHeaderInclude]: + return set() + @property def typed_symbol(self): return TypedSymbol(self._identifier, self._src_type) - - - - - -# class SrcExpression(SrcStatements): -# """Represents a single expression of the source language, e.g. a C++ expression -# (c.f. https://en.cppreference.com/w/cpp/language/expressions). - -# It is the user's responsibility to ensure that the code string is valid code in the output language, -# and that the list of required objects is complete. - -# Args: -# code_string: Code to be printed out. -# required_objects: Objects (as `SrcObject` or `TypedSymbol`) that are required as input to this expression. -# """ - -# def __init__(self, -# code_string: str, -# required_objects: Sequence[Union[SrcObject, TypedSymbol]]): -# super().__init__(code_string, (), required_objects) diff --git a/pystencilssfg/tree/basic_nodes.py b/pystencilssfg/tree/basic_nodes.py index 3750bd5..429f504 100644 --- a/pystencilssfg/tree/basic_nodes.py +++ b/pystencilssfg/tree/basic_nodes.py @@ -3,15 +3,19 @@ from typing import TYPE_CHECKING, Any, Sequence, Set, Union, Iterable if TYPE_CHECKING: from ..context import SfgContext + from ..source_components import SfgHeaderInclude from abc import ABC, abstractmethod from functools import reduce +from itertools import chain from jinja2.filters import do_indent from ..kernel_namespace import SfgKernelHandle from ..source_concepts.source_concepts import SrcObject +from ..exceptions import SfgException + from pystencils.typing import TypedSymbol class SfgCallTreeNode(ABC): @@ -22,6 +26,10 @@ class SfgCallTreeNode(ABC): def children(self) -> Sequence[SfgCallTreeNode]: pass + @abstractmethod + def replace_child(self, child_idx: int, node: SfgCallTreeNode) -> None: + pass + @abstractmethod def get_code(self, ctx: SfgContext) -> str: """Returns the code of this node. @@ -30,12 +38,19 @@ class SfgCallTreeNode(ABC): """ pass + @property + def required_includes(self) -> Set[SfgHeaderInclude]: + return set() + class SfgCallTreeLeaf(SfgCallTreeNode, ABC): @property def children(self) -> Sequence[SfgCallTreeNode]: return () + + def replace_child(self, child_idx: int, node: SfgCallTreeNode) -> None: + raise SfgException("Leaf nodes have no children.") @property @abstractmethod @@ -68,14 +83,19 @@ class SfgStatements(SfgCallTreeLeaf): def to_symbol(obj: Union[SrcObject, TypedSymbol]): if isinstance(obj, SrcObject): - self._required_symbols.add(obj.typed_symbol) + return obj.typed_symbol elif isinstance(obj, TypedSymbol): - self._required_symbols.add(obj) + return obj else: raise ValueError(f"Required object in expression is neither TypedSymbol nor SrcObject: {obj}") self._defined_symbols = set(map(to_symbol, defined_objects)) self._required_symbols = set(map(to_symbol, required_objects)) + + self._required_includes = set() + for obj in chain(required_objects, defined_objects): + if isinstance(obj, SrcObject): + self._required_includes |= obj.required_includes @property def required_symbols(self) -> Set[TypedSymbol]: @@ -84,6 +104,10 @@ class SfgStatements(SfgCallTreeLeaf): @property def defined_symbols(self) -> Set[TypedSymbol]: return self._defined_symbols + + @property + def required_includes(self) -> Set[SfgHeaderInclude]: + return self._required_includes def get_code(self, ctx: SfgContext) -> str: return self._code_string @@ -91,12 +115,15 @@ class SfgStatements(SfgCallTreeLeaf): class SfgSequence(SfgCallTreeNode): def __init__(self, children: Sequence[SfgCallTreeNode]): - self._children = tuple(children) + self._children = list(children) @property def children(self) -> Sequence[SfgCallTreeNode]: return self._children + def replace_child(self, child_idx: int, node: SfgCallTreeNode) -> None: + self._children[child_idx] = node + def get_code(self, ctx: SfgContext) -> str: return "\n".join(c.get_code(ctx) for c in self._children) @@ -108,7 +135,12 @@ class SfgBlock(SfgCallTreeNode): @property def children(self) -> Sequence[SfgCallTreeNode]: - return { self._subtree } + return [self._subtree] + + def replace_child(self, child_idx: int, node: SfgCallTreeNode) -> None: + match child_idx: + case 0: self._subtree = node + case _: raise IndexError(f"Invalid child index: {child_idx}. SfgBlock has only a single child.") def get_code(self, ctx: SfgContext) -> str: subtree_code = ctx.codestyle.indent(self._subtree.get_code(ctx)) diff --git a/pystencilssfg/tree/conditional.py b/pystencilssfg/tree/conditional.py index 52f130a..acfcc61 100644 --- a/pystencilssfg/tree/conditional.py +++ b/pystencilssfg/tree/conditional.py @@ -40,6 +40,12 @@ class SfgBranch(SfgCallTreeNode): else: return (self._branch_true,) + def replace_child(self, child_idx: int, node: SfgCallTreeNode) -> None: + match child_idx: + case 0: self._branch_true = node + case 1: self._branch_false = node + case _: raise IndexError(f"Invalid child index: {child_idx}. SfgBlock has only two children.") + def get_code(self, ctx: SfgContext) -> str: code = f"if({self._cond.get_code(ctx)}) {{\n" code += ctx.codestyle.indent(self._branch_true.get_code(ctx)) diff --git a/pystencilssfg/tree/deferred_nodes.py b/pystencilssfg/tree/deferred_nodes.py new file mode 100644 index 0000000..349c7bd --- /dev/null +++ b/pystencilssfg/tree/deferred_nodes.py @@ -0,0 +1,82 @@ +from __future__ import annotations +from typing import TYPE_CHECKING, Any, Sequence, Set, Union, Iterable + +if TYPE_CHECKING: + from ..context import SfgContext + +from abc import ABC, abstractmethod + +from pystencils import Field, TypedSymbol +from pystencils.typing import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol + +from ..exceptions import SfgException + +from .basic_nodes import SfgCallTreeNode +from .builders import make_sequence + +from ..source_concepts.containers import SrcField + + +class SfgDeferredNode(SfgCallTreeNode, ABC): + """Nodes of this type are inserted as placeholders into the kernel call tree and need to be expanded at a later time. + + Subclasses of SfgDeferredNode correspond to nodes that cannot be created yet because information required for their + construction is not yet known. + """ + + @property + def children(self) -> Sequence[SfgCallTreeNode]: + raise SfgException("Deferred nodes cannot be descended into; expand it first.") + + def replace_child(self, child_idx: int, node: SfgCallTreeNode) -> None: + raise SfgException("Deferred nodes do not have children.") + + def get_code(self, ctx: SfgContext) -> str: + raise SfgException("Deferred nodes can not generate code; they need to be expanded first.") + + @abstractmethod + def expand(self, ctx: SfgContext, *args, **kwargs) -> SfgCallTreeNode: + pass + + +class SfgParamCollectionDeferredNode(SfgDeferredNode, ABC): + @abstractmethod + def expand(self, ctx: SfgContext, visible_params: Set[TypedSymbol]) -> SfgCallTreeNode: + pass + + +class SfgDeferredFieldMapping(SfgParamCollectionDeferredNode): + def __init__(self, field: Field, src_field: SrcField): + self._field = field + self._src_field = src_field + + def expand(self, ctx: SfgContext, visible_params: Set[TypedSymbol]) -> SfgCallTreeNode: + # Find field pointer + ptr = None + for param in visible_params: + if isinstance(param, FieldPointerSymbol) and param.field_name == self._field.name: + if param.dtype.base_type != self._field.dtype: + raise SfgException("Data type mismatch between field and encountered pointer symbol") + ptr = param + + # Find required sizes + shape = [] + for c, s in enumerate(self._field.shape): + if isinstance(s, FieldShapeSymbol) and s not in visible_params: + continue + else: + shape.append((c, s)) + + # Find required strides + strides = [] + for c, s in enumerate(self._field.strides): + if isinstance(s, FieldStrideSymbol) and s not in visible_params: + continue + else: + strides.append((c, s)) + + return make_sequence( + self._src_field.extract_ptr(ptr), + *(self._src_field.extract_size(c, s) for c, s in shape), + *(self._src_field.extract_stride(c, s) for c, s in strides) + ) diff --git a/pystencilssfg/tree/visitors.py b/pystencilssfg/tree/visitors.py index c94ee53..88c078d 100644 --- a/pystencilssfg/tree/visitors.py +++ b/pystencilssfg/tree/visitors.py @@ -1,9 +1,15 @@ -from typing import Set +from __future__ import annotations +from typing import TYPE_CHECKING, Any, Sequence, Set, Union, Iterable + +if TYPE_CHECKING: + from ..context import SfgContext + from functools import reduce from pystencils.typing import TypedSymbol from .basic_nodes import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence, SfgStatements +from .deferred_nodes import SfgParamCollectionDeferredNode class FlattenSequences(): @@ -33,6 +39,70 @@ class FlattenSequences(): sequence._children = children_flattened +class CollectIncludes: + def visit(self, node: SfgCallTreeNode): + includes = node.required_includes + for c in node.children: + includes |= self.visit(c) + + return includes + + +class ExpandingParameterCollector(): + def __init__(self, ctx: SfgContext) -> None: + self._ctx = ctx + + self._flattener = FlattenSequences() + + """Collects all parameters required but not defined in a kernel call tree. + Expands any deferred nodes of type `SfgParamCollectionDeferredNode` found within sequences on the way. + """ + def visit(self, node: SfgCallTreeNode) -> Set[TypedSymbol]: + if isinstance(node, SfgCallTreeLeaf): + return self._visit_SfgCallTreeLeaf(node) + elif isinstance(node, SfgSequence): + return self._visit_SfgSequence(node) + else: + return self._visit_branchingNode(node) + + def _visit_SfgCallTreeLeaf(self, leaf: SfgCallTreeLeaf) -> Set[TypedSymbol]: + return leaf.required_symbols + + def _visit_SfgSequence(self, sequence: SfgSequence) -> Set[TypedSymbol]: + """ + Only in a sequence may parameters be defined and visible to subsequent nodes. + """ + + params = set() + + def iter_nested_sequences(seq: SfgSequence, visible_params: Set[TypedSymbol]): + for i in range(len(seq.children) - 1, -1, -1): + c = seq.children[i] + + if isinstance(c, SfgParamCollectionDeferredNode): + c = c.expand(self._ctx, visible_params=visible_params) + seq.replace_child(i, c) + + if isinstance(c, SfgSequence): + iter_nested_sequences(c, visible_params) + else: + if isinstance(c, SfgStatements): + visible_params -= c.defined_symbols + + visible_params |= self.visit(c) + + iter_nested_sequences(sequence, params) + + return params + + def _visit_branchingNode(self, node: SfgCallTreeNode): + """ + Each interior node that is not a sequence simply requires the union of all parameters + required by its children. + """ + return reduce(lambda x, y: x | y, (self.visit(c) for c in node.children), set()) + + class ParameterCollector(): """Collects all parameters required but not defined in a kernel call tree. diff --git a/tests/mdspan/Makefile b/tests/mdspan/Makefile index cf51bde..76e57a6 100644 --- a/tests/mdspan/Makefile +++ b/tests/mdspan/Makefile @@ -1,6 +1,6 @@ CXX := clang++ -CXX_FLAGS := '-DDEBUG -g' +CXX_FLAGS := -DDEBUG -g -std=c++2b -I/home/fhennig/lssgit/mdspan/include PYTHON := python @@ -16,7 +16,7 @@ GEN_SRC := generated_src all: $(BIN)/mdspan_test clean: - rm -rf $(BIN) $(OUT) $(GEN_SRC) + rm -rf $(BIN) $(OBJ) $(GEN_SRC) $(GEN_SRC)/kernels.cpp $(GEN_SRC)/kernels.h &: kernels.py @$(dir_guard) diff --git a/tests/mdspan/kernels.py b/tests/mdspan/kernels.py index 11942a3..d8f582c 100644 --- a/tests/mdspan/kernels.py +++ b/tests/mdspan/kernels.py @@ -1,8 +1,22 @@ +import numpy as np +from pystencils.session import * + from pystencilssfg import SourceFileGenerator +from pystencilssfg.source_concepts.cpp.std_mdspan import std_mdspan with SourceFileGenerator() as sfg: + src, dst = ps.fields("src, dst(1) : double[2D]") + + @ps.kernel + def poisson_gs(): + dst[0,0] @= src[1, 0] + src[-1, 0] + src[0, 1] + src[0, -1] - 4 * src[0, 0] + sfg.include("<iostream>") + poisson_kernel = sfg.kernels.create(poisson_gs) + sfg.function("myFunction")( - r'std::cout << "mdspans!\n";' + sfg.map_field(src, std_mdspan(src.name, np.float64, (std_mdspan.dynamic_extent, std_mdspan.dynamic_extent, 1))), + sfg.map_field(dst, std_mdspan(dst.name, np.float64, (2, 2, 1))), + sfg.call(poisson_kernel) ) -- GitLab