diff --git a/src/pystencilssfg/extensions/sycl.py b/src/pystencilssfg/extensions/sycl.py index af59f4f82b1451928b3e3bbdb6835d4ce92f33c5..4ee499126fe795f95284d9c40d02513f53333054 100644 --- a/src/pystencilssfg/extensions/sycl.py +++ b/src/pystencilssfg/extensions/sycl.py @@ -5,7 +5,6 @@ import re from pystencils.types import PsType, PsCustomType from pystencils.enums import Target -from pystencils.backend.kernelfunction import KernelParameter from ..exceptions import SfgException from ..context import SfgContext @@ -15,8 +14,7 @@ from ..composer import ( SfgComposer, SfgComposerMixIn, ) -from ..ir.source_components import SfgKernelHandle, SfgHeaderInclude -from ..ir.source_components import SfgSymbolLike +from ..ir.source_components import SfgKernelHandle, SfgHeaderInclude, SfgKernelParamVar from ..ir import ( SfgCallTreeNode, SfgCallTreeLeaf, @@ -75,7 +73,7 @@ class SyclHandler(AugExpr): id_regex = re.compile(r"sycl::(id|item|nd_item)<\s*[0-9]\s*>") - def filter_id(param: SfgSymbolLike[KernelParameter]) -> bool: + def filter_id(param: SfgKernelParamVar) -> bool: return ( isinstance(param.dtype, PsCustomType) and id_regex.search(param.dtype.c_string()) is not None @@ -119,7 +117,7 @@ class SyclGroup(AugExpr): id_regex = re.compile(r"sycl::id<\s*[0-9]\s*>") - def filter_id(param: SfgSymbolLike[KernelParameter]) -> bool: + def filter_id(param: SfgKernelParamVar) -> bool: return ( isinstance(param.dtype, PsCustomType) and id_regex.search(param.dtype.c_string()) is not None diff --git a/src/pystencilssfg/ir/__init__.py b/src/pystencilssfg/ir/__init__.py index 1ae1749367527472aaf5ba77ddf09a8081ed578b..9ab366cbd65373639e2123b9d896ab2839225166 100644 --- a/src/pystencilssfg/ir/__init__.py +++ b/src/pystencilssfg/ir/__init__.py @@ -19,7 +19,7 @@ from .source_components import ( SfgEmptyLines, SfgKernelNamespace, SfgKernelHandle, - SfgSymbolLike, + SfgKernelParamVar, SfgFunction, SfgVisibility, SfgClassKeyword, @@ -50,7 +50,7 @@ __all__ = [ "SfgEmptyLines", "SfgKernelNamespace", "SfgKernelHandle", - "SfgSymbolLike", + "SfgKernelParamVar", "SfgFunction", "SfgVisibility", "SfgClassKeyword", diff --git a/src/pystencilssfg/ir/postprocessing.py b/src/pystencilssfg/ir/postprocessing.py index 851a981862cb900442d92624bbaef0dbece68f89..c33ec7ab282432d6f462c3316032bb8842cd7652 100644 --- a/src/pystencilssfg/ir/postprocessing.py +++ b/src/pystencilssfg/ir/postprocessing.py @@ -8,18 +8,14 @@ from abc import ABC, abstractmethod import sympy as sp -from pystencils import Field, TypedSymbol +from pystencils import Field from pystencils.types import deconstify -from pystencils.backend.kernelfunction import ( - FieldPointerParam, - FieldShapeParam, - FieldStrideParam, -) +from pystencils.backend.properties import FieldBasePtr, FieldShape, FieldStride from ..exceptions import SfgException from .call_tree import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence, SfgStatements -from ..ir.source_components import SfgSymbolLike +from ..ir.source_components import SfgKernelParamVar from ..lang import SfgVar, IFieldExtraction, SrcField, SrcVector if TYPE_CHECKING: @@ -252,43 +248,38 @@ class SfgDeferredFieldMapping(SfgDeferredNode): else extraction.get_extraction() ) - # type: ignore def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode: # Find field pointer - ptr: SfgSymbolLike[FieldPointerParam] | None = None - shape: list[SfgSymbolLike[FieldShapeParam] | int | None] = [None] * len( - self._field.shape - ) - strides: list[SfgSymbolLike[FieldStrideParam] | int | None] = [None] * len( + ptr: SfgKernelParamVar | None = None + shape: list[SfgKernelParamVar | str | None] = [None] * len(self._field.shape) + strides: list[SfgKernelParamVar | str | None] = [None] * len( self._field.strides ) for param in ppc.live_variables: - # idk why, but mypy does not understand these pattern matches - match param: - case SfgSymbolLike(FieldPointerParam(_, _, field)) if field == self._field: # type: ignore - ptr = param - case SfgSymbolLike( - FieldShapeParam(_, _, field, coord) # type: ignore - ) if field == self._field: # type: ignore - shape[coord] = param # type: ignore - case SfgSymbolLike( - FieldStrideParam(_, _, field, coord) # type: ignore - ) if field == self._field: # type: ignore - strides[coord] = param # type: ignore - - # Find constant sizes + if isinstance(param, SfgKernelParamVar): + for prop in param.wrapped.properties: + match prop: + case FieldBasePtr(field) if field == self._field: + ptr = param + case FieldShape(field, coord) if field == self._field: # type: ignore + shape[coord] = param # type: ignore + case FieldStride(field, coord) if field == self._field: # type: ignore + strides[coord] = param # type: ignore + + # Find constant or otherwise determined sizes for coord, s in enumerate(self._field.shape): - if not isinstance(s, TypedSymbol): - shape[coord] = s + if shape[coord] is None: + shape[coord] = str(s) - # Find constant strides + # Find constant or otherwise determined strides for coord, s in enumerate(self._field.strides): - if not isinstance(s, TypedSymbol): - strides[coord] = s + if strides[coord] is None: + strides[coord] = str(s) # Now we have all the symbols, start extracting them nodes = [] + done: set[SfgKernelParamVar] = set() if ptr is not None: expr = self._extraction.ptr() @@ -298,7 +289,7 @@ class SfgDeferredFieldMapping(SfgDeferredNode): ) ) - def get_shape(coord, symb: SfgSymbolLike | int): + def get_shape(coord, symb: SfgKernelParamVar | str): expr = self._extraction.size(coord) if expr is None: @@ -306,14 +297,15 @@ class SfgDeferredFieldMapping(SfgDeferredNode): f"Cannot extract shape in coordinate {coord} from {self._extraction}" ) - if isinstance(symb, SfgSymbolLike): + if isinstance(symb, SfgKernelParamVar) and symb not in done: + done.add(symb) return SfgStatements( f"{symb.dtype} {symb.name} {{ {expr} }};", (symb,), expr.depends ) else: return SfgStatements(f"/* {expr} == {symb} */", (), ()) - def get_stride(coord, symb: SfgSymbolLike | int): + def get_stride(coord, symb: SfgKernelParamVar | str): expr = self._extraction.stride(coord) if expr is None: @@ -321,7 +313,8 @@ class SfgDeferredFieldMapping(SfgDeferredNode): f"Cannot extract stride in coordinate {coord} from {self._extraction}" ) - if isinstance(symb, SfgSymbolLike): + if isinstance(symb, SfgKernelParamVar) and symb not in done: + done.add(symb) return SfgStatements( f"{symb.dtype} {symb.name} {{ {expr} }};", (symb,), expr.depends ) diff --git a/src/pystencilssfg/ir/source_components.py b/src/pystencilssfg/ir/source_components.py index 8a4e90967d74aabc4f50529c7ff0e51c76a3d869..4398938327242e8e1c1fb2cfc7787c72fa0d3138 100644 --- a/src/pystencilssfg/ir/source_components.py +++ b/src/pystencilssfg/ir/source_components.py @@ -2,7 +2,7 @@ from __future__ import annotations from abc import ABC from enum import Enum, auto -from typing import TYPE_CHECKING, Sequence, Generator, TypeVar, Generic +from typing import TYPE_CHECKING, Sequence, Generator, TypeVar from dataclasses import replace from itertools import chain @@ -10,7 +10,6 @@ from pystencils import CreateKernelConfig, create_kernel, Field from pystencils.backend.kernelfunction import ( KernelFunction, KernelParameter, - FieldParameter, ) from pystencils.types import PsType, PsCustomType @@ -162,14 +161,14 @@ class SfgKernelHandle: self._ctx = ctx self._name = name self._namespace = namespace - self._parameters = [SfgSymbolLike(p) for p in parameters] + self._parameters = [SfgKernelParamVar(p) for p in parameters] - self._scalar_params: set[SfgSymbolLike] = set() + self._scalar_params: set[SfgKernelParamVar] = set() self._fields: set[Field] = set() for param in self._parameters: - if isinstance(param.wrapped, FieldParameter): - self._fields.add(param.wrapped.field) + if param.wrapped.is_field_parameter: + self._fields |= set(param.wrapped.fields) else: self._scalar_params.add(param) @@ -190,7 +189,7 @@ class SfgKernelHandle: return f"{fqn}::{self.kernel_namespace.name}::{self.kernel_name}" @property - def parameters(self) -> Sequence[SfgSymbolLike]: + def parameters(self) -> Sequence[SfgKernelParamVar]: return self._parameters @property @@ -208,17 +207,17 @@ class SfgKernelHandle: SymbolLike_T = TypeVar("SymbolLike_T", bound=KernelParameter) -class SfgSymbolLike(SfgVar, Generic[SymbolLike_T]): +class SfgKernelParamVar(SfgVar): __match_args__ = ("wrapped",) """Cast pystencils- or SymPy-native symbol-like objects as a `SfgVar`.""" - def __init__(self, param: SymbolLike_T): + def __init__(self, param: KernelParameter): self._param = param super().__init__(param.name, param.dtype) @property - def wrapped(self) -> SymbolLike_T: + def wrapped(self) -> KernelParameter: return self._param def _args(self): diff --git a/tests/generator_scripts/expected/Variables.h b/tests/generator_scripts/expected/Variables.h index 96c16d7b306fe7594db0caa02e9168b2f5c86fc6..b6f2ce700b80d6e8c062208dc5239c67d2c728fc 100644 --- a/tests/generator_scripts/expected/Variables.h +++ b/tests/generator_scripts/expected/Variables.h @@ -9,5 +9,5 @@ private: float alpha; public: Scale(float alpha) : alpha{ alpha } {} - void operator() (float *const _data_f, float *const _data_g); + void operator() (float *RESTRICT const _data_f, float *RESTRICT const _data_g); }; diff --git a/tests/ir/test_postprocessing.py b/tests/ir/test_postprocessing.py index e144024fa831bf5c31c17b7efc93d82857654671..3030e1294d784c88c11e5770c9af0bfe00c21333 100644 --- a/tests/ir/test_postprocessing.py +++ b/tests/ir/test_postprocessing.py @@ -1,10 +1,13 @@ import sympy as sp -from pystencils import fields, kernel, TypedSymbol +from pystencils import fields, kernel, TypedSymbol, Field, FieldType, create_type +from pystencils.types import PsCustomType from pystencilssfg import SfgContext, SfgComposer from pystencilssfg.composer import make_sequence -from pystencilssfg.ir import SfgStatements +from pystencilssfg.lang import IFieldExtraction, AugExpr + +from pystencilssfg.ir import SfgStatements, SfgSequence from pystencilssfg.ir.postprocessing import CallTreePostProcessing @@ -75,3 +78,101 @@ def test_find_sympy_symbols(): assert isinstance(call_tree.children[1], SfgStatements) assert call_tree.children[1].code_string == "const double y = x / a;" + + +class TestFieldExtraction(IFieldExtraction): + def __init__(self, name: str): + self.obj = AugExpr(PsCustomType("MyField")).var(name) + + def ptr(self) -> AugExpr: + return AugExpr.format("{}.ptr()", self.obj) + + def size(self, coordinate: int) -> AugExpr | None: + return AugExpr.format("{}.size({})", self.obj, coordinate) + + def stride(self, coordinate: int) -> AugExpr | None: + return AugExpr.format("{}.stride({})", self.obj, coordinate) + + +def test_field_extraction(): + sx, sy, tx, ty = [ + TypedSymbol(n, create_type("int64")) for n in ("sx", "sy", "tx", "ty") + ] + f = Field("f", FieldType.GENERIC, "double", (1, 0), (sx, sy), (tx, ty)) + + @kernel + def set_constant(): + f.center @= 13.2 + + sfg = SfgComposer(SfgContext()) + + khandle = sfg.kernels.create(set_constant) + + extraction = TestFieldExtraction("f") + call_tree = make_sequence(sfg.map_field(f, extraction), sfg.call(khandle)) + + pp = CallTreePostProcessing() + free_vars = pp.get_live_variables(call_tree) + assert free_vars == {extraction.obj.as_variable()} + + lines = [ + r"double * RESTRICT const _data_f { f.ptr() };", + r"const int64_t sx { f.size(0) };", + r"const int64_t sy { f.size(1) };", + r"const int64_t tx { f.stride(0) };", + r"const int64_t ty { f.stride(1) };", + ] + + assert isinstance(call_tree.children[0], SfgSequence) + for line, stmt in zip(lines, call_tree.children[0].children, strict=True): + assert isinstance(stmt, SfgStatements) + assert stmt.code_string == line + + +def test_duplicate_field_shapes(): + N, tx, ty = [TypedSymbol(n, create_type("int64")) for n in ("N", "tx", "ty")] + f = Field("f", FieldType.GENERIC, "double", (1, 0), (N, N), (tx, ty)) + g = Field("g", FieldType.GENERIC, "double", (1, 0), (N, N), (tx, ty)) + + @kernel + def set_constant(): + f.center @= g.center(0) + + sfg = SfgComposer(SfgContext()) + + khandle = sfg.kernels.create(set_constant) + + call_tree = make_sequence( + sfg.map_field(g, TestFieldExtraction("g")), + sfg.map_field(f, TestFieldExtraction("f")), + sfg.call(khandle), + ) + + pp = CallTreePostProcessing() + _ = pp.get_live_variables(call_tree) + + lines_g = [ + r"double * RESTRICT const _data_g { g.ptr() };", + r"/* g.size(0) == N */", + r"/* g.size(1) == N */", + r"/* g.stride(0) == tx */", + r"/* g.stride(1) == ty */", + ] + + assert isinstance(call_tree.children[0], SfgSequence) + for line, stmt in zip(lines_g, call_tree.children[0].children, strict=True): + assert isinstance(stmt, SfgStatements) + assert stmt.code_string == line + + lines_f = [ + r"double * RESTRICT const _data_f { f.ptr() };", + r"const int64_t N { f.size(0) };", + r"/* f.size(1) == N */", + r"const int64_t tx { f.stride(0) };", + r"const int64_t ty { f.stride(1) };", + ] + + assert isinstance(call_tree.children[1], SfgSequence) + for line, stmt in zip(lines_f, call_tree.children[1].children, strict=True): + assert isinstance(stmt, SfgStatements) + assert stmt.code_string == line