Newer
Older
from typing import TYPE_CHECKING, Any, Sequence, Set, Union, Iterable
if TYPE_CHECKING:
from ..context import SfgContext
from abc import ABC, abstractmethod
from functools import reduce
from jinja2.filters import do_indent
from ..kernel_namespace import SfgKernelHandle
from ..source_concepts.source_concepts import SrcObject
from pystencils.typing import TypedSymbol
class SfgCallTreeNode(ABC):
"""Base class for all nodes comprising SFG call trees. """
@property
@abstractmethod
def children(self) -> Sequence[SfgCallTreeNode]:
pass
@abstractmethod
def get_code(self, ctx: SfgContext) -> str:
"""Returns the code of this node.
By convention, the code block emitted by this function should not contain a trailing newline.
"""
pass
class SfgCallTreeLeaf(SfgCallTreeNode, ABC):
@property
def children(self) -> Sequence[SfgCallTreeNode]:
return ()
@property
@abstractmethod
def required_symbols(self) -> Set[TypedSymbol]:
class SfgStatements(SfgCallTreeLeaf):
"""Represents (a sequence of) statements in the source language.
This class groups together arbitrary code strings
(e.g. sequences of C++ statements, cf. https://en.cppreference.com/w/cpp/language/statements),
and annotates them with the set of symbols read and written by these statements.
It is the user's responsibility to ensure that the code string is valid code in the output language,
and that the lists of required and defined objects are correct and complete.
Args:
code_string: Code to be printed out.
defined_objects: Objects (as `SrcObject` or `TypedSymbol`) that will be newly defined and visible to
code in sequence after these statements.
required_objects: Objects (as `SrcObject` or `TypedSymbol`) that are required as input to these statements.
"""
def __init__(self,
code_string: str,
defined_objects: Sequence[Union[SrcObject, TypedSymbol]],
required_objects: Sequence[Union[SrcObject, TypedSymbol]]):
self._code_string = code_string
def to_symbol(obj: Union[SrcObject, TypedSymbol]):
if isinstance(obj, SrcObject):
self._required_symbols.add(obj.typed_symbol)
elif isinstance(obj, TypedSymbol):
self._required_symbols.add(obj)
else:
raise ValueError(f"Required object in expression is neither TypedSymbol nor SrcObject: {obj}")
self._defined_symbols = set(map(to_symbol, defined_objects))
self._required_symbols = set(map(to_symbol, required_objects))
def required_symbols(self) -> Set[TypedSymbol]:
return self._required_symbols
@property
def defined_symbols(self) -> Set[TypedSymbol]:
return self._defined_symbols
return self._code_string
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
class SfgSequence(SfgCallTreeNode):
def __init__(self, children: Sequence[SfgCallTreeNode]):
self._children = tuple(children)
@property
def children(self) -> Sequence[SfgCallTreeNode]:
return self._children
def get_code(self, ctx: SfgContext) -> str:
return "\n".join(c.get_code(ctx) for c in self._children)
class SfgBlock(SfgCallTreeNode):
def __init__(self, subtree: SfgCallTreeNode):
super().__init__(ctx)
self._subtree = subtree
@property
def children(self) -> Sequence[SfgCallTreeNode]:
return { self._subtree }
def get_code(self, ctx: SfgContext) -> str:
subtree_code = ctx.codestyle.indent(self._subtree.get_code(ctx))
return "{\n" + subtree_code + "\n}"
class SfgKernelCallNode(SfgCallTreeLeaf):
def __init__(self, kernel_handle: SfgKernelHandle):
self._kernel_handle = kernel_handle
@property
def required_symbols(self) -> Set[TypedSymbol]:
return set(p.symbol for p in self._kernel_handle.parameters)
def get_code(self, ctx: SfgContext) -> str:
ast_params = self._kernel_handle.parameters
fnc_name = self._kernel_handle.fully_qualified_name
call_parameters = ", ".join([p.symbol.name for p in ast_params])
return f"{fnc_name}({call_parameters});"