Newer
Older
from typing import Set
from functools import reduce
from pystencils.typing import TypedSymbol
from .basic_nodes import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence, SfgStatements
class FlattenSequences():
"""Flattens any nested sequences occuring in a kernel call tree."""
def visit(self, node: SfgCallTreeNode) -> None:
if isinstance(node, SfgSequence):
return self._visit_SfgSequence(node)
else:
for c in node.children:
self.visit(c)
def _visit_SfgSequence(self, sequence: SfgSequence) -> None:
children_flattened = []
def flatten(seq: SfgSequence):
for c in seq.children:
if isinstance(c, SfgSequence):
flatten(c)
else:
children_flattened.append(c)
flatten(sequence)
for c in children_flattened:
self.visit(c)
sequence._children = children_flattened
"""Collects all parameters required but not defined in a kernel call tree.
Requires that all sequences in the tree are flattened.
"""
def visit(self, node: SfgCallTreeNode) -> Set[TypedSymbol]:
if isinstance(node, SfgCallTreeLeaf):
return self._visit_SfgCallTreeLeaf(node)
elif isinstance(node, SfgSequence):
return self._visit_SfgSequence(node)
else:
return self._visit_branchingNode(node)
def _visit_SfgCallTreeLeaf(self, leaf: SfgCallTreeLeaf) -> Set[TypedSymbol]:
return leaf.required_symbols
def _visit_SfgSequence(self, sequence: SfgSequence) -> Set[TypedSymbol]:
"""
Only in a sequence may parameters be defined and visible to subsequent nodes.
"""
params = set()
for c in sequence.children[::-1]:
if isinstance(c, SfgStatements):
assert not isinstance(c, SfgSequence), "Sequence not flattened."
params |= self.visit(c)
return params
def _visit_branchingNode(self, node: SfgCallTreeNode):
"""
Each interior node that is not a sequence simply requires the union of all parameters
required by its children.
"""
return reduce(lambda x, y: x | y, (self.visit(c) for c in node.children), set())