Skip to content
Snippets Groups Projects
Commit 3edfb527 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

added function params + documentation for sequencing

parent 4e5c7b3a
Branches
Tags
No related merge requests found
Pipeline #57802 passed with stages
in 43 seconds
::: pystencilssfg.context.SfgContext
::: pystencilssfg.composer.SfgComposer
::: pystencilssfg.source_components.SfgKernelNamespace
::: pystencilssfg.source_components.SfgKernelHandle
::: pystencilssfg.source_components.SfgFunction
::: pystencilssfg.composer
::: pystencilssfg.source_components
......@@ -21,6 +21,8 @@ plugins:
paths: [src]
options:
heading_level: 2
members_order: source
group_by_category: False
show_root_heading: True
show_root_full_path: False
show_symbol_type_heading: True
......@@ -44,5 +46,6 @@ nav:
- 'API Documentation':
- 'Overview': api/index.md
- 'Source File Generator': api/generator.md
- 'Composer and Source File Components': api/composition.md
- 'Composer': api/composition.md
- 'Source File Components': api/source_components.md
- 'Kernel Call Tree': api/tree.md
......@@ -5,10 +5,22 @@ from abc import ABC, abstractmethod
from pystencils import Field
from pystencils.astnodes import KernelFunction
from .tree import SfgCallTreeNode, SfgKernelCallNode, SfgStatements, SfgSequence, SfgBlock
from .tree import (
SfgCallTreeNode,
SfgKernelCallNode,
SfgStatements,
SfgFunctionParams,
SfgSequence,
SfgBlock,
)
from .tree.deferred_nodes import SfgDeferredFieldMapping
from .tree.conditional import SfgCondition, SfgCustomCondition, SfgBranch
from .source_components import SfgFunction, SfgHeaderInclude, SfgKernelNamespace, SfgKernelHandle
from .source_components import (
SfgFunction,
SfgHeaderInclude,
SfgKernelNamespace,
SfgKernelHandle,
)
from .source_concepts import SrcField, TypedSymbolOrObject, SrcVector
if TYPE_CHECKING:
......@@ -63,9 +75,18 @@ class SfgComposer:
header_file = header_file[1:-1]
system_header = True
self._ctx.add_include(SfgHeaderInclude(header_file, system_header=system_header))
self._ctx.add_include(
SfgHeaderInclude(header_file, system_header=system_header)
)
def kernel_function(
self, name: str, ast_or_kernel_handle: KernelFunction | SfgKernelHandle
):
"""Creates a function comprising just a single kernel call.
def kernel_function(self, name: str, ast_or_kernel_handle: KernelFunction | SfgKernelHandle):
Args:
ast_or_kernel_handle: Either a pystencils AST, or a kernel handle for an already registered AST.
"""
if self._ctx.get_function(name) is not None:
raise ValueError(f"Function {name} already exists.")
......@@ -90,6 +111,9 @@ class SfgComposer:
# Function Body
)
```
The function body is constructed via sequencing;
refer to [make_sequence][pystencilssfg.composer.make_sequence].
"""
if self._ctx.get_function(name) is not None:
raise ValueError(f"Function {name} already exists.")
......@@ -110,8 +134,13 @@ class SfgComposer:
return SfgKernelCallNode(kernel_handle)
def seq(self, *args: SfgCallTreeNode) -> SfgSequence:
"""Syntax sequencing. For details, refer to [make_sequence][pystencilssfg.composer.make_sequence]"""
return make_sequence(*args)
def params(self, *args: TypedSymbolOrObject) -> SfgFunctionParams:
"""Use inside a function body to add parameters to the function."""
return SfgFunctionParams(args)
@property
def branch(self) -> SfgBranchBuilder:
"""Use inside a function body to create an if/else conditonal branch.
......@@ -137,16 +166,21 @@ class SfgComposer:
"""
return SfgDeferredFieldMapping(field, src_object)
def map_param(self, lhs: TypedSymbolOrObject, rhs: TypedSymbolOrObject, mapping: str):
def map_param(
self, lhs: TypedSymbolOrObject, rhs: TypedSymbolOrObject, mapping: str
):
"""Arbitrary parameter mapping: Add a single line of code to define a left-hand
side object from a right-hand side."""
return SfgStatements(mapping, (lhs,), (rhs,))
def map_vector(self, lhs_components: Sequence[TypedSymbolOrObject], rhs: SrcVector):
"""Extracts scalar numerical values from a vector data type."""
return make_sequence(*(
rhs.extract_component(dest, coord) for coord, dest in enumerate(lhs_components)
))
return make_sequence(
*(
rhs.extract_component(dest, coord)
for coord, dest in enumerate(lhs_components)
)
)
class SfgNodeBuilder(ABC):
......@@ -156,6 +190,53 @@ class SfgNodeBuilder(ABC):
def make_sequence(*args: tuple | str | SfgCallTreeNode | SfgNodeBuilder) -> SfgSequence:
"""Construct a sequence of C++ code from various kinds of arguments.
`make_sequence` is ubiquitous throughout the function building front-end;
among others, it powers the syntax of
[SfgComposer.function][pystencilssfg.SfgComposer.function] and
[SfgComposer.branch][pystencilssfg.SfgComposer.branch].
`make_sequence` constructs an abstract syntax tree for code within a function body, accepting various
types of arguments which then get turned into C++ code. These are:
- Strings (`str`) are printed as-is
- Tuples (`tuple`) signify *blocks*, i.e. C++ code regions enclosed in `{ }`
- Sub-ASTs and AST builders, which are often produced by the syntactic sugar and
factory methods of [SfgComposer][pystencilssfg.SfgComposer].
Its usage is best shown by example:
```Python
tree = make_sequence(
"int a = 0;",
"int b = 1;",
(
"int tmp = b;",
"b = a;",
"a = tmp;"
),
SfgKernelCall(kernel_handle)
)
sfg.context.add_function("myFunction", tree)
```
will translate to
```C++
void myFunction() {
int a = 0;
int b = 0;
{
int tmp = b;
b = a;
a = tmp;
}
kernels::kernel( ... );
}
```
"""
children = []
for i, arg in enumerate(args):
if isinstance(arg, SfgNodeBuilder):
......@@ -186,7 +267,9 @@ class SfgBranchBuilder(SfgNodeBuilder):
match self._phase:
case 0: # Condition
if len(args) != 1:
raise ValueError("Must specify exactly one argument as branch condition!")
raise ValueError(
"Must specify exactly one argument as branch condition!"
)
cond = args[0]
......@@ -194,7 +277,8 @@ class SfgBranchBuilder(SfgNodeBuilder):
cond = SfgCustomCondition(cond)
elif not isinstance(cond, SfgCondition):
raise ValueError(
"Invalid type for branch condition. Must be either `str` or a subclass of `SfgCondition`.")
"Invalid type for branch condition. Must be either `str` or a subclass of `SfgCondition`."
)
self._cond = cond
......
from .basic_nodes import SfgCallTreeNode, SfgKernelCallNode, SfgBlock, SfgSequence, SfgStatements
from .basic_nodes import (
SfgCallTreeNode,
SfgKernelCallNode,
SfgBlock,
SfgSequence,
SfgStatements,
SfgFunctionParams,
)
from .conditional import SfgBranch, SfgCondition, IntEven, IntOdd
__all__ = [
"SfgCallTreeNode", "SfgKernelCallNode", "SfgSequence", "SfgBlock", "SfgStatements",
"SfgCondition", "SfgBranch", "IntEven", "IntOdd"
"SfgCallTreeNode",
"SfgKernelCallNode",
"SfgSequence",
"SfgBlock",
"SfgStatements",
"SfgFunctionParams",
"SfgCondition",
"SfgBranch",
"IntEven",
"IntOdd",
]
......@@ -13,7 +13,8 @@ if TYPE_CHECKING:
class SfgCallTreeNode(ABC):
"""Base class for all nodes comprising SFG call trees. """
"""Base class for all nodes comprising SFG call trees."""
def __init__(self, *children: SfgCallTreeNode):
self._children = list(children)
......@@ -45,15 +46,15 @@ class SfgCallTreeNode(ABC):
@property
def required_includes(self) -> set[SfgHeaderInclude]:
"""Return a set of header includes required by this node"""
return set()
class SfgCallTreeLeaf(SfgCallTreeNode, ABC):
@property
@abstractmethod
def required_parameters(self) -> set[TypedSymbolOrObject]:
pass
...
class SfgStatements(SfgCallTreeLeaf):
......@@ -73,10 +74,12 @@ class SfgStatements(SfgCallTreeLeaf):
required_objects: Objects (as `SrcObject` or `TypedSymbol`) that are required as input to these statements.
"""
def __init__(self,
code_string: str,
defined_params: Sequence[TypedSymbolOrObject],
required_params: Sequence[TypedSymbolOrObject]):
def __init__(
self,
code_string: str,
defined_params: Sequence[TypedSymbolOrObject],
required_params: Sequence[TypedSymbolOrObject],
):
super().__init__()
self._code_string = code_string
......@@ -105,6 +108,28 @@ class SfgStatements(SfgCallTreeLeaf):
return self._code_string
class SfgFunctionParams(SfgCallTreeLeaf):
def __init__(self, parameters: Sequence[TypedSymbolOrObject]):
super().__init__()
self._params = set(parameters)
self._required_includes = set()
for obj in parameters:
if isinstance(obj, SrcObject):
self._required_includes |= obj.required_includes
@property
def required_parameters(self) -> set[TypedSymbolOrObject]:
return self._params
@property
def required_includes(self) -> set[SfgHeaderInclude]:
return self._required_includes
def get_code(self, ctx: SfgContext) -> str:
return ""
class SfgSequence(SfgCallTreeNode):
def __init__(self, children: Sequence[SfgCallTreeNode]):
super().__init__(*children)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment