diff --git a/src/pystencils/__init__.py b/src/pystencils/__init__.py index 61016e14f11a536444c798441ba8be516d97a167..3d3b7846a84bb9c477b38e90839d7f67fe12933c 100644 --- a/src/pystencils/__init__.py +++ b/src/pystencils/__init__.py @@ -6,8 +6,14 @@ from . import fd from . import stencil as stencil from .display_utils import get_code_obj, get_code_str, show_code, to_dot from .field import Field, FieldType, fields +from .types import create_type from .cache import clear_cache -from .config import CreateKernelConfig, CpuOptimConfig, VectorizationConfig +from .config import ( + CreateKernelConfig, + CpuOptimConfig, + VectorizationConfig, + OpenMpConfig, +) from .kernel_decorator import kernel, kernel_config from .kernelcreation import create_kernel from .backend.kernelfunction import KernelFunction @@ -34,10 +40,12 @@ __all__ = [ "fields", "DEFAULTS", "TypedSymbol", + "create_type", "make_slice", "CreateKernelConfig", "CpuOptimConfig", "VectorizationConfig", + "OpenMpConfig", "create_kernel", "KernelFunction", "Target", diff --git a/src/pystencils/backend/ast/analysis.py b/src/pystencils/backend/ast/analysis.py index 040c6167827dd48c97bc289c0959d937bbbefb38..15ee0680edb0b5b8197aec2545d182f90eb6c71a 100644 --- a/src/pystencils/backend/ast/analysis.py +++ b/src/pystencils/backend/ast/analysis.py @@ -5,7 +5,7 @@ from .structural import ( PsAssignment, PsAstNode, PsBlock, - PsComment, + PsEmptyLeafMixIn, PsConditional, PsDeclaration, PsExpression, @@ -63,7 +63,7 @@ class UndefinedSymbolsCollector: undefined_vars |= self(branch_false) return undefined_vars - case PsComment(): + case PsEmptyLeafMixIn(): return set() case unknown: @@ -92,11 +92,11 @@ class UndefinedSymbolsCollector: case ( PsAssignment() | PsBlock() - | PsComment() | PsConditional() | PsExpression() | PsLoop() | PsStatement() + | PsEmptyLeafMixIn() ): return set() diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index 7bcf62b973d8ace8e9ad9847ae165c398f1cbb0e..4063f7b539ab387d1f950a75e735f3c6201b5ef2 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -158,7 +158,7 @@ class PsConstantExpr(PsLeafMixIn, PsExpression): def __repr__(self) -> str: return f"PsConstantExpr({repr(self._constant)})" - + class PsLiteralExpr(PsLeafMixIn, PsExpression): __match_args__ = ("literal",) @@ -177,7 +177,7 @@ class PsLiteralExpr(PsLeafMixIn, PsExpression): def clone(self) -> PsLiteralExpr: return PsLiteralExpr(self._literal) - + def structurally_equal(self, other: PsAstNode) -> bool: if not isinstance(other, PsLiteralExpr): return False diff --git a/src/pystencils/backend/ast/iteration.py b/src/pystencils/backend/ast/iteration.py index 6c1c406ed4602ed98416816886686fe29324975f..cc666c72257c5a434c0b312c9e8d8e5ea9dc028f 100644 --- a/src/pystencils/backend/ast/iteration.py +++ b/src/pystencils/backend/ast/iteration.py @@ -4,32 +4,32 @@ from .structural import PsAstNode def dfs_preorder( - node: PsAstNode, yield_pred: Callable[[PsAstNode], bool] = lambda _: True + node: PsAstNode, filter_pred: Callable[[PsAstNode], bool] = lambda _: True ) -> Generator[PsAstNode, None, None]: """Pre-Order depth-first traversal of an abstract syntax tree. Args: node: The tree's root node - yield_pred: Filter predicate; a node is only yielded to the caller if `yield_pred(node)` returns True + filter_pred: Filter predicate; a node is only returned to the caller if `yield_pred(node)` returns True """ - if yield_pred(node): + if filter_pred(node): yield node for c in node.children: - yield from dfs_preorder(c, yield_pred) + yield from dfs_preorder(c, filter_pred) def dfs_postorder( - node: PsAstNode, yield_pred: Callable[[PsAstNode], bool] = lambda _: True + node: PsAstNode, filter_pred: Callable[[PsAstNode], bool] = lambda _: True ) -> Generator[PsAstNode, None, None]: """Post-Order depth-first traversal of an abstract syntax tree. Args: node: The tree's root node - yield_pred: Filter predicate; a node is only yielded to the caller if `yield_pred(node)` returns True + filter_pred: Filter predicate; a node is only returned to the caller if `yield_pred(node)` returns True """ for c in node.children: - yield from dfs_postorder(c, yield_pred) + yield from dfs_postorder(c, filter_pred) - if yield_pred(node): + if filter_pred(node): yield node diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py index 47342cfedcb7c748f8c460ae99512ec64800ce61..cd3aae30d35061ab6c15c338a735aaecca83a141 100644 --- a/src/pystencils/backend/ast/structural.py +++ b/src/pystencils/backend/ast/structural.py @@ -307,7 +307,42 @@ class PsConditional(PsAstNode): assert False, "unreachable code" -class PsComment(PsLeafMixIn, PsAstNode): +class PsEmptyLeafMixIn: + """Mix-in marking AST leaves that can be treated as empty by the code generator, + such as comments and preprocessor directives.""" + + pass + + +class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode): + """A C/C++ preprocessor pragma. + + Example usage: ``PsPragma("omp parallel for")`` translates to ``#pragma omp parallel for``. + + Args: + text: The pragma's text, without the ``#pragma ``. + """ + + __match_args__ = ("text",) + + def __init__(self, text: str) -> None: + self._text = text + + @property + def text(self) -> str: + return self._text + + def clone(self) -> PsPragma: + return PsPragma(self.text) + + def structurally_equal(self, other: PsAstNode) -> bool: + if not isinstance(other, PsPragma): + return False + + return self._text == other._text + + +class PsComment(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode): __match_args__ = ("lines",) def __init__(self, text: str) -> None: diff --git a/src/pystencils/backend/emission.py b/src/pystencils/backend/emission.py index f3d56c6c4c20e5969ee10d08ee42b6803a2e0b1c..e8fc2a662b2f49c43fe51d021915b9e44cc59ad3 100644 --- a/src/pystencils/backend/emission.py +++ b/src/pystencils/backend/emission.py @@ -10,6 +10,7 @@ from .ast.structural import ( PsLoop, PsConditional, PsComment, + PsPragma, ) from .ast.expressions import ( @@ -235,6 +236,9 @@ class CAstPrinter: lines_list[-1] = lines_list[-1] + " */" return pc.indent("\n".join(lines_list)) + case PsPragma(text): + return pc.indent("#pragma " + text) + case PsSymbolExpr(symbol): return symbol.name @@ -246,7 +250,7 @@ class CAstPrinter: ) return dtype.create_literal(constant.value) - + case PsLiteralExpr(lit): return lit.text diff --git a/src/pystencils/backend/kernelcreation/cpu_optimization.py b/src/pystencils/backend/kernelcreation/cpu_optimization.py index b0156c7e8ce0b9cac2c6f3be9f60bbeef41e1c51..29b133ff164e856783f14eb83357c8382db9ba5d 100644 --- a/src/pystencils/backend/kernelcreation/cpu_optimization.py +++ b/src/pystencils/backend/kernelcreation/cpu_optimization.py @@ -3,10 +3,9 @@ from typing import cast from .context import KernelCreationContext from ..platforms import GenericCpu -from ..transformations import CanonicalizeSymbols, HoistLoopInvariantDeclarations from ..ast.structural import PsBlock -from ...config import CpuOptimConfig +from ...config import CpuOptimConfig, OpenMpConfig def optimize_cpu( @@ -16,6 +15,7 @@ def optimize_cpu( cfg: CpuOptimConfig | None, ) -> PsBlock: """Carry out CPU-specific optimizations according to the given configuration.""" + from ..transformations import CanonicalizeSymbols, HoistLoopInvariantDeclarations canonicalize = CanonicalizeSymbols(ctx, True) kernel_ast = cast(PsBlock, canonicalize(kernel_ast)) @@ -32,8 +32,12 @@ def optimize_cpu( if cfg.vectorize is not False: raise NotImplementedError("Vectorization not implemented yet") - if cfg.openmp: - raise NotImplementedError("OpenMP not implemented yet") + if cfg.openmp is not False: + from ..transformations import AddOpenMP + + params = cfg.openmp if isinstance(cfg.openmp, OpenMpConfig) else OpenMpConfig() + add_omp = AddOpenMP(ctx, params) + kernel_ast = cast(PsBlock, add_omp(kernel_ast)) if cfg.use_cacheline_zeroing: raise NotImplementedError("CL-zeroing not implemented yet") diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index dbec20235f0a37cfab763771e2e2fbed05a3c196..06e34d4e36219366a25713d28ae758b61fb3d0d6 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -22,7 +22,7 @@ from ..ast.structural import ( PsExpression, PsAssignment, PsDeclaration, - PsComment, + PsEmptyLeafMixIn, ) from ..ast.expressions import ( PsArrayAccess, @@ -159,7 +159,7 @@ class TypeContext: f" Constant type: {c.dtype}\n" f" Target type: {self._target_type}" ) - + case PsLiteralExpr(lit): if not self._compatible(lit.dtype): raise TypificationError( @@ -336,7 +336,7 @@ class Typifier: self.visit(body) - case PsComment(): + case PsEmptyLeafMixIn(): pass case _: diff --git a/src/pystencils/backend/literals.py b/src/pystencils/backend/literals.py index dc7504f520f8950b46df76b0359aaad371244b19..dc254da0e340d518929d6eecb483defcdffbe185 100644 --- a/src/pystencils/backend/literals.py +++ b/src/pystencils/backend/literals.py @@ -4,7 +4,7 @@ from ..types import PsType, constify class PsLiteral: """Representation of literal code. - + Instances of this class represent code literals inside the AST. These literals are not to be confused with C literals; the name `Literal` refers to the fact that the code generator takes them "literally", printing them as they are. @@ -22,22 +22,22 @@ class PsLiteral: @property def text(self) -> str: return self._text - + @property def dtype(self) -> PsType: return self._dtype - + def __str__(self) -> str: return f"{self._text}: {self._dtype}" - + def __repr__(self) -> str: return f"PsLiteral({repr(self._text)}, {repr(self._dtype)})" - + def __eq__(self, other: object) -> bool: if not isinstance(other, PsLiteral): return False - + return self._text == other._text and self._dtype == other._dtype - + def __hash__(self) -> int: return hash((PsLiteral, self._text, self._dtype)) diff --git a/src/pystencils/backend/transformations/__init__.py b/src/pystencils/backend/transformations/__init__.py index 518c402d27e8828cc129c06a20bd10e2ba3d3168..88ad9348f09685258d2aecb5fca66fcfe609173b 100644 --- a/src/pystencils/backend/transformations/__init__.py +++ b/src/pystencils/backend/transformations/__init__.py @@ -60,6 +60,11 @@ Loop Reshaping Transformations .. autoclass:: ReshapeLoops :members: +.. autoclass:: InsertPragmasAtLoops + :members: + +.. autoclass:: AddOpenMP + :members: Code Lowering and Materialization --------------------------------- @@ -78,6 +83,7 @@ from .eliminate_constants import EliminateConstants from .eliminate_branches import EliminateBranches from .hoist_loop_invariant_decls import HoistLoopInvariantDeclarations from .reshape_loops import ReshapeLoops +from .add_pragmas import InsertPragmasAtLoops, LoopPragma, AddOpenMP from .erase_anonymous_structs import EraseAnonymousStructTypes from .select_functions import SelectFunctions from .select_intrinsics import MaterializeVectorIntrinsics @@ -89,6 +95,9 @@ __all__ = [ "EliminateBranches", "HoistLoopInvariantDeclarations", "ReshapeLoops", + "InsertPragmasAtLoops", + "LoopPragma", + "AddOpenMP", "EraseAnonymousStructTypes", "SelectFunctions", "MaterializeVectorIntrinsics", diff --git a/src/pystencils/backend/transformations/add_pragmas.py b/src/pystencils/backend/transformations/add_pragmas.py new file mode 100644 index 0000000000000000000000000000000000000000..c7015ccb62042e6beaf131e68485f7c8186b2a2a --- /dev/null +++ b/src/pystencils/backend/transformations/add_pragmas.py @@ -0,0 +1,118 @@ +from dataclasses import dataclass + +from typing import Sequence +from collections import defaultdict + +from ..kernelcreation import KernelCreationContext +from ..ast import PsAstNode +from ..ast.structural import PsBlock, PsLoop, PsPragma +from ..ast.expressions import PsExpression + +from ...config import OpenMpConfig + +__all__ = ["InsertPragmasAtLoops", "LoopPragma", "AddOpenMP"] + + +@dataclass +class LoopPragma: + """A pragma that should be prepended to loops at a certain nesting depth.""" + + text: str + """The pragma text, without the ``#pragma ``""" + + loop_nesting_depth: int + """Nesting depth of the loops the pragma should be added to. ``-1`` indicates the innermost loops.""" + + def __post_init__(self): + if self.loop_nesting_depth < -1: + raise ValueError("Loop nesting depth must be nonnegative or -1.") + + +@dataclass +class Nesting: + depth: int + has_inner_loops: bool = False + + +class InsertPragmasAtLoops: + """Insert pragmas before loops in a loop nest. + + This transformation augments the AST with pragma directives which are prepended to loops. + The directives are annotated with the nesting depth of the loops they should be added to, + where ``-1`` indicates the innermost loop. + + The relative order of pragmas with the (exact) same nesting depth is preserved; + however, no guarantees are given about the relative order of pragmas inserted at ``-1`` + and at the actual depth of the innermost loop. + """ + + def __init__( + self, ctx: KernelCreationContext, insertions: Sequence[LoopPragma] + ) -> None: + self._ctx = ctx + self._insertions: dict[int, list[LoopPragma]] = defaultdict(list) + for ins in insertions: + self._insertions[ins.loop_nesting_depth].append(ins) + + def __call__(self, node: PsAstNode) -> PsAstNode: + is_loop = isinstance(node, PsLoop) + if is_loop: + node = PsBlock([node]) + + self.visit(node, Nesting(0)) + + if is_loop and len(node.children) == 1: + node = node.children[0] + + return node + + def visit(self, node: PsAstNode, nest: Nesting) -> None: + match node: + case PsExpression(): + return + + case PsBlock(children): + new_children: list[PsAstNode] = [] + for c in children: + if isinstance(c, PsLoop): + nest.has_inner_loops = True + inner_nest = Nesting(nest.depth + 1) + self.visit(c.body, inner_nest) + + if not inner_nest.has_inner_loops: + # c is the innermost loop + for pragma in self._insertions[-1]: + new_children.append(PsPragma(pragma.text)) + + for pragma in self._insertions[nest.depth]: + new_children.append(PsPragma(pragma.text)) + + new_children.append(c) + node.children = new_children + + case other: + for c in other.children: + self.visit(c, nest) + + +class AddOpenMP: + """Apply OpenMP directives to loop nests. + + This transformation augments the AST with OpenMP pragmas according to the given + `OpenMpParams` configuration. + """ + + def __init__(self, ctx: KernelCreationContext, omp_params: OpenMpConfig) -> None: + pragma_text = "omp" + pragma_text += " parallel" if not omp_params.omit_parallel_construct else "" + pragma_text += f" for schedule({omp_params.schedule})" + + if omp_params.collapse > 0: + pragma_text += f" collapse({str(omp_params.collapse)})" + + self._insert_pragmas = InsertPragmasAtLoops( + ctx, [LoopPragma(pragma_text, omp_params.nesting_depth)] + ) + + def __call__(self, node: PsAstNode) -> PsAstNode: + return self._insert_pragmas(node) diff --git a/src/pystencils/backend/transformations/canonical_clone.py b/src/pystencils/backend/transformations/canonical_clone.py index 7c040d30471c9dbe413d287142199050c7c24a37..b21fd115f98645ff4c8dfb2dd3f72c252282fcf2 100644 --- a/src/pystencils/backend/transformations/canonical_clone.py +++ b/src/pystencils/backend/transformations/canonical_clone.py @@ -12,6 +12,7 @@ from ..ast.structural import ( PsDeclaration, PsAssignment, PsComment, + PsPragma, PsStatement, ) from ..ast.expressions import PsExpression, PsSymbolExpr @@ -73,7 +74,7 @@ class CanonicalClone: ), ) - case PsComment(): + case PsComment() | PsPragma(): return cast(Node_T, node.clone()) case PsDeclaration(lhs, rhs): diff --git a/src/pystencils/config.py b/src/pystencils/config.py index 2d3327104dd6d5989b58a5dfc646f125e7bf94ad..7c49c4c37c0257adde151c3c32680faa50e9a36a 100644 --- a/src/pystencils/config.py +++ b/src/pystencils/config.py @@ -15,19 +15,39 @@ from .types import PsIntegerType, PsNumericType, PsIeeeFloatType from .defaults import DEFAULTS +@dataclass +class OpenMpConfig: + """Parameters controlling kernel parallelization using OpenMP.""" + + nesting_depth: int = 0 + """Nesting depth of the loop that should be parallelized. Must be a nonnegative number.""" + + collapse: int = 0 + """Argument to the OpenMP ``collapse`` clause""" + + schedule: str = "static" + """Argument to the OpenMP ``schedule`` clause""" + + omit_parallel_construct: bool = False + """If set to ``True``, the OpenMP ``parallel`` construct is omitted, producing just a ``#pragma omp for``. + + Use this option only if you intend to wrap the kernel into an external ``#pragma omp parallel`` region. + """ + + @dataclass class CpuOptimConfig: """Configuration for the CPU optimizer. - + If any flag in this configuration is set to a value not supported by the CPU specified in `CreateKernelConfig.target`, an error will be raised. """ - - openmp: bool = False + + openmp: bool | OpenMpConfig = False """Enable OpenMP parallelization. - If set to `True`, the kernel will be parallelized using OpenMP according to the OpenMP settings - given in this configuration. + If set to `True`, the kernel will be parallelized using OpenMP according to the default settings in `OpenMpParams`. + To customize OpenMP parallelization, pass an instance of `OpenMpParams` instead. """ vectorize: bool | VectorizationConfig = False @@ -58,7 +78,7 @@ class CpuOptimConfig: @dataclass class VectorizationConfig: """Configuration for the auto-vectorizer. - + If any flag in this configuration is set to a value not supported by the CPU specified in `CreateKernelConfig.target`, an error will be raised. """ @@ -182,19 +202,27 @@ class CreateKernelConfig: raise PsOptionsError( "Only fields with `field_type == FieldType.INDEXED` can be specified as `index_field`" ) - + # Check optim if self.cpu_optim is not None: if not self.target.is_cpu(): - raise PsOptionsError(f"`cpu_optim` cannot be set for non-CPU target {self.target}") - - if self.cpu_optim.vectorize is not False and not self.target.is_vector_cpu(): - raise PsOptionsError(f"Cannot enable auto-vectorization for non-vector CPU target {self.target}") + raise PsOptionsError( + f"`cpu_optim` cannot be set for non-CPU target {self.target}" + ) + + if ( + self.cpu_optim.vectorize is not False + and not self.target.is_vector_cpu() + ): + raise PsOptionsError( + f"Cannot enable auto-vectorization for non-vector CPU target {self.target}" + ) # Infer JIT if self.jit is None: if self.target.is_cpu(): from .backend.jit import LegacyCpuJit + self.jit = LegacyCpuJit() else: raise NotImplementedError( diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py index 0f6941cf52ca1c3745b5af9a1f0478f269fc8e4a..66c2a0d6c16e291ba5f6315478406668e7e91069 100644 --- a/src/pystencils/kernelcreation.py +++ b/src/pystencils/kernelcreation.py @@ -104,6 +104,7 @@ def create_kernel( # Target-Specific optimizations if config.target.is_cpu(): from .backend.kernelcreation import optimize_cpu + kernel_ast = optimize_cpu(ctx, platform, kernel_ast, config.cpu_optim) erase_anons = EraseAnonymousStructTypes(ctx) diff --git a/tests/nbackend/kernelcreation/test_openmp.py b/tests/nbackend/kernelcreation/test_openmp.py new file mode 100644 index 0000000000000000000000000000000000000000..d7be8eb98cd29bea370bc6279013ef973e621370 --- /dev/null +++ b/tests/nbackend/kernelcreation/test_openmp.py @@ -0,0 +1,61 @@ +import pytest +from pystencils import ( + fields, + Assignment, + create_kernel, + CreateKernelConfig, + CpuOptimConfig, + OpenMpConfig, + Target, +) + +from pystencils.backend.ast import dfs_preorder +from pystencils.backend.ast.structural import PsLoop, PsPragma + + +@pytest.mark.parametrize("nesting_depth", range(3)) +@pytest.mark.parametrize("schedule", ["static", "static,16", "dynamic", "auto"]) +@pytest.mark.parametrize("collapse", range(3)) +@pytest.mark.parametrize("omit_parallel_construct", range(3)) +def test_openmp(nesting_depth, schedule, collapse, omit_parallel_construct): + f, g = fields("f, g: [3D]") + asm = Assignment(f.center(0), g.center(0)) + + omp = OpenMpConfig( + nesting_depth=nesting_depth, + schedule=schedule, + collapse=collapse, + omit_parallel_construct=omit_parallel_construct, + ) + gen_config = CreateKernelConfig( + target=Target.CPU, cpu_optim=CpuOptimConfig(openmp=omp) + ) + + kernel = create_kernel(asm, gen_config) + ast = kernel.body + + def find_omp_pragma(ast) -> PsPragma: + num_loops = 0 + generator = dfs_preorder(ast) + for node in generator: + match node: + case PsLoop(): + num_loops += 1 + case PsPragma(): + loop = next(generator) + assert isinstance(loop, PsLoop) + assert num_loops == nesting_depth + return node + + pytest.fail("No OpenMP pragma found") + + pragma = find_omp_pragma(ast) + tokens = set(pragma.text.split()) + + expected_tokens = {"omp", "for", f"schedule({omp.schedule})"} + if not omp.omit_parallel_construct: + expected_tokens.add("parallel") + if omp.collapse > 0: + expected_tokens.add(f"collapse({omp.collapse})") + + assert tokens == expected_tokens diff --git a/tests/nbackend/test_ast.py b/tests/nbackend/test_ast.py index fb2dd0e04081050138be1c8efc0210f81c999707..09c63a5572f7873648712dbd3279d16f3c342458 100644 --- a/tests/nbackend/test_ast.py +++ b/tests/nbackend/test_ast.py @@ -12,6 +12,7 @@ from pystencils.backend.ast.structural import ( PsBlock, PsConditional, PsComment, + PsPragma, PsLoop, ) from pystencils.types.quick import Fp, Ptr @@ -44,6 +45,7 @@ def test_cloning(): PsConditional( y, PsBlock([PsStatement(x + y)]), PsBlock([PsComment("hello world")]) ), + PsPragma("omp parallel for"), PsLoop( x, y, @@ -54,6 +56,7 @@ def test_cloning(): PsComment("Loop body"), PsAssignment(x, y), PsAssignment(x, y), + PsPragma("#pragma clang loop vectorize(enable)"), PsStatement( PsDeref(PsCast(Ptr(Fp(32)), z)) + PsSubscript(z, one + one + one) diff --git a/tests/nbackend/transformations/test_add_pragmas.py b/tests/nbackend/transformations/test_add_pragmas.py new file mode 100644 index 0000000000000000000000000000000000000000..1d8dd1ded148697dbe7acde3648a292dc5a7fcac --- /dev/null +++ b/tests/nbackend/transformations/test_add_pragmas.py @@ -0,0 +1,54 @@ +import sympy as sp +from itertools import product + +from pystencils import make_slice, fields, Assignment +from pystencils.backend.kernelcreation import ( + KernelCreationContext, + AstFactory, + FullIterationSpace, +) + +from pystencils.backend.ast import dfs_preorder +from pystencils.backend.ast.structural import PsBlock, PsPragma, PsLoop +from pystencils.backend.transformations import InsertPragmasAtLoops, LoopPragma + +def test_insert_pragmas(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + + f, g = fields("f, g: [3D]") + ispace = FullIterationSpace.create_from_slice( + ctx, make_slice[:, :, :], archetype_field=f + ) + ctx.set_iteration_space(ispace) + + stencil = list(product([-1, 0, 1], [-1, 0, 1], [-1, 0, 1])) + loop_body = PsBlock([ + factory.parse_sympy(Assignment(f.center(0), sum(g.neighbors(stencil)))) + ]) + loops = factory.loops_from_ispace(ispace, loop_body) + + pragmas = ( + LoopPragma("omp parallel for", 0), + LoopPragma("some nonsense pragma", 1), + LoopPragma("omp simd", -1), + ) + add_pragmas = InsertPragmasAtLoops(ctx, pragmas) + ast = add_pragmas(loops) + + assert isinstance(ast, PsBlock) + + first_pragma = ast.statements[0] + assert isinstance(first_pragma, PsPragma) + assert first_pragma.text == pragmas[0].text + + assert ast.statements[1] == loops + second_pragma = loops.body.statements[0] + assert isinstance(second_pragma, PsPragma) + assert second_pragma.text == pragmas[1].text + + second_loop = list(dfs_preorder(ast, lambda node: isinstance(node, PsLoop)))[1] + assert isinstance(second_loop, PsLoop) + third_pragma = second_loop.body.statements[0] + assert isinstance(third_pragma, PsPragma) + assert third_pragma.text == pragmas[2].text