diff --git a/src/pystencilssfg/extensions/sycl.py b/src/pystencilssfg/extensions/sycl.py index 3cb0c1c5e50aa2b9557a176f3c541283641ad530..bb99101c2a32e7ab274a6e532ca0799e232807e3 100644 --- a/src/pystencilssfg/extensions/sycl.py +++ b/src/pystencilssfg/extensions/sycl.py @@ -6,6 +6,8 @@ import re from pystencils.types import PsType, PsCustomType from pystencils.enums import Target +from pystencilssfg.composer.basic_composer import SequencerArg + from ..exceptions import SfgException from ..context import SfgContext from ..composer import ( @@ -13,6 +15,7 @@ from ..composer import ( SfgClassComposer, SfgComposer, SfgComposerMixIn, + make_sequence, ) from ..ir.source_components import SfgKernelHandle, SfgHeaderInclude from ..ir import ( @@ -56,12 +59,13 @@ class SyclHandler(AugExpr): self._ctx = ctx - def parallel_for(self, range: SfgVar | Sequence[int], kernel: SfgKernelHandle): + def parallel_for(self, range: SfgVar | Sequence[int], kernel: SfgKernelHandle, *, extras: Sequence[SequencerArg]=[]): """Generate a ``parallel_for`` kernel invocation using this command group handler. Args: range: Object, or tuple of integers, indicating the kernel's iteration range kernel: Handle to the pystencils-kernel to be executed + extras: Statements that should be in the parallel_for but before the kernel call """ self._ctx.add_include(SfgHeaderInclude("sycl/sycl.hpp", system_header=True)) @@ -81,7 +85,7 @@ class SyclHandler(AugExpr): id_param = list(filter(filter_id, kernel.scalar_parameters))[0] - tree = SfgKernelCallNode(kernel) + tree = make_sequence(*extras, SfgKernelCallNode(kernel)) kernel_lambda = SfgLambda(("=",), (id_param,), tree, None) return SyclKernelInvoke(self, SyclInvokeType.ParallelFor, range, kernel_lambda) diff --git a/src/pystencilssfg/lang/cpp/sycl_accessor.py b/src/pystencilssfg/lang/cpp/sycl_accessor.py new file mode 100644 index 0000000000000000000000000000000000000000..c6cd011d9fee77119cfca2deace55ebbfc5898e2 --- /dev/null +++ b/src/pystencilssfg/lang/cpp/sycl_accessor.py @@ -0,0 +1,77 @@ +from ...lang import SrcField, IFieldExtraction +from ...ir.source_components import SfgHeaderInclude + +from pystencils import Field +from pystencils.types import ( + PsType, + PsCustomType, +) + +from pystencilssfg.lang.expressions import AugExpr + + +class SyclAccessor(SrcField): + def __init__( + self, + T: PsType, + dimensions: int, + reference: bool = False, + ): + cpp_typestr = T.c_string() + if dimensions not in [1, 2, 3]: + raise ValueError("sycl accessors can only have dims 1, 2 or 3") + typestring = ( + f"sycl::accessor< {cpp_typestr}, {dimensions} > {'&' if reference else ''}" + ) + super().__init__(PsCustomType(typestring)) + + self._dim = dimensions + + @property + def required_includes(self) -> set[SfgHeaderInclude]: + return {SfgHeaderInclude("sycl/sycl.hpp", system_header=True)} + + def get_extraction(self) -> IFieldExtraction: + accessor = self + + class Extraction(IFieldExtraction): + def ptr(self) -> AugExpr: + return AugExpr.format( + "{}.get_multi_ptr<sycl::access::decorated::no>().get()", + accessor, + ) + + def size(self, coordinate: int) -> AugExpr | None: + if coordinate > accessor._dim: + return None + else: + return AugExpr.format( + "{}.get_range().get({})", accessor, coordinate + ) + + def stride(self, coordinate: int) -> AugExpr | None: + if coordinate > accessor._dim: + return None + else: + if coordinate == accessor._dim - 1: + return AugExpr.format("1") + else: + exprs = [] + args = [] + for d in range(coordinate + 1, accessor._dim): + args.extend([accessor, d]) + exprs.append("{}.get_range().get({})") + expr = " * ".join(exprs) + return AugExpr.format(expr, *args) + + return Extraction() + + +def sycl_accessor_ref(field: Field): + """Creates a `sycl::accessor &` for a given pystencils field.""" + + return SyclAccessor( + field.dtype, + field.spatial_dimensions, + reference=True, + ).var(field.name)