Skip to content
Snippets Groups Projects
Commit 1934248a authored by Frederik Hennig's avatar Frederik Hennig
Browse files

std::tuple mapping

parent c5a41e2e
No related merge requests found
Pipeline #58423 passed with stages
in 53 seconds
from .std_mdspan import StdMdspan, mdspan_ref
from .std_vector import StdVector, std_vector_ref
from .std_tuple import StdTuple, std_tuple_ref
__all__ = [
"StdMdspan", "StdVector", "std_vector_ref",
"mdspan_ref"
"StdMdspan",
"mdspan_ref",
"StdVector",
"std_vector_ref",
"StdTuple",
"std_tuple_ref",
]
from typing import Sequence
from pystencils.typing import BasicType, TypedSymbol
from ...tree import SfgStatements
from ..source_objects import SrcVector
from ..source_objects import TypedSymbolOrObject
from ...types import SrcType, cpp_typename
from ...source_components import SfgHeaderInclude
class StdTuple(SrcVector):
def __init__(
self,
identifier: str,
element_types: Sequence[BasicType],
const: bool = False,
ref: bool = False,
):
self._element_types = element_types
self._length = len(element_types)
elt_type_strings = tuple(cpp_typename(t) for t in self._element_types)
src_type = f"{'const' if const else ''} std::tuple< {', '.join(elt_type_strings)} > {'&' if ref else ''}"
super().__init__(identifier, SrcType(src_type))
@property
def required_includes(self) -> set[SfgHeaderInclude]:
return {SfgHeaderInclude("tuple", system_header=True)}
def extract_component(self, destination: TypedSymbolOrObject, coordinate: int):
if coordinate < 0 or coordinate >= self._length:
raise ValueError(
f"Index {coordinate} out-of-bounds for std::tuple with {self._length} entries."
)
if destination.dtype != self._element_types[coordinate]:
raise ValueError(
f"Cannot extract type {destination.dtype} from std::tuple entry "
"of type {self._element_types[coordinate]}"
)
return SfgStatements(
f"{destination.dtype} {destination.name} = std::get< {coordinate} >({self.identifier});",
(destination,),
(self,),
)
def std_tuple_ref(
identifier: str, components: Sequence[TypedSymbol], const: bool = True
):
elt_types = tuple(c.dtype for c in components)
return StdTuple(identifier, elt_types, const=const, ref=True)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment