from typing import Generator, Union, Optional, Sequence

import sys
import os
from os import path


from pystencils import Field
from pystencils.astnodes import KernelFunction

from .configuration import SfgConfiguration, config_from_commandline, merge_configurations, SfgCodeStyle
from .kernel_namespace import SfgKernelNamespace, SfgKernelHandle
from .tree import SfgCallTreeNode, SfgSequence, SfgKernelCallNode, SfgStatements
from .tree.deferred_nodes import SfgDeferredFieldMapping
from .tree.builders import SfgBranchBuilder, make_sequence
from .tree.visitors import CollectIncludes
from .source_concepts import SrcField, TypedSymbolOrObject
from .source_components import SfgFunction, SfgHeaderInclude


class SourceFileGenerator:
    def __init__(self, sfg_config: SfgConfiguration | None = None):
        if sfg_config and not isinstance(sfg_config, SfgConfiguration):
            raise TypeError("sfg_config is not an SfgConfiguration.")

        import __main__
        scriptpath = __main__.__file__
        scriptname = path.split(scriptpath)[1]
        basename = path.splitext(scriptname)[0]

        project_config, cmdline_config, script_args = config_from_commandline(sys.argv)

        config = merge_configurations(project_config, cmdline_config, sfg_config)

        self._context = SfgContext(script_args, config)

        from .emitters.cpu.basic_cpu import BasicCpuEmitter
        self._emitter = BasicCpuEmitter(basename, config)

    def clean_files(self):
        for file in self._emitter.output_files:
            if path.exists(file):
                os.remove(file)

    def __enter__(self):
        self.clean_files()
        return self._context

    def __exit__(self, exc_type, exc_value, traceback):
        if exc_type is None:
            self._emitter.write_files(self._context)


class SfgContext:
    def __init__(self, argv, config: SfgConfiguration):
        self._argv = argv
        self._config = config
        self._default_kernel_namespace = SfgKernelNamespace(self, "kernels")

        self._code_namespace = None

        #   Source Components
        self._includes: set[SfgHeaderInclude] = set()
        self._kernel_namespaces = {self._default_kernel_namespace.name: self._default_kernel_namespace}
        self._functions: dict[str, SfgFunction] = dict()

    @property
    def argv(self) -> Sequence[str]:
        return self._argv

    @property
    def root_namespace(self) -> str | None:
        return self._config.base_namespace

    @property
    def inner_namespace(self) -> str | None:
        return self._code_namespace

    @property
    def fully_qualified_namespace(self) -> str | None:
        match (self.root_namespace, self.inner_namespace):
            case None, None: return None
            case outer, None: return outer
            case None, inner: return inner
            case outer, inner: return f"{outer}::{inner}"
            case _: assert False

    @property
    def codestyle(self) -> SfgCodeStyle:
        assert self._config.codestyle is not None
        return self._config.codestyle

    # ----------------------------------------------------------------------------------------------
    #   Source Component Getters
    # ----------------------------------------------------------------------------------------------

    def includes(self) -> Generator[SfgHeaderInclude, None, None]:
        yield from self._includes

    def kernel_namespaces(self) -> Generator[SfgKernelNamespace, None, None]:
        yield from self._kernel_namespaces.values()

    def functions(self) -> Generator[SfgFunction, None, None]:
        yield from self._functions.values()

    # ----------------------------------------------------------------------------------------------
    #   Source Component Adders
    # ----------------------------------------------------------------------------------------------

    def add_include(self, include: SfgHeaderInclude):
        self._includes.add(include)

    def add_function(self, func: SfgFunction):
        if func.name in self._functions:
            raise ValueError(f"Duplicate function: {func.name}")

        self._functions[func.name] = func
        for incl in CollectIncludes().visit(func._tree):
            self.add_include(incl)

    # ----------------------------------------------------------------------------------------------
    #   Factory-like Adders
    # ----------------------------------------------------------------------------------------------

    @property
    def kernels(self) -> SfgKernelNamespace:
        return self._default_kernel_namespace

    def kernel_namespace(self, name: str) -> SfgKernelNamespace:
        if name in self._kernel_namespaces:
            raise ValueError(f"Duplicate kernel namespace: {name}")

        kns = SfgKernelNamespace(self, name)
        self._kernel_namespaces[name] = kns
        return kns

    def include(self, header_file: str):
        system_header = False
        if header_file.startswith("<") and header_file.endswith(">"):
            header_file = header_file[1:-1]
            system_header = True

        self.add_include(SfgHeaderInclude(header_file, system_header=system_header))

    def function(self,
                 name: str,
                 ast_or_kernel_handle: Optional[Union[KernelFunction, SfgKernelHandle]] = None):
        if name in self._functions:
            raise ValueError(f"Duplicate function: {name}")

        if ast_or_kernel_handle is not None:
            if isinstance(ast_or_kernel_handle, KernelFunction):
                khandle = self._default_kernel_namespace.add(ast_or_kernel_handle)
                tree = SfgKernelCallNode(khandle)
            elif isinstance(ast_or_kernel_handle, SfgKernelCallNode):
                tree = ast_or_kernel_handle
            else:
                raise TypeError("Invalid type of argument `ast_or_kernel_handle`!")

            func = SfgFunction(self, name, tree)
            self.add_function(func)
        else:
            def sequencer(*args: SfgCallTreeNode):
                tree = make_sequence(*args)
                func = SfgFunction(self, name, tree)
                self.add_function(func)

            return sequencer

    # ----------------------------------------------------------------------------------------------
    #   In-Sequence builders to be used within the second phase of SfgContext.function().
    # ----------------------------------------------------------------------------------------------

    def call(self, kernel_handle: SfgKernelHandle) -> SfgKernelCallNode:
        return SfgKernelCallNode(kernel_handle)

    @property
    def branch(self) -> SfgBranchBuilder:
        return SfgBranchBuilder()

    def map_field(self, field: Field, src_object: Optional[SrcField] = None) -> SfgDeferredFieldMapping:
        if src_object is None:
            raise NotImplementedError("Automatic field extraction is not implemented yet.")
        else:
            return SfgDeferredFieldMapping(field, src_object)

    def map_param(self, lhs: TypedSymbolOrObject, rhs: TypedSymbolOrObject, mapping: str):
        return SfgStatements(mapping, (lhs,), (rhs,))