Skip to content
Snippets Groups Projects
dispatcher.py 1.34 KiB
Newer Older
from __future__ import annotations
Frederik Hennig's avatar
Frederik Hennig committed
from typing import Callable, TypeVar, Generic, ParamSpec
from types import MethodType

from functools import wraps

from .basic_nodes import SfgCallTreeNode

Frederik Hennig's avatar
Frederik Hennig committed
V = TypeVar("V")
R = TypeVar("R")
P = ParamSpec("P")
Frederik Hennig's avatar
Frederik Hennig committed
class VisitorDispatcher(Generic[V, R]):
    def __init__(self, wrapped_method: Callable[..., R]):
        self._dispatch_dict: dict[type, Callable[..., R]] = {}
        self._wrapped_method: Callable[..., R] = wrapped_method

    def case(self, node_type: type):
        """Decorator for visitor's methods"""

Frederik Hennig's avatar
Frederik Hennig committed
        def decorate(handler: Callable[..., R]):
            if node_type in self._dispatch_dict:
                raise ValueError(f"Duplicate visitor case {node_type}")
            self._dispatch_dict[node_type] = handler
            return handler

        return decorate

Frederik Hennig's avatar
Frederik Hennig committed
    def __call__(self, instance: V, node: SfgCallTreeNode, *args, **kwargs) -> R:
        for cls in node.__class__.mro():
            if cls in self._dispatch_dict:
                return self._dispatch_dict[cls](instance, node, *args, **kwargs)

        return self._wrapped_method(instance, node, *args, **kwargs)

Frederik Hennig's avatar
Frederik Hennig committed
    def __get__(self, obj: V, objtype=None) -> Callable[..., R]:
        if obj is None:
            return self
        return MethodType(self, obj)


def visitor(method):
    return wraps(method)(VisitorDispatcher(method))