Skip to content
Snippets Groups Projects
visitors.py 4.48 KiB
Newer Older
Frederik Hennig's avatar
Frederik Hennig committed
from typing import TYPE_CHECKING
Frederik Hennig's avatar
Frederik Hennig committed
from functools import reduce

from .basic_nodes import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence, SfgStatements
from .deferred_nodes import SfgParamCollectionDeferredNode
from .dispatcher import visitor
Frederik Hennig's avatar
Frederik Hennig committed
from ..source_concepts.source_objects import TypedSymbolOrObject
if TYPE_CHECKING:
    from ..context import SfgContext


class FlattenSequences():
    """Flattens any nested sequences occuring in a kernel call tree."""
    @visitor
    def visit(self, node: SfgCallTreeNode) -> None:
        for c in node.children:
            self.visit(c)
    @visit.case(SfgSequence)
    def sequence(self, sequence: SfgSequence) -> None:
        children_flattened = []
        def flatten(seq: SfgSequence):
            for c in seq.children:
                if isinstance(c, SfgSequence):
                    flatten(c)
                else:
                    children_flattened.append(c)
        flatten(sequence)

        for c in children_flattened:
            self.visit(c)

        sequence._children = children_flattened
class CollectIncludes:
    def visit(self, node: SfgCallTreeNode):
        includes = node.required_includes
        for c in node.children:
            includes |= self.visit(c)

        return includes


class ExpandingParameterCollector():
    """Collects all parameters required but not defined in a kernel call tree.
    Expands any deferred nodes of type `SfgParamCollectionDeferredNode` found within sequences on the way.
    """

    def __init__(self, ctx: SfgContext) -> None:
        self._ctx = ctx

        self._flattener = FlattenSequences()

    @visitor
Frederik Hennig's avatar
Frederik Hennig committed
    def visit(self, node: SfgCallTreeNode) -> set[TypedSymbolOrObject]:
        return self.branching_node(node)

    @visit.case(SfgCallTreeLeaf)
Frederik Hennig's avatar
Frederik Hennig committed
    def leaf(self, leaf: SfgCallTreeLeaf) -> set[TypedSymbolOrObject]:
        return leaf.required_parameters
    @visit.case(SfgSequence)
Frederik Hennig's avatar
Frederik Hennig committed
    def sequence(self, sequence: SfgSequence) -> set[TypedSymbolOrObject]:
        """
            Only in a sequence may parameters be defined and visible to subsequent nodes.
        """
Frederik Hennig's avatar
Frederik Hennig committed
        params: set[TypedSymbolOrObject] = set()
Frederik Hennig's avatar
Frederik Hennig committed
        def iter_nested_sequences(seq: SfgSequence, visible_params: set[TypedSymbolOrObject]):
            for i in range(len(seq.children) - 1, -1, -1):
                c = seq.children[i]
                if isinstance(c, SfgParamCollectionDeferredNode):
                    c = c.expand(self._ctx, visible_params=visible_params)
Frederik Hennig's avatar
Frederik Hennig committed
                    seq[i] = c

                if isinstance(c, SfgSequence):
                    iter_nested_sequences(c, visible_params)
                else:
                    if isinstance(c, SfgStatements):
                        visible_params -= c.defined_parameters
                    visible_params |= self.visit(c)

        iter_nested_sequences(sequence, params)

        return params

Frederik Hennig's avatar
Frederik Hennig committed
    def branching_node(self, node: SfgCallTreeNode) -> set[TypedSymbolOrObject]:
        """
            Each interior node that is not a sequence simply requires the union of all parameters
            required by its children.
        """
        return reduce(lambda x, y: x | y, (self.visit(c) for c in node.children), set())


Frederik Hennig's avatar
Frederik Hennig committed
class ParameterCollector():
    """Collects all parameters required but not defined in a kernel call tree.

    Requires that all sequences in the tree are flattened.
    """
    @visitor
Frederik Hennig's avatar
Frederik Hennig committed
    def visit(self, node: SfgCallTreeNode) -> set[TypedSymbolOrObject]:
        return self.branching_node(node)

    @visit.case(SfgCallTreeLeaf)
Frederik Hennig's avatar
Frederik Hennig committed
    def leaf(self, leaf: SfgCallTreeLeaf) -> set[TypedSymbolOrObject]:
        return leaf.required_parameters
Frederik Hennig's avatar
Frederik Hennig committed

    @visit.case(SfgSequence)
Frederik Hennig's avatar
Frederik Hennig committed
    def sequence(self, sequence: SfgSequence) -> set[TypedSymbolOrObject]:
Frederik Hennig's avatar
Frederik Hennig committed
        """
            Only in a sequence may parameters be defined and visible to subsequent nodes.
        """
Frederik Hennig's avatar
Frederik Hennig committed
        params: set[TypedSymbolOrObject] = set()
Frederik Hennig's avatar
Frederik Hennig committed
        for c in sequence.children[::-1]:
            if isinstance(c, SfgStatements):
                params -= c.defined_parameters
            assert not isinstance(c, SfgSequence), "Sequence not flattened."
Frederik Hennig's avatar
Frederik Hennig committed
            params |= self.visit(c)
        return params

Frederik Hennig's avatar
Frederik Hennig committed
    def branching_node(self, node: SfgCallTreeNode) -> set[TypedSymbolOrObject]:
Frederik Hennig's avatar
Frederik Hennig committed
        """
            Each interior node that is not a sequence simply requires the union of all parameters
            required by its children.
        """
        return reduce(lambda x, y: x | y, (self.visit(c) for c in node.children), set())