Skip to content
Snippets Groups Projects
visitors.py 2.36 KiB
Newer Older
Frederik Hennig's avatar
Frederik Hennig committed
from typing import Set
from functools import reduce

from pystencils.typing import TypedSymbol

from .basic_nodes import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence, SfgStatements


class FlattenSequences():
    """Flattens any nested sequences occuring in a kernel call tree."""
    def visit(self, node: SfgCallTreeNode) -> None:
        if isinstance(node, SfgSequence):
            return self._visit_SfgSequence(node)
        else:
            for c in node.children:
                self.visit(c)

    def _visit_SfgSequence(self, sequence: SfgSequence) -> None:
        children_flattened = []
        
        def flatten(seq: SfgSequence):
            for c in seq.children:
                if isinstance(c, SfgSequence):
                    flatten(c)
                else:
                    children_flattened.append(c)
        
        flatten(sequence)

        for c in children_flattened:
            self.visit(c)

        sequence._children = children_flattened
Frederik Hennig's avatar
Frederik Hennig committed


class ParameterCollector():
    """Collects all parameters required but not defined in a kernel call tree.

    Requires that all sequences in the tree are flattened.
    """
Frederik Hennig's avatar
Frederik Hennig committed
    def visit(self, node: SfgCallTreeNode) -> Set[TypedSymbol]:
        if isinstance(node, SfgCallTreeLeaf):
            return self._visit_SfgCallTreeLeaf(node)
        elif isinstance(node, SfgSequence):
            return self._visit_SfgSequence(node)
        else:
            return self._visit_branchingNode(node)

    def _visit_SfgCallTreeLeaf(self, leaf: SfgCallTreeLeaf) -> Set[TypedSymbol]:
        return leaf.required_symbols

    def _visit_SfgSequence(self, sequence: SfgSequence) -> Set[TypedSymbol]:
        """
            Only in a sequence may parameters be defined and visible to subsequent nodes.
        """
        
        params = set()
        for c in sequence.children[::-1]:
            if isinstance(c, SfgStatements):
Frederik Hennig's avatar
Frederik Hennig committed
                params -= c.defined_symbols
            
            assert not isinstance(c, SfgSequence), "Sequence not flattened."
Frederik Hennig's avatar
Frederik Hennig committed
            params |= self.visit(c)
        return params

    def _visit_branchingNode(self, node: SfgCallTreeNode):
        """
            Each interior node that is not a sequence simply requires the union of all parameters
            required by its children.
        """
        return reduce(lambda x, y: x | y, (self.visit(c) for c in node.children), set())