Newer
Older
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Sequence, Set, Union, Iterable
if TYPE_CHECKING:
from ..context import SfgContext
from functools import reduce
from pystencils.typing import TypedSymbol
from .basic_nodes import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence, SfgStatements
from .deferred_nodes import SfgParamCollectionDeferredNode
class FlattenSequences():
"""Flattens any nested sequences occuring in a kernel call tree."""
def visit(self, node: SfgCallTreeNode) -> None:
if isinstance(node, SfgSequence):
return self._visit_SfgSequence(node)
else:
for c in node.children:
self.visit(c)
def _visit_SfgSequence(self, sequence: SfgSequence) -> None:
children_flattened = []
def flatten(seq: SfgSequence):
for c in seq.children:
if isinstance(c, SfgSequence):
flatten(c)
else:
children_flattened.append(c)
flatten(sequence)
for c in children_flattened:
self.visit(c)
sequence._children = children_flattened
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
class CollectIncludes:
def visit(self, node: SfgCallTreeNode):
includes = node.required_includes
for c in node.children:
includes |= self.visit(c)
return includes
class ExpandingParameterCollector():
def __init__(self, ctx: SfgContext) -> None:
self._ctx = ctx
self._flattener = FlattenSequences()
"""Collects all parameters required but not defined in a kernel call tree.
Expands any deferred nodes of type `SfgParamCollectionDeferredNode` found within sequences on the way.
"""
def visit(self, node: SfgCallTreeNode) -> Set[TypedSymbol]:
if isinstance(node, SfgCallTreeLeaf):
return self._visit_SfgCallTreeLeaf(node)
elif isinstance(node, SfgSequence):
return self._visit_SfgSequence(node)
else:
return self._visit_branchingNode(node)
def _visit_SfgCallTreeLeaf(self, leaf: SfgCallTreeLeaf) -> Set[TypedSymbol]:
return leaf.required_symbols
def _visit_SfgSequence(self, sequence: SfgSequence) -> Set[TypedSymbol]:
"""
Only in a sequence may parameters be defined and visible to subsequent nodes.
"""
params = set()
def iter_nested_sequences(seq: SfgSequence, visible_params: Set[TypedSymbol]):
for i in range(len(seq.children) - 1, -1, -1):
c = seq.children[i]
if isinstance(c, SfgParamCollectionDeferredNode):
c = c.expand(self._ctx, visible_params=visible_params)
seq.replace_child(i, c)
if isinstance(c, SfgSequence):
iter_nested_sequences(c, visible_params)
else:
if isinstance(c, SfgStatements):
visible_params -= c.defined_symbols
visible_params |= self.visit(c)
iter_nested_sequences(sequence, params)
return params
def _visit_branchingNode(self, node: SfgCallTreeNode):
"""
Each interior node that is not a sequence simply requires the union of all parameters
required by its children.
"""
return reduce(lambda x, y: x | y, (self.visit(c) for c in node.children), set())
"""Collects all parameters required but not defined in a kernel call tree.
Requires that all sequences in the tree are flattened.
"""
def visit(self, node: SfgCallTreeNode) -> Set[TypedSymbol]:
if isinstance(node, SfgCallTreeLeaf):
return self._visit_SfgCallTreeLeaf(node)
elif isinstance(node, SfgSequence):
return self._visit_SfgSequence(node)
else:
return self._visit_branchingNode(node)
def _visit_SfgCallTreeLeaf(self, leaf: SfgCallTreeLeaf) -> Set[TypedSymbol]:
return leaf.required_symbols
def _visit_SfgSequence(self, sequence: SfgSequence) -> Set[TypedSymbol]:
"""
Only in a sequence may parameters be defined and visible to subsequent nodes.
"""
params = set()
for c in sequence.children[::-1]:
if isinstance(c, SfgStatements):
assert not isinstance(c, SfgSequence), "Sequence not flattened."
params |= self.visit(c)
return params
def _visit_branchingNode(self, node: SfgCallTreeNode):
"""
Each interior node that is not a sequence simply requires the union of all parameters
required by its children.
"""
return reduce(lambda x, y: x | y, (self.visit(c) for c in node.children), set())