Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Sequence
if TYPE_CHECKING:
from ..context import SfgContext
from abc import ABC, abstractmethod
from functools import reduce
from jinja2.filters import do_indent
from ..kernel_namespace import SfgKernelHandle
from pystencils.typing import TypedSymbol
class SfgCallTreeNode(ABC):
"""Base class for all nodes comprising SFG call trees. """
@property
@abstractmethod
def children(self) -> Sequence[SfgCallTreeNode]:
pass
@abstractmethod
def get_code(self, ctx: SfgContext) -> str:
"""Returns the code of this node.
By convention, the code block emitted by this function should not contain a trailing newline.
"""
pass
class SfgCallTreeLeaf(SfgCallTreeNode, ABC):
@property
def children(self) -> Sequence[SfgCallTreeNode]:
return ()
@property
@abstractmethod
def required_symbols(self) -> set(TypedSymbol):
pass
class SfgParameterDefinition(SfgCallTreeLeaf):
def __init__(self, defined_param: TypedSymbol, required_params: Set[TypedSymbol], code_string: str):
self._defined_param = defined_param
self._required_params = required_params
self._code_string = code_string
def defined_symbol(self) -> TypedSymbol:
return self._defined_param
@property
def required_symbols(self) -> set(TypedSymbol):
return self._required_params
def get_code(self):
return self._code_string
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
class SfgCustomStatement(SfgCallTreeLeaf):
def __init__(self, statement: str):
self._statement = statement
def required_symbols(self) -> set(TypedSymbol):
return set()
def get_code(self, ctx: SfgContext) -> str:
return self._statement
class SfgSequence(SfgCallTreeNode):
def __init__(self, children: Sequence[SfgCallTreeNode]):
self._children = tuple(children)
@property
def children(self) -> Sequence[SfgCallTreeNode]:
return self._children
def get_code(self, ctx: SfgContext) -> str:
return "\n".join(c.get_code(ctx) for c in self._children)
class SfgBlock(SfgCallTreeNode):
def __init__(self, subtree: SfgCallTreeNode):
super().__init__(ctx)
self._subtree = subtree
@property
def children(self) -> Sequence[SfgCallTreeNode]:
return { self._subtree }
def get_code(self, ctx: SfgContext) -> str:
subtree_code = ctx.codestyle.indent(self._subtree.get_code(ctx))
return "{\n" + subtree_code + "\n}"
class SfgKernelCallNode(SfgCallTreeLeaf):
def __init__(self, kernel_handle: SfgKernelHandle):
self._kernel_handle = kernel_handle
@property
def required_symbols(self) -> set(TypedSymbol):
return set(p.symbol for p in self._kernel_handle.parameters)
def get_code(self, ctx: SfgContext) -> str:
ast_params = self._kernel_handle.parameters
fnc_name = self._kernel_handle.fully_qualified_name
call_parameters = ", ".join([p.symbol.name for p in ast_params])
return f"{fnc_name}({call_parameters});"