Skip to content
Snippets Groups Projects
Commit cb4a449d authored by Frederik Hennig's avatar Frederik Hennig
Browse files

added switch-case

parent ed11c419
No related merge requests found
Pipeline #58322 passed with stages
in 4 minutes and 9 seconds
......@@ -16,7 +16,7 @@ from .tree import (
SfgBlock,
)
from .tree.deferred_nodes import SfgDeferredFieldMapping
from .tree.conditional import SfgCondition, SfgCustomCondition, SfgBranch
from .tree.conditional import SfgCondition, SfgCustomCondition, SfgBranch, SfgSwitch
from .source_components import (
SfgFunction,
SfgHeaderInclude,
......@@ -82,8 +82,8 @@ class SfgComposer:
return kns
def include(self, header_file: str):
self._ctx.add_include(parse_include(header_file))
def include(self, header_file: str, private: bool = False):
self._ctx.add_include(parse_include(header_file, private))
def numpy_struct(
self, name: str, dtype: np.dtype, add_constructor: bool = True
......@@ -154,7 +154,7 @@ class SfgComposer:
"""
return SfgKernelCallNode(kernel_handle)
def seq(self, *args: SfgCallTreeNode) -> SfgSequence:
def seq(self, *args: tuple | str | SfgCallTreeNode | SfgNodeBuilder) -> SfgSequence:
"""Syntax sequencing. For details, refer to [make_sequence][pystencilssfg.composer.make_sequence]"""
return make_sequence(*args)
......@@ -180,6 +180,9 @@ class SfgComposer:
"""
return SfgBranchBuilder()
def switch(self, switch_arg: str | TypedSymbolOrObject) -> SfgSwitchBuilder:
return SfgSwitchBuilder(switch_arg)
def map_field(self, field: Field, src_object: SrcField) -> SfgDeferredFieldMapping:
"""Map a pystencils field to a field data structure, from which pointers, sizes
and strides should be extracted.
......@@ -322,7 +325,37 @@ class SfgBranchBuilder(SfgNodeBuilder):
return SfgBranch(self._cond, self._branch_true, self._branch_false)
def parse_include(incl: str | SfgHeaderInclude):
class SfgSwitchBuilder(SfgNodeBuilder):
def __init__(self, switch_arg: str | TypedSymbolOrObject):
self._switch_arg = switch_arg
self._cases: dict[str, SfgCallTreeNode] = dict()
self._default: SfgCallTreeNode | None = None
def case(self, label: str):
if label in self._cases:
raise SfgException(f"Duplicate case: {label}")
def sequencer(*args):
tree = make_sequence(*args)
self._cases[label] = tree
return self
return sequencer
def default(self, *args):
if self._default is not None:
raise SfgException("Duplicate default case")
tree = make_sequence(*args)
self._default = tree
return self
def resolve(self) -> SfgCallTreeNode:
return SfgSwitch(self._switch_arg, self._cases, self._default)
def parse_include(incl: str | SfgHeaderInclude, private: bool = False):
if isinstance(incl, SfgHeaderInclude):
return incl
......@@ -331,7 +364,7 @@ def parse_include(incl: str | SfgHeaderInclude):
incl = incl[1:-1]
system_header = True
return SfgHeaderInclude(incl, system_header=system_header)
return SfgHeaderInclude(incl, system_header=system_header, private=private)
class SfgClassComposer:
......@@ -347,10 +380,9 @@ class SfgClassComposer:
def __call__(
self,
*args: SfgClassMember
| SfgClassComposer.ConstructorBuilder
| SrcObject
| str,
*args: (
SfgClassMember | SfgClassComposer.ConstructorBuilder | SrcObject | str
),
):
for arg in args:
self._vis_block.append_member(SfgClassComposer._resolve_member(arg))
......@@ -430,11 +462,13 @@ class SfgClassComposer:
self._ctx.add_class(cls)
def sequencer(
*args: SfgClassComposer.VisibilityContext
| SfgClassMember
| SfgClassComposer.ConstructorBuilder
| SrcObject
| str,
*args: (
SfgClassComposer.VisibilityContext
| SfgClassMember
| SfgClassComposer.ConstructorBuilder
| SrcObject
| str
),
):
default_ended = False
......@@ -465,7 +499,7 @@ class SfgClassComposer:
@staticmethod
def _resolve_member(
arg: SfgClassMember | SfgClassComposer.ConstructorBuilder | SrcObject | str,
arg: (SfgClassMember | SfgClassComposer.ConstructorBuilder | SrcObject | str),
):
if isinstance(arg, SrcObject):
return SfgMemberVariable(arg.name, arg.dtype)
......
from __future__ import annotations
from typing import TYPE_CHECKING, Optional, cast
from typing import TYPE_CHECKING, Optional, cast, Generator
from pystencils.typing import TypedSymbol, BasicType
......@@ -60,10 +60,12 @@ class IntOdd(SfgCondition):
class SfgBranch(SfgCallTreeNode):
def __init__(self,
cond: SfgCondition,
branch_true: SfgCallTreeNode,
branch_false: Optional[SfgCallTreeNode] = None):
def __init__(
self,
cond: SfgCondition,
branch_true: SfgCallTreeNode,
branch_false: Optional[SfgCallTreeNode] = None,
):
super().__init__(cond, branch_true, *((branch_false,) if branch_false else ()))
@property
......@@ -89,3 +91,45 @@ class SfgBranch(SfgCallTreeNode):
code += "\n}"
return code
class SfgSwitch(SfgCallTreeNode):
def __init__(
self,
switch_arg: str | TypedSymbolOrObject,
cases_dict: dict[str, SfgCallTreeNode],
default: SfgCallTreeNode | None = None,
):
children = tuple(cases_dict.values()) + (
(default,) if default is not None else ()
)
super().__init__(*children)
self._switch_arg = switch_arg
self._cases_dict = cases_dict
self._default = default
@property
def switch_arg(self) -> str | TypedSymbolOrObject:
return self._switch_arg
def cases(self) -> Generator[tuple[str, SfgCallTreeNode], None, None]:
yield from self._cases_dict.items()
@property
def default(self) -> SfgCallTreeNode | None:
return self._default
def get_code(self, ctx: SfgContext) -> str:
code = f"switch({self._switch_arg}) {{\n"
for label, subtree in self._cases_dict.items():
code += f"case {label}: {{\n"
code += ctx.codestyle.indent(subtree.get_code(ctx))
code += "\nbreak;\n}\n"
if self._default is not None:
code += "default: {\n"
code += ctx.codestyle.indent(self._default.get_code(ctx))
code += "\nbreak;\n}\n"
code += "}"
return code
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment