Skip to content
Snippets Groups Projects
Commit 3edfb527 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

added function params + documentation for sequencing

parent 4e5c7b3a
Branches
Tags
No related merge requests found
Pipeline #57802 passed with stages
in 43 seconds
::: pystencilssfg.context.SfgContext ::: pystencilssfg.context.SfgContext
::: pystencilssfg.composer.SfgComposer ::: pystencilssfg.composer
::: pystencilssfg.source_components.SfgKernelNamespace
::: pystencilssfg.source_components.SfgKernelHandle
::: pystencilssfg.source_components.SfgFunction
::: pystencilssfg.source_components
...@@ -21,6 +21,8 @@ plugins: ...@@ -21,6 +21,8 @@ plugins:
paths: [src] paths: [src]
options: options:
heading_level: 2 heading_level: 2
members_order: source
group_by_category: False
show_root_heading: True show_root_heading: True
show_root_full_path: False show_root_full_path: False
show_symbol_type_heading: True show_symbol_type_heading: True
...@@ -44,5 +46,6 @@ nav: ...@@ -44,5 +46,6 @@ nav:
- 'API Documentation': - 'API Documentation':
- 'Overview': api/index.md - 'Overview': api/index.md
- 'Source File Generator': api/generator.md - 'Source File Generator': api/generator.md
- 'Composer and Source File Components': api/composition.md - 'Composer': api/composition.md
- 'Source File Components': api/source_components.md
- 'Kernel Call Tree': api/tree.md - 'Kernel Call Tree': api/tree.md
...@@ -5,10 +5,22 @@ from abc import ABC, abstractmethod ...@@ -5,10 +5,22 @@ from abc import ABC, abstractmethod
from pystencils import Field from pystencils import Field
from pystencils.astnodes import KernelFunction from pystencils.astnodes import KernelFunction
from .tree import SfgCallTreeNode, SfgKernelCallNode, SfgStatements, SfgSequence, SfgBlock from .tree import (
SfgCallTreeNode,
SfgKernelCallNode,
SfgStatements,
SfgFunctionParams,
SfgSequence,
SfgBlock,
)
from .tree.deferred_nodes import SfgDeferredFieldMapping from .tree.deferred_nodes import SfgDeferredFieldMapping
from .tree.conditional import SfgCondition, SfgCustomCondition, SfgBranch from .tree.conditional import SfgCondition, SfgCustomCondition, SfgBranch
from .source_components import SfgFunction, SfgHeaderInclude, SfgKernelNamespace, SfgKernelHandle from .source_components import (
SfgFunction,
SfgHeaderInclude,
SfgKernelNamespace,
SfgKernelHandle,
)
from .source_concepts import SrcField, TypedSymbolOrObject, SrcVector from .source_concepts import SrcField, TypedSymbolOrObject, SrcVector
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -63,9 +75,18 @@ class SfgComposer: ...@@ -63,9 +75,18 @@ class SfgComposer:
header_file = header_file[1:-1] header_file = header_file[1:-1]
system_header = True system_header = True
self._ctx.add_include(SfgHeaderInclude(header_file, system_header=system_header)) self._ctx.add_include(
SfgHeaderInclude(header_file, system_header=system_header)
)
def kernel_function(
self, name: str, ast_or_kernel_handle: KernelFunction | SfgKernelHandle
):
"""Creates a function comprising just a single kernel call.
def kernel_function(self, name: str, ast_or_kernel_handle: KernelFunction | SfgKernelHandle): Args:
ast_or_kernel_handle: Either a pystencils AST, or a kernel handle for an already registered AST.
"""
if self._ctx.get_function(name) is not None: if self._ctx.get_function(name) is not None:
raise ValueError(f"Function {name} already exists.") raise ValueError(f"Function {name} already exists.")
...@@ -90,6 +111,9 @@ class SfgComposer: ...@@ -90,6 +111,9 @@ class SfgComposer:
# Function Body # Function Body
) )
``` ```
The function body is constructed via sequencing;
refer to [make_sequence][pystencilssfg.composer.make_sequence].
""" """
if self._ctx.get_function(name) is not None: if self._ctx.get_function(name) is not None:
raise ValueError(f"Function {name} already exists.") raise ValueError(f"Function {name} already exists.")
...@@ -110,8 +134,13 @@ class SfgComposer: ...@@ -110,8 +134,13 @@ class SfgComposer:
return SfgKernelCallNode(kernel_handle) return SfgKernelCallNode(kernel_handle)
def seq(self, *args: SfgCallTreeNode) -> SfgSequence: def seq(self, *args: SfgCallTreeNode) -> SfgSequence:
"""Syntax sequencing. For details, refer to [make_sequence][pystencilssfg.composer.make_sequence]"""
return make_sequence(*args) return make_sequence(*args)
def params(self, *args: TypedSymbolOrObject) -> SfgFunctionParams:
"""Use inside a function body to add parameters to the function."""
return SfgFunctionParams(args)
@property @property
def branch(self) -> SfgBranchBuilder: def branch(self) -> SfgBranchBuilder:
"""Use inside a function body to create an if/else conditonal branch. """Use inside a function body to create an if/else conditonal branch.
...@@ -137,16 +166,21 @@ class SfgComposer: ...@@ -137,16 +166,21 @@ class SfgComposer:
""" """
return SfgDeferredFieldMapping(field, src_object) return SfgDeferredFieldMapping(field, src_object)
def map_param(self, lhs: TypedSymbolOrObject, rhs: TypedSymbolOrObject, mapping: str): def map_param(
self, lhs: TypedSymbolOrObject, rhs: TypedSymbolOrObject, mapping: str
):
"""Arbitrary parameter mapping: Add a single line of code to define a left-hand """Arbitrary parameter mapping: Add a single line of code to define a left-hand
side object from a right-hand side.""" side object from a right-hand side."""
return SfgStatements(mapping, (lhs,), (rhs,)) return SfgStatements(mapping, (lhs,), (rhs,))
def map_vector(self, lhs_components: Sequence[TypedSymbolOrObject], rhs: SrcVector): def map_vector(self, lhs_components: Sequence[TypedSymbolOrObject], rhs: SrcVector):
"""Extracts scalar numerical values from a vector data type.""" """Extracts scalar numerical values from a vector data type."""
return make_sequence(*( return make_sequence(
rhs.extract_component(dest, coord) for coord, dest in enumerate(lhs_components) *(
)) rhs.extract_component(dest, coord)
for coord, dest in enumerate(lhs_components)
)
)
class SfgNodeBuilder(ABC): class SfgNodeBuilder(ABC):
...@@ -156,6 +190,53 @@ class SfgNodeBuilder(ABC): ...@@ -156,6 +190,53 @@ class SfgNodeBuilder(ABC):
def make_sequence(*args: tuple | str | SfgCallTreeNode | SfgNodeBuilder) -> SfgSequence: def make_sequence(*args: tuple | str | SfgCallTreeNode | SfgNodeBuilder) -> SfgSequence:
"""Construct a sequence of C++ code from various kinds of arguments.
`make_sequence` is ubiquitous throughout the function building front-end;
among others, it powers the syntax of
[SfgComposer.function][pystencilssfg.SfgComposer.function] and
[SfgComposer.branch][pystencilssfg.SfgComposer.branch].
`make_sequence` constructs an abstract syntax tree for code within a function body, accepting various
types of arguments which then get turned into C++ code. These are:
- Strings (`str`) are printed as-is
- Tuples (`tuple`) signify *blocks*, i.e. C++ code regions enclosed in `{ }`
- Sub-ASTs and AST builders, which are often produced by the syntactic sugar and
factory methods of [SfgComposer][pystencilssfg.SfgComposer].
Its usage is best shown by example:
```Python
tree = make_sequence(
"int a = 0;",
"int b = 1;",
(
"int tmp = b;",
"b = a;",
"a = tmp;"
),
SfgKernelCall(kernel_handle)
)
sfg.context.add_function("myFunction", tree)
```
will translate to
```C++
void myFunction() {
int a = 0;
int b = 0;
{
int tmp = b;
b = a;
a = tmp;
}
kernels::kernel( ... );
}
```
"""
children = [] children = []
for i, arg in enumerate(args): for i, arg in enumerate(args):
if isinstance(arg, SfgNodeBuilder): if isinstance(arg, SfgNodeBuilder):
...@@ -186,7 +267,9 @@ class SfgBranchBuilder(SfgNodeBuilder): ...@@ -186,7 +267,9 @@ class SfgBranchBuilder(SfgNodeBuilder):
match self._phase: match self._phase:
case 0: # Condition case 0: # Condition
if len(args) != 1: if len(args) != 1:
raise ValueError("Must specify exactly one argument as branch condition!") raise ValueError(
"Must specify exactly one argument as branch condition!"
)
cond = args[0] cond = args[0]
...@@ -194,7 +277,8 @@ class SfgBranchBuilder(SfgNodeBuilder): ...@@ -194,7 +277,8 @@ class SfgBranchBuilder(SfgNodeBuilder):
cond = SfgCustomCondition(cond) cond = SfgCustomCondition(cond)
elif not isinstance(cond, SfgCondition): elif not isinstance(cond, SfgCondition):
raise ValueError( raise ValueError(
"Invalid type for branch condition. Must be either `str` or a subclass of `SfgCondition`.") "Invalid type for branch condition. Must be either `str` or a subclass of `SfgCondition`."
)
self._cond = cond self._cond = cond
......
from .basic_nodes import SfgCallTreeNode, SfgKernelCallNode, SfgBlock, SfgSequence, SfgStatements from .basic_nodes import (
SfgCallTreeNode,
SfgKernelCallNode,
SfgBlock,
SfgSequence,
SfgStatements,
SfgFunctionParams,
)
from .conditional import SfgBranch, SfgCondition, IntEven, IntOdd from .conditional import SfgBranch, SfgCondition, IntEven, IntOdd
__all__ = [ __all__ = [
"SfgCallTreeNode", "SfgKernelCallNode", "SfgSequence", "SfgBlock", "SfgStatements", "SfgCallTreeNode",
"SfgCondition", "SfgBranch", "IntEven", "IntOdd" "SfgKernelCallNode",
"SfgSequence",
"SfgBlock",
"SfgStatements",
"SfgFunctionParams",
"SfgCondition",
"SfgBranch",
"IntEven",
"IntOdd",
] ]
...@@ -13,7 +13,8 @@ if TYPE_CHECKING: ...@@ -13,7 +13,8 @@ if TYPE_CHECKING:
class SfgCallTreeNode(ABC): class SfgCallTreeNode(ABC):
"""Base class for all nodes comprising SFG call trees. """ """Base class for all nodes comprising SFG call trees."""
def __init__(self, *children: SfgCallTreeNode): def __init__(self, *children: SfgCallTreeNode):
self._children = list(children) self._children = list(children)
...@@ -45,15 +46,15 @@ class SfgCallTreeNode(ABC): ...@@ -45,15 +46,15 @@ class SfgCallTreeNode(ABC):
@property @property
def required_includes(self) -> set[SfgHeaderInclude]: def required_includes(self) -> set[SfgHeaderInclude]:
"""Return a set of header includes required by this node"""
return set() return set()
class SfgCallTreeLeaf(SfgCallTreeNode, ABC): class SfgCallTreeLeaf(SfgCallTreeNode, ABC):
@property @property
@abstractmethod @abstractmethod
def required_parameters(self) -> set[TypedSymbolOrObject]: def required_parameters(self) -> set[TypedSymbolOrObject]:
pass ...
class SfgStatements(SfgCallTreeLeaf): class SfgStatements(SfgCallTreeLeaf):
...@@ -73,10 +74,12 @@ class SfgStatements(SfgCallTreeLeaf): ...@@ -73,10 +74,12 @@ class SfgStatements(SfgCallTreeLeaf):
required_objects: Objects (as `SrcObject` or `TypedSymbol`) that are required as input to these statements. required_objects: Objects (as `SrcObject` or `TypedSymbol`) that are required as input to these statements.
""" """
def __init__(self, def __init__(
code_string: str, self,
defined_params: Sequence[TypedSymbolOrObject], code_string: str,
required_params: Sequence[TypedSymbolOrObject]): defined_params: Sequence[TypedSymbolOrObject],
required_params: Sequence[TypedSymbolOrObject],
):
super().__init__() super().__init__()
self._code_string = code_string self._code_string = code_string
...@@ -105,6 +108,28 @@ class SfgStatements(SfgCallTreeLeaf): ...@@ -105,6 +108,28 @@ class SfgStatements(SfgCallTreeLeaf):
return self._code_string return self._code_string
class SfgFunctionParams(SfgCallTreeLeaf):
def __init__(self, parameters: Sequence[TypedSymbolOrObject]):
super().__init__()
self._params = set(parameters)
self._required_includes = set()
for obj in parameters:
if isinstance(obj, SrcObject):
self._required_includes |= obj.required_includes
@property
def required_parameters(self) -> set[TypedSymbolOrObject]:
return self._params
@property
def required_includes(self) -> set[SfgHeaderInclude]:
return self._required_includes
def get_code(self, ctx: SfgContext) -> str:
return ""
class SfgSequence(SfgCallTreeNode): class SfgSequence(SfgCallTreeNode):
def __init__(self, children: Sequence[SfgCallTreeNode]): def __init__(self, children: Sequence[SfgCallTreeNode]):
super().__init__(*children) super().__init__(*children)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment