from typing import Set from functools import reduce from pystencils.typing import TypedSymbol from .basic_nodes import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence, SfgStatements class FlattenSequences(): """Flattens any nested sequences occuring in a kernel call tree.""" def visit(self, node: SfgCallTreeNode) -> None: if isinstance(node, SfgSequence): return self._visit_SfgSequence(node) else: for c in node.children: self.visit(c) def _visit_SfgSequence(self, sequence: SfgSequence) -> None: children_flattened = [] def flatten(seq: SfgSequence): for c in seq.children: if isinstance(c, SfgSequence): flatten(c) else: children_flattened.append(c) flatten(sequence) for c in children_flattened: self.visit(c) sequence._children = children_flattened class ParameterCollector(): """Collects all parameters required but not defined in a kernel call tree. Requires that all sequences in the tree are flattened. """ def visit(self, node: SfgCallTreeNode) -> Set[TypedSymbol]: if isinstance(node, SfgCallTreeLeaf): return self._visit_SfgCallTreeLeaf(node) elif isinstance(node, SfgSequence): return self._visit_SfgSequence(node) else: return self._visit_branchingNode(node) def _visit_SfgCallTreeLeaf(self, leaf: SfgCallTreeLeaf) -> Set[TypedSymbol]: return leaf.required_symbols def _visit_SfgSequence(self, sequence: SfgSequence) -> Set[TypedSymbol]: """ Only in a sequence may parameters be defined and visible to subsequent nodes. """ params = set() for c in sequence.children[::-1]: if isinstance(c, SfgStatements): params -= c.defined_symbols assert not isinstance(c, SfgSequence), "Sequence not flattened." params |= self.visit(c) return params def _visit_branchingNode(self, node: SfgCallTreeNode): """ Each interior node that is not a sequence simply requires the union of all parameters required by its children. """ return reduce(lambda x, y: x | y, (self.visit(c) for c in node.children), set())