Newer
Older
from typing import Callable, TypeVar, Generic, ParamSpec
from types import MethodType
from functools import wraps
from .basic_nodes import SfgCallTreeNode
V = TypeVar("V")
R = TypeVar("R")
P = ParamSpec("P")
class VisitorDispatcher(Generic[V, R]):
def __init__(self, wrapped_method: Callable[..., R]):
self._dispatch_dict: dict[type, Callable[..., R]] = {}
self._wrapped_method: Callable[..., R] = wrapped_method
def case(self, node_type: type):
"""Decorator for visitor's methods"""
if node_type in self._dispatch_dict:
raise ValueError(f"Duplicate visitor case {node_type}")
self._dispatch_dict[node_type] = handler
return handler
return decorate
def __call__(self, instance: V, node: SfgCallTreeNode, *args, **kwargs) -> R:
for cls in node.__class__.mro():
if cls in self._dispatch_dict:
return self._dispatch_dict[cls](instance, node, *args, **kwargs)
return self._wrapped_method(instance, node, *args, **kwargs)
def __get__(self, obj: V, objtype=None) -> Callable[..., R]:
if obj is None:
return self
return MethodType(self, obj)
def visitor(method):
return wraps(method)(VisitorDispatcher(method))