Skip to content
Snippets Groups Projects
std_mdspan.py 4.06 KiB
Newer Older
Frederik Hennig's avatar
Frederik Hennig committed
from typing import Union, cast
import numpy as np

from pystencils import Field
from pystencils.typing import FieldPointerSymbol, FieldStrideSymbol, FieldShapeSymbol

from ...tree import SfgStatements
from ..source_objects import SrcField
from ...source_components import SfgHeaderInclude
Frederik Hennig's avatar
Frederik Hennig committed
from ...types import PsType, cpp_typename, SrcType
from ...exceptions import SfgException
    dynamic_extent = "std::dynamic_extent"

Frederik Hennig's avatar
Frederik Hennig committed
    def __init__(self, identifer: str,
                 T: PsType,
Frederik Hennig's avatar
Frederik Hennig committed
                 extents: tuple[int | str, ...],
Frederik Hennig's avatar
Frederik Hennig committed
                 extents_type: PsType = int,
                 reference: bool = False):
Frederik Hennig's avatar
Frederik Hennig committed
        cpp_typestr = cpp_typename(T)
        extents_type_str = cpp_typename(extents_type)
Frederik Hennig's avatar
Frederik Hennig committed
        extents_str = f"std::extents< {extents_type_str}, {', '.join(str(e) for e in extents)} >"
        typestring = f"std::mdspan< {cpp_typestr}, {extents_str} > {'&' if reference else ''}"
        super().__init__(SrcType(typestring), identifer)
Frederik Hennig's avatar
Frederik Hennig committed
    def required_includes(self) -> set[SfgHeaderInclude]:
Frederik Hennig's avatar
Frederik Hennig committed
        return {SfgHeaderInclude("experimental/mdspan", system_header=True)}

    def extract_ptr(self, ptr_symbol: FieldPointerSymbol):
        return SfgStatements(
            f"{ptr_symbol.dtype} {ptr_symbol.name} = {self._identifier}.data_handle();",
            (ptr_symbol, ),
            (self, )
        )

    def extract_size(self, coordinate: int, size: Union[int, FieldShapeSymbol]) -> SfgStatements:
        dim = len(self._extents)
        if coordinate >= dim:
            if isinstance(size, FieldShapeSymbol):
                raise SfgException(f"Cannot extract size in coordinate {coordinate} from a {dim}-dimensional mdspan!")
            elif size != 1:
Frederik Hennig's avatar
Frederik Hennig committed
                raise SfgException(
                    f"Cannot map field with size {size} in coordinate {coordinate} to {dim}-dimensional mdspan!")
            else:
                #   trivial trailing index dimensions are OK -> do nothing
                return SfgStatements(f"// {self._identifier}.extents().extent({coordinate}) == 1", (), ())
        if isinstance(size, FieldShapeSymbol):
            return SfgStatements(
Frederik Hennig's avatar
Frederik Hennig committed
                f"{size.dtype} {size.name} = {self._identifier}.extents().extent({coordinate});",
                (size, ),
                (self, )
            )
        else:
            return SfgStatements(
                f"assert( {self._identifier}.extents().extent({coordinate}) == {size} );",
                (), (self, )
            )
    def extract_stride(self, coordinate: int, stride: Union[int, FieldStrideSymbol]) -> SfgStatements:
        if coordinate >= len(self._extents):
Frederik Hennig's avatar
Frederik Hennig committed
            raise SfgException(
                f"Cannot extract stride in coordinate {coordinate} from a {len(self._extents)}-dimensional mdspan")

        if isinstance(stride, FieldStrideSymbol):
Frederik Hennig's avatar
Frederik Hennig committed
                f"{stride.dtype} {stride.name} = {self._identifier}.stride({coordinate});",
                (stride, ),
                (self, )
            )
        else:
            return SfgStatements(
                f"assert( {self._identifier}.stride({coordinate}) == {stride} );",
                (), (self, )
            )


def mdspan_ref(field: Field, extents_type: type = np.uint32):
    """Creates a `std::mdspan &` for a given pystencils field."""
    from pystencils.field import layout_string_to_tuple

    if field.layout != layout_string_to_tuple("soa", field.spatial_dimensions):
        raise NotImplementedError("mdspan mapping is currently only available for structure-of-arrays fields")

Frederik Hennig's avatar
Frederik Hennig committed
    extents : list[str | int] = []
Frederik Hennig's avatar
Frederik Hennig committed
        extents.append(StdMdspan.dynamic_extent if isinstance(s, FieldShapeSymbol) else cast(int, s))

    if field.index_shape != (1,):
        for s in field.index_shape:
            extents += StdMdspan.dynamic_extent if isinstance(s, FieldShapeSymbol) else s

    return StdMdspan(field.name, field.dtype,
Frederik Hennig's avatar
Frederik Hennig committed
                     tuple(extents),
                     extents_type=extents_type,
                     reference=True)