from typing import Union, cast

import numpy as np

from pystencils import Field
from pystencils.typing import FieldPointerSymbol, FieldStrideSymbol, FieldShapeSymbol

from ...tree import SfgStatements
from ..source_objects import SrcField
from ...source_components import SfgHeaderInclude
from ...types import PsType, cpp_typename, SrcType
from ...exceptions import SfgException


class StdMdspan(SrcField):
    dynamic_extent = "std::dynamic_extent"

    def __init__(self, identifer: str,
                 T: PsType,
                 extents: tuple[int | str, ...],
                 extents_type: PsType = int,
                 reference: bool = False):
        cpp_typestr = cpp_typename(T)
        extents_type_str = cpp_typename(extents_type)

        extents_str = f"std::extents< {extents_type_str}, {', '.join(str(e) for e in extents)} >"
        typestring = f"std::mdspan< {cpp_typestr}, {extents_str} > {'&' if reference else ''}"
        super().__init__(SrcType(typestring), identifer)

        self._extents = extents

    @property
    def required_includes(self) -> set[SfgHeaderInclude]:
        return {SfgHeaderInclude("experimental/mdspan", system_header=True)}

    def extract_ptr(self, ptr_symbol: FieldPointerSymbol):
        return SfgStatements(
            f"{ptr_symbol.dtype} {ptr_symbol.name} = {self._identifier}.data_handle();",
            (ptr_symbol, ),
            (self, )
        )

    def extract_size(self, coordinate: int, size: Union[int, FieldShapeSymbol]) -> SfgStatements:
        dim = len(self._extents)
        if coordinate >= dim:
            if isinstance(size, FieldShapeSymbol):
                raise SfgException(f"Cannot extract size in coordinate {coordinate} from a {dim}-dimensional mdspan!")
            elif size != 1:
                raise SfgException(
                    f"Cannot map field with size {size} in coordinate {coordinate} to {dim}-dimensional mdspan!")
            else:
                #   trivial trailing index dimensions are OK -> do nothing
                return SfgStatements(f"// {self._identifier}.extents().extent({coordinate}) == 1", (), ())

        if isinstance(size, FieldShapeSymbol):
            return SfgStatements(
                f"{size.dtype} {size.name} = {self._identifier}.extents().extent({coordinate});",
                (size, ),
                (self, )
            )
        else:
            return SfgStatements(
                f"assert( {self._identifier}.extents().extent({coordinate}) == {size} );",
                (), (self, )
            )

    def extract_stride(self, coordinate: int, stride: Union[int, FieldStrideSymbol]) -> SfgStatements:
        if coordinate >= len(self._extents):
            raise SfgException(
                f"Cannot extract stride in coordinate {coordinate} from a {len(self._extents)}-dimensional mdspan")

        if isinstance(stride, FieldStrideSymbol):
            return SfgStatements(
                f"{stride.dtype} {stride.name} = {self._identifier}.stride({coordinate});",
                (stride, ),
                (self, )
            )
        else:
            return SfgStatements(
                f"assert( {self._identifier}.stride({coordinate}) == {stride} );",
                (), (self, )
            )


def mdspan_ref(field: Field, extents_type: type = np.uint32):
    """Creates a `std::mdspan &` for a given pystencils field."""
    from pystencils.field import layout_string_to_tuple

    if field.layout != layout_string_to_tuple("soa", field.spatial_dimensions):
        raise NotImplementedError("mdspan mapping is currently only available for structure-of-arrays fields")

    extents: list[str | int] = []

    for s in field.spatial_shape:
        extents.append(StdMdspan.dynamic_extent if isinstance(s, FieldShapeSymbol) else cast(int, s))

    if field.index_shape != (1,):
        for s in field.index_shape:
            extents += StdMdspan.dynamic_extent if isinstance(s, FieldShapeSymbol) else s

    return StdMdspan(field.name, field.dtype,
                     tuple(extents),
                     extents_type=extents_type,
                     reference=True)