diff --git a/pystencilssfg/context.py b/pystencilssfg/context.py index 659eecca6382745b5f84ef2927353faf2974cf76..c4eb5952a3dc3e30852ecdaf88a17d9d1a707c9c 100644 --- a/pystencilssfg/context.py +++ b/pystencilssfg/context.py @@ -13,11 +13,11 @@ from pystencils import Field from pystencils.astnodes import KernelFunction from .kernel_namespace import SfgKernelNamespace, SfgKernelHandle -from .tree import SfgCallTreeNode, SfgSequence, SfgKernelCallNode +from .tree import SfgCallTreeNode, SfgSequence, SfgKernelCallNode, SfgStatements from .tree.deferred_nodes import SfgDeferredFieldMapping from .tree.builders import SfgBranchBuilder, make_sequence from .tree.visitors import CollectIncludes -from .source_concepts.containers import SrcField +from .source_concepts import SrcField, TypedSymbolOrObject from .source_components import SfgFunction, SfgHeaderInclude @@ -167,7 +167,7 @@ class SfgContext: #---------------------------------------------------------------------------------------------- - # Call Tree Node Factory + # In-Sequence builders to be used within the second phase of SfgContext.function(). #---------------------------------------------------------------------------------------------- def call(self, kernel_handle: SfgKernelHandle) -> SfgKernelCallNode: @@ -182,4 +182,7 @@ class SfgContext: raise NotImplementedError("Automatic field extraction is not implemented yet.") else: return SfgDeferredFieldMapping(field, src_object) + + def map_param(self, lhs: TypedSymbolOrObject, rhs: TypedSymbolOrObject, mapping: str): + return SfgStatements(mapping, (lhs,), (rhs,)) \ No newline at end of file diff --git a/pystencilssfg/kernel_namespace.py b/pystencilssfg/kernel_namespace.py index 191d6d3d5de52770df42bad70485d07aa78ad958..890ff60119eb8ac413636b0ca8b122a3ea999ff4 100644 --- a/pystencilssfg/kernel_namespace.py +++ b/pystencilssfg/kernel_namespace.py @@ -1,5 +1,3 @@ -# from .context import SfgContext - from typing import Sequence from pystencils import CreateKernelConfig, create_kernel diff --git a/pystencilssfg/source_components/header_include.py b/pystencilssfg/source_components/header_include.py index 915423d7ddd9a1fb5cac071b6a9450298641eed9..33454563fbd9726d27b86cd0642515ef757e764f 100644 --- a/pystencilssfg/source_components/header_include.py +++ b/pystencilssfg/source_components/header_include.py @@ -24,6 +24,7 @@ class SfgHeaderInclude: return hash((self._header_file, self._system_header, self._private)) def __eq__(self, other: SfgHeaderInclude) -> bool: - return (self._header_file == other._header_file + return (isinstance(other, SfgHeaderInclude) + and self._header_file == other._header_file and self._system_header == other._system_header and self._private == other._private) diff --git a/pystencilssfg/source_concepts/__init__.py b/pystencilssfg/source_concepts/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..10ba6778432aa2169633f8652ba0e646bfc0d72d 100644 --- a/pystencilssfg/source_concepts/__init__.py +++ b/pystencilssfg/source_concepts/__init__.py @@ -0,0 +1,6 @@ +from .source_objects import SrcObject, SrcField, SrcVector, PsType, SrcType, TypedSymbolOrObject + +__all__ = [ + SrcObject, SrcField, SrcVector, + PsType, SrcType, TypedSymbolOrObject +] \ No newline at end of file diff --git a/pystencilssfg/source_concepts/containers.py b/pystencilssfg/source_concepts/containers.py deleted file mode 100644 index c634310c80f43fee81fc3c14a7469c62c632b3f6..0000000000000000000000000000000000000000 --- a/pystencilssfg/source_concepts/containers.py +++ /dev/null @@ -1,36 +0,0 @@ -from typing import Optional, Union -from abc import ABC, abstractmethod - -from pystencils import Field -from pystencils.typing import FieldPointerSymbol, FieldStrideSymbol, FieldShapeSymbol - -from .source_concepts import SrcObject -from ..tree import SfgStatements, SfgSequence - -class SrcField(SrcObject): - def __init__(self, src_type, identifier: Optional[str]): - super().__init__(src_type, identifier) - - @abstractmethod - def extract_ptr(self, ptr_symbol: FieldPointerSymbol) -> SfgStatements: - pass - - @abstractmethod - def extract_size(self, coordinate: int, size: Union[int, FieldShapeSymbol]) -> SfgStatements: - pass - - @abstractmethod - def extract_stride(self, coordinate: int, stride: Union[int, FieldStrideSymbol]) -> SfgStatements: - pass - - def extract_parameters(self, field: Field) -> SfgSequence: - ptr = FieldPointerSymbol(field.name, field.dtype, False) - - from ..tree import make_sequence - - return make_sequence( - self.extract_ptr(ptr), - *(self.extract_size(c, s) for c, s in enumerate(field.shape)), - *(self.extract_stride(c, s) for c, s in enumerate(field.strides)) - ) - diff --git a/pystencilssfg/source_concepts/cpp/__init__.py b/pystencilssfg/source_concepts/cpp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4d3397a175cbaf3a68b1ad9a60a8dba1c789212a --- /dev/null +++ b/pystencilssfg/source_concepts/cpp/__init__.py @@ -0,0 +1,6 @@ +from .std_mdspan import std_mdspan +from .std_vector import std_vector, std_vector_ref + +__all__= [ + std_mdspan, std_vector, std_vector_ref +] \ No newline at end of file diff --git a/pystencilssfg/source_concepts/cpp/std_mdspan.py b/pystencilssfg/source_concepts/cpp/std_mdspan.py index c4ff910722247afdc2424cb07c612fa58ac55810..2b99b86e9521b415a17234536f9ea75962bea296 100644 --- a/pystencilssfg/source_concepts/cpp/std_mdspan.py +++ b/pystencilssfg/source_concepts/cpp/std_mdspan.py @@ -1,22 +1,22 @@ from typing import Set, Union, Tuple -from numpy import dtype from pystencils.typing import FieldPointerSymbol, FieldStrideSymbol, FieldShapeSymbol -from pystencilssfg.source_components import SfgHeaderInclude - from ...tree import SfgStatements -from ..containers import SrcField +from ..source_objects import SrcField from ...source_components.header_include import SfgHeaderInclude +from ..source_objects import PsType from ...exceptions import SfgException class std_mdspan(SrcField): dynamic_extent = "std::dynamic_extent" - def __init__(self, identifer: str, T: dtype, extents: Tuple[int, str]): + def __init__(self, identifer: str, T: PsType, extents: Tuple[int, str], extents_type: PsType = int, reference: bool = False): from pystencils.typing import create_type T = create_type(T) - typestring = f"std::mdspan< {T}, std::extents< int, {', '.join(str(e) for e in extents)} > >" + extents_type = create_type(extents_type) + + typestring = f"std::mdspan< {T}, std::extents< {extents_type}, {', '.join(str(e) for e in extents)} > > {'&' if reference else ''}" super().__init__(typestring, identifer) self._extents = extents @@ -33,8 +33,15 @@ class std_mdspan(SrcField): ) def extract_size(self, coordinate: int, size: Union[int, FieldShapeSymbol]) -> SfgStatements: - if coordinate >= len(self._extents): - raise SfgException(f"Cannot extract size in coordinate {coordinate} from a {len(self._extents)}-dimensional mdspan") + dim = len(self._extents) + if coordinate >= dim: + if isinstance(size, FieldShapeSymbol): + raise SfgException(f"Cannot extract size in coordinate {coordinate} from a {dim}-dimensional mdspan!") + elif size != 1: + raise SfgException(f"Cannot map field with size {size} in coordinate {coordinate} to {dim}-dimensional mdspan!") + else: + # trivial trailing index dimensions are OK -> do nothing + return SfgStatements(f"// {self._identifier}.extents().extent({coordinate}) == 1", (), ()) if isinstance(size, FieldShapeSymbol): return SfgStatements( @@ -50,7 +57,7 @@ class std_mdspan(SrcField): def extract_stride(self, coordinate: int, stride: Union[int, FieldStrideSymbol]) -> SfgStatements: if coordinate >= len(self._extents): - raise SfgException(f"Cannot extract size in coordinate {coordinate} from a {len(self._extents)}-dimensional mdspan") + raise SfgException(f"Cannot extract stride in coordinate {coordinate} from a {len(self._extents)}-dimensional mdspan") if isinstance(stride, FieldStrideSymbol): return SfgStatements( diff --git a/pystencilssfg/source_concepts/cpp/std_vector.py b/pystencilssfg/source_concepts/cpp/std_vector.py new file mode 100644 index 0000000000000000000000000000000000000000..5f2e8f03a6718eb650845c74779f5981727a739a --- /dev/null +++ b/pystencilssfg/source_concepts/cpp/std_vector.py @@ -0,0 +1,73 @@ +from typing import Set, Union, Tuple + +from pystencils.typing import FieldPointerSymbol, FieldStrideSymbol, FieldShapeSymbol, create_type + +from ...tree import SfgStatements +from ..source_objects import SrcField, SrcVector +from ..source_objects import SrcObject, SrcType, TypedSymbolOrObject +from ...source_components.header_include import SfgHeaderInclude +from ...exceptions import SfgException + +class std_vector(SrcVector, SrcField): + def __init__(self, identifer: str, T: SrcType, unsafe: bool = False): + typestring = f"std::vector< {T} >" + super(SrcObject, self).__init__(identifer, typestring) + + self._element_type = T + self._unsafe = unsafe + + @property + def required_includes(self) -> Set[SfgHeaderInclude]: + return { SfgHeaderInclude("vector", system_header=True) } + + def extract_ptr(self, ptr_symbol: FieldPointerSymbol): + if ptr_symbol.dtype != self._element_type: + if self._unsafe: + mapping = f"{ptr_symbol.dtype} {ptr_symbol.name} = ({ptr_symbol.dtype}) {self._identifier}.data();" + else: + raise SfgException("Field type and std::vector element type do not match, and unsafe extraction was not enabled.") + else: + mapping = f"{ptr_symbol.dtype} {ptr_symbol.name} = {self._identifier}.data();" + + return SfgStatements(mapping, (ptr_symbol,), (self,)) + + def extract_size(self, coordinate: int, size: Union[int, FieldShapeSymbol]) -> SfgStatements: + if coordinate > 0: + raise SfgException(f"Cannot extract size in coordinate {coordinate} from std::vector") + + if isinstance(size, FieldShapeSymbol): + return SfgStatements( + f"{size.dtype} {size.name} = {self._identifier}.size();", + (size, ), + (self, ) + ) + else: + return SfgStatements( + f"assert( {self._identifier}.size() == {size} );", + (), (self, ) + ) + + def extract_stride(self, coordinate: int, stride: Union[int, FieldStrideSymbol]) -> SfgStatements: + if coordinate > 0: + raise SfgException(f"Cannot extract stride in coordinate {coordinate} from std::vector") + + if isinstance(stride, FieldStrideSymbol): + return SfgStatements(f"{stride.dtype} {stride.name} = 1;", (stride, ), ()) + else: + return SfgStatements(f"assert( 1 == {stride} );", (), ()) + + + def extract_component(self, destination: TypedSymbolOrObject, coordinate: int): + if self._unsafe: + mapping = f"{destination.dtype} {destination.name} = {self._identifier}[{coordinate}];" + else: + mapping = f"{destination.dtype} {destination.name} = {self._identifier}.at({coordinate});" + + return SfgStatements(mapping, (destination,), (self,)) + + + +class std_vector_ref(std_vector): + def __init__(self, identifer: str, T: SrcType): + typestring = f"std::vector< {T} > &" + super(SrcObject, self).__init__(identifer, typestring) diff --git a/pystencilssfg/source_concepts/source_concepts.py b/pystencilssfg/source_concepts/source_concepts.py deleted file mode 100644 index 66a22d86fcbe9caf89b581037a1b530362f26dbe..0000000000000000000000000000000000000000 --- a/pystencilssfg/source_concepts/source_concepts.py +++ /dev/null @@ -1,30 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Optional, Sequence, Union, Set - -if TYPE_CHECKING: - from ..source_components import SfgHeaderInclude - -from abc import ABC -from pystencils import TypedSymbol - -class SrcObject(ABC): - def __init__(self, src_type, identifier: Optional[str]): - self._src_type = src_type - self._identifier = identifier - - @property - def src_type(self): - return self._src_type - - @property - def identifier(self): - return self._identifier - - @property - def required_includes(self) -> Set[SfgHeaderInclude]: - return set() - - @property - def typed_symbol(self): - return TypedSymbol(self._identifier, self._src_type) diff --git a/pystencilssfg/source_concepts/source_objects.py b/pystencilssfg/source_concepts/source_objects.py new file mode 100644 index 0000000000000000000000000000000000000000..244ff75242c44ea603b7ae3c5ecad460f4ca0307 --- /dev/null +++ b/pystencilssfg/source_concepts/source_objects.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional, Union, Set, TypeAlias, NewType + +if TYPE_CHECKING: + from ..source_components import SfgHeaderInclude + from ..tree import SfgStatements, SfgSequence + +from numpy import dtype + +from abc import ABC, abstractmethod + +from pystencils import TypedSymbol, Field +from pystencils.typing import AbstractType, FieldPointerSymbol, FieldStrideSymbol, FieldShapeSymbol + +PsType: TypeAlias = Union[type, dtype, AbstractType] +"""Types used in interacting with pystencils. + +PsType represents various ways of specifying types within pystencils. +In particular, it encompasses most ways to construct an instance of `AbstractType`, +for example via `create_type`. + +(Note that, while `create_type` does accept strings, they are excluded here for +reasons of safety. It is discouraged to use strings for type specifications when working +with pystencils!) +""" + +SrcType = NewType('SrcType', str) +"""Nonprimitive C/C++-Types occuring during source file generation. + +Nonprimitive C/C++ types are represented by their names. +When necessary, the SFG package checks equality of types by these name strings; it does +not care about typedefs, aliases, namespaces, etc! +""" + + +class SrcObject: + """C/C++ object of nonprimitive type. + + Two objects are identical if they have the same identifier and type string.""" + + def __init__(self, src_type: SrcType, identifier: Optional[str]): + self._src_type = src_type + self._identifier = identifier + + @property + def identifier(self): + return self._identifier + + @property + def name(self): + """For interface compatibility with ps.TypedSymbol""" + return self._identifier + + @property + def dtype(self): + return self._src_type + + @property + def required_includes(self) -> Set[SfgHeaderInclude]: + return set() + + def __hash__(self) -> int: + return hash((self._identifier, self._src_type)) + + def __eq__(self, other: SrcObject) -> bool: + return (isinstance(other, SrcObject) + and self._identifier == other._identifier + and self._src_type == other._src_type) + + +TypedSymbolOrObject: TypeAlias = Union[TypedSymbol, SrcObject] + + +class SrcField(SrcObject, ABC): + def __init__(self, src_type: SrcType, identifier: Optional[str]): + super().__init__(src_type, identifier) + + @abstractmethod + def extract_ptr(self, ptr_symbol: FieldPointerSymbol) -> SfgStatements: + pass + + @abstractmethod + def extract_size(self, coordinate: int, size: Union[int, FieldShapeSymbol]) -> SfgStatements: + pass + + @abstractmethod + def extract_stride(self, coordinate: int, stride: Union[int, FieldStrideSymbol]) -> SfgStatements: + pass + + def extract_parameters(self, field: Field) -> SfgSequence: + ptr = FieldPointerSymbol(field.name, field.dtype, False) + + from ..tree import make_sequence + + return make_sequence( + self.extract_ptr(ptr), + *(self.extract_size(c, s) for c, s in enumerate(field.shape)), + *(self.extract_stride(c, s) for c, s in enumerate(field.strides)) + ) + + +class SrcVector(SrcObject): + @abstractmethod + def extract_component(self, destination: TypedSymbolOrObject, coordinate: int): + pass diff --git a/pystencilssfg/tree/basic_nodes.py b/pystencilssfg/tree/basic_nodes.py index 429f504314c2a077d2f48b61698abf33f9cb8e1f..75fceb5ddf6f6e9edd587643287fc88452dc75e7 100644 --- a/pystencilssfg/tree/basic_nodes.py +++ b/pystencilssfg/tree/basic_nodes.py @@ -6,13 +6,10 @@ if TYPE_CHECKING: from ..source_components import SfgHeaderInclude from abc import ABC, abstractmethod -from functools import reduce from itertools import chain -from jinja2.filters import do_indent - from ..kernel_namespace import SfgKernelHandle -from ..source_concepts.source_concepts import SrcObject +from ..source_concepts.source_objects import SrcObject, TypedSymbolOrObject from ..exceptions import SfgException @@ -54,7 +51,7 @@ class SfgCallTreeLeaf(SfgCallTreeNode, ABC): @property @abstractmethod - def required_symbols(self) -> Set[TypedSymbol]: + def required_parameters(self) -> Set[TypedSymbolOrObject]: pass @@ -77,33 +74,25 @@ class SfgStatements(SfgCallTreeLeaf): def __init__(self, code_string: str, - defined_objects: Sequence[Union[SrcObject, TypedSymbol]], - required_objects: Sequence[Union[SrcObject, TypedSymbol]]): + defined_params: Sequence[TypedSymbolOrObject], + required_params: Sequence[TypedSymbolOrObject]): self._code_string = code_string - def to_symbol(obj: Union[SrcObject, TypedSymbol]): - if isinstance(obj, SrcObject): - return obj.typed_symbol - elif isinstance(obj, TypedSymbol): - return obj - else: - raise ValueError(f"Required object in expression is neither TypedSymbol nor SrcObject: {obj}") - - self._defined_symbols = set(map(to_symbol, defined_objects)) - self._required_symbols = set(map(to_symbol, required_objects)) + self._defined_params = set(defined_params) + self._required_params = set(required_params) self._required_includes = set() - for obj in chain(required_objects, defined_objects): + for obj in chain(required_params, defined_params): if isinstance(obj, SrcObject): self._required_includes |= obj.required_includes @property - def required_symbols(self) -> Set[TypedSymbol]: - return self._required_symbols + def required_parameters(self) -> Set[TypedSymbolOrObject]: + return self._required_params @property - def defined_symbols(self) -> Set[TypedSymbol]: - return self._defined_symbols + def defined_parameters(self) -> Set[TypedSymbolOrObject]: + return self._defined_params @property def required_includes(self) -> Set[SfgHeaderInclude]: @@ -153,7 +142,7 @@ class SfgKernelCallNode(SfgCallTreeLeaf): self._kernel_handle = kernel_handle @property - def required_symbols(self) -> Set[TypedSymbol]: + def required_parameters(self) -> Set[TypedSymbolOrObject]: return set(p.symbol for p in self._kernel_handle.parameters) def get_code(self, ctx: SfgContext) -> str: diff --git a/pystencilssfg/tree/conditional.py b/pystencilssfg/tree/conditional.py index acfcc615d7cd13d8d132120a2677bb1f01da22a0..87313442a4f5c06e975b97d02b82ac417b044291 100644 --- a/pystencilssfg/tree/conditional.py +++ b/pystencilssfg/tree/conditional.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Sequence, Optional +from typing import TYPE_CHECKING, Sequence, Optional, Set if TYPE_CHECKING: from ..context import SfgContext @@ -8,6 +8,7 @@ from jinja2.filters import do_indent from pystencils.typing import TypedSymbol from .basic_nodes import SfgCallTreeNode, SfgCallTreeLeaf +from ..source_concepts.source_objects import TypedSymbolOrObject class SfgCondition(SfgCallTreeLeaf): pass @@ -16,7 +17,7 @@ class SfgCustomCondition(SfgCondition): def __init__(self, cond_text: str): self._cond_text = cond_text - def required_symbols(self) -> set(TypedSymbol): + def required_parameters(self) -> Set[TypedSymbolOrObject]: return set() def get_code(self, ctx: SfgContext) -> str: diff --git a/pystencilssfg/tree/deferred_nodes.py b/pystencilssfg/tree/deferred_nodes.py index 349c7bd33a2c1adfa2d7567ddee7d739c1845952..b14ba2ddb9d122a9535ffd4f1b4b37587362bc1c 100644 --- a/pystencilssfg/tree/deferred_nodes.py +++ b/pystencilssfg/tree/deferred_nodes.py @@ -14,7 +14,7 @@ from ..exceptions import SfgException from .basic_nodes import SfgCallTreeNode from .builders import make_sequence -from ..source_concepts.containers import SrcField +from ..source_concepts import SrcField class SfgDeferredNode(SfgCallTreeNode, ABC): diff --git a/pystencilssfg/tree/visitors.py b/pystencilssfg/tree/visitors.py index 88c078d274453b350e9add04087313d205082b48..9da7db247b8206d68ea81ce47ca5e79cc6889293 100644 --- a/pystencilssfg/tree/visitors.py +++ b/pystencilssfg/tree/visitors.py @@ -66,7 +66,7 @@ class ExpandingParameterCollector(): return self._visit_branchingNode(node) def _visit_SfgCallTreeLeaf(self, leaf: SfgCallTreeLeaf) -> Set[TypedSymbol]: - return leaf.required_symbols + return leaf.required_parameters def _visit_SfgSequence(self, sequence: SfgSequence) -> Set[TypedSymbol]: """ @@ -87,7 +87,7 @@ class ExpandingParameterCollector(): iter_nested_sequences(c, visible_params) else: if isinstance(c, SfgStatements): - visible_params -= c.defined_symbols + visible_params -= c.defined_parameters visible_params |= self.visit(c) @@ -117,7 +117,7 @@ class ParameterCollector(): return self._visit_branchingNode(node) def _visit_SfgCallTreeLeaf(self, leaf: SfgCallTreeLeaf) -> Set[TypedSymbol]: - return leaf.required_symbols + return leaf.required_parameters def _visit_SfgSequence(self, sequence: SfgSequence) -> Set[TypedSymbol]: """ @@ -127,7 +127,7 @@ class ParameterCollector(): params = set() for c in sequence.children[::-1]: if isinstance(c, SfgStatements): - params -= c.defined_symbols + params -= c.defined_parameters assert not isinstance(c, SfgSequence), "Sequence not flattened." params |= self.visit(c) diff --git a/pystencilssfg/types.py b/pystencilssfg/types.py new file mode 100644 index 0000000000000000000000000000000000000000..9e309e823cea2404c5e1f65980fc7bdb163b2b07 --- /dev/null +++ b/pystencilssfg/types.py @@ -0,0 +1,12 @@ +from pystencils.typing import AbstractType, BasicType, StructType, PointerType + + +class SrcType: + """Valid C/C++-Type occuring during source file generation. + + Nonprimitive C/C++ types are represented by their names. + When necessary, the SFG package checks equality of types by these name strings; it does + not care about typedefs, aliases, namespaces, etc! + """ + + diff --git a/tests/mdspan/Makefile b/tests/mdspan/Makefile index 76e57a6bea12495feb0cddc905250ae5536c77d9..1761528f0f2edc9ea2cf07e4ba9b1334a4e87afb 100644 --- a/tests/mdspan/Makefile +++ b/tests/mdspan/Makefile @@ -27,8 +27,10 @@ $(OBJ)/kernels.o: $(GEN_SRC)/kernels.cpp $(GEN_SRC)/kernels.h $(CXX) $(CXX_FLAGS) -c -o $@ $< $(OBJ)/main.o: main.cpp $(GEN_SRC)/kernels.h + @$(dir_guard) $(CXX) $(CXX_FLAGS) -c -o $@ $< $(BIN)/mdspan_test: $(OBJ)/kernels.o $(OBJ)/main.o + @$(dir_guard) $(CXX) $(CXX_FLAGS) -o $@ $^ diff --git a/tests/mdspan/kernels.py b/tests/mdspan/kernels.py index d8f582c5786bfa38a63aaba60d6a4352c656f76a..27a482ae02d852c8de439a07ac41fc61efc37744 100644 --- a/tests/mdspan/kernels.py +++ b/tests/mdspan/kernels.py @@ -1,22 +1,32 @@ +import sympy as sp import numpy as np + from pystencils.session import * from pystencilssfg import SourceFileGenerator -from pystencilssfg.source_concepts.cpp.std_mdspan import std_mdspan +from pystencilssfg.source_concepts.cpp import std_mdspan + +def field_t(field: ps.Field): + return std_mdspan(field.name, + field.dtype, + (std_mdspan.dynamic_extent, std_mdspan.dynamic_extent), + extents_type=np.uint32, + reference=True) + -with SourceFileGenerator() as sfg: +with SourceFileGenerator("poisson") as sfg: src, dst = ps.fields("src, dst(1) : double[2D]") - @ps.kernel - def poisson_gs(): - dst[0,0] @= src[1, 0] + src[-1, 0] + src[0, 1] + src[0, -1] - 4 * src[0, 0] + h = sp.Symbol('h') - sfg.include("<iostream>") + @ps.kernel + def poisson_jacobi(): + dst[0,0] @= (src[1, 0] + src[-1, 0] + src[0, 1] + src[0, -1]) / 4 - poisson_kernel = sfg.kernels.create(poisson_gs) + poisson_kernel = sfg.kernels.create(poisson_jacobi) - sfg.function("myFunction")( - sfg.map_field(src, std_mdspan(src.name, np.float64, (std_mdspan.dynamic_extent, std_mdspan.dynamic_extent, 1))), - sfg.map_field(dst, std_mdspan(dst.name, np.float64, (2, 2, 1))), + sfg.function("jacobi_smooth")( + sfg.map_field(src, field_t(src)), + sfg.map_field(dst, field_t(dst)), sfg.call(poisson_kernel) ) diff --git a/tests/mdspan/main.cpp b/tests/mdspan/main.cpp index f44ca21ed9c35e25c3e9c5b48d0f7f2ed793abe7..d8247a580ca17ca49f2a2fb1e6535b8bac6fa96f 100644 --- a/tests/mdspan/main.cpp +++ b/tests/mdspan/main.cpp @@ -1,6 +1,58 @@ +#include <iostream> +#include <fstream> + +#include <cstdint> +#include <vector> + +#include <experimental/mdspan> + #include "generated_src/kernels.h" +using field_t = std::mdspan< double, std::extents< uint32_t, std::dynamic_extent, std::dynamic_extent > >; + +double boundary(double x, double y){ + return 1.0; +} + int main(int argc, char ** argv){ - pystencils::myFunction(); + uint32_t N = 8; /* number of grid nodes */ + double h = 1.0 / (double(N) - 1); + uint32_t n_iters = 100; + + std::vector< double > data_src(N*N); + field_t src(data_src.data(), N, N); + + std::vector< double > data_dst(N*N); + field_t dst(data_dst.data(), N, N); + + for(uint32_t i = 0; i < N; ++i){ + for(uint32_t j = 0; j < N; ++j){ + if(i == 0 || j == 0 || i == N-1 || j == N-1){ + src[i, j] = boundary(double(i) * h, double(j) * h); + dst[i, j] = boundary(double(i) * h, double(j) * h); + } + } + } + + for(uint32_t i = 0; i < n_iters; ++i){ + poisson::jacobi_smooth(dst, src); + std::swap(src, dst); + } + + std::ofstream f("data.out", std::ios::trunc | std::ios::out); + + if(!f.is_open()){ + std::cerr << "Could not open output file.\n"; + } else { + for(uint32_t i = 0; i < N; ++i){ + for(uint32_t j = 0; j < N; ++j){ + f << src[i, j] << " "; + } + f << '\n'; + } + } + + f.close(); + return 0; }