from __future__ import annotations
from typing import Callable, TypeVar, Generic, ParamSpec
from types import MethodType

from functools import wraps

from .basic_nodes import SfgCallTreeNode

V = TypeVar("V")
R = TypeVar("R")
P = ParamSpec("P")


class VisitorDispatcher(Generic[V, R]):
    def __init__(self, wrapped_method: Callable[..., R]):
        self._dispatch_dict: dict[type, Callable[..., R]] = {}
        self._wrapped_method: Callable[..., R] = wrapped_method

    def case(self, node_type: type):
        """Decorator for visitor's methods"""

        def decorate(handler: Callable[..., R]):
            if node_type in self._dispatch_dict:
                raise ValueError(f"Duplicate visitor case {node_type}")
            self._dispatch_dict[node_type] = handler
            return handler

        return decorate

    def __call__(self, instance: V, node: SfgCallTreeNode, *args, **kwargs) -> R:
        for cls in node.__class__.mro():
            if cls in self._dispatch_dict:
                return self._dispatch_dict[cls](instance, node, *args, **kwargs)

        return self._wrapped_method(instance, node, *args, **kwargs)

    def __get__(self, obj: V, objtype=None) -> Callable[..., R]:
        if obj is None:
            return self
        return MethodType(self, obj)


def visitor(method):
    return wraps(method)(VisitorDispatcher(method))