diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index bcb3a53f8e92461f71acd7da8700be3abfad97ae..f3ee646a59432b3f9aed1d5fc1de731b7139babd 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -219,6 +219,11 @@ class KernelCreationContext: """Return an iterable of all symbols listed in the symbol table.""" return self._symbols.values() + @property + def symbols_with_reduction(self) -> dict[PsSymbol, ReductionSymbolProperty]: + """Return a dictionary holding symbols and their reduction property.""" + return self._symbols_with_reduction + # Fields and Arrays @property diff --git a/src/pystencils/backend/transformations/add_pragmas.py b/src/pystencils/backend/transformations/add_pragmas.py index 78e721f3850e0075a8079131b84ae558abb50062..6d72e1550b76f4122db060cf2af6492283bdcefe 100644 --- a/src/pystencils/backend/transformations/add_pragmas.py +++ b/src/pystencils/backend/transformations/add_pragmas.py @@ -10,6 +10,8 @@ from ..ast import PsAstNode from ..ast.structural import PsBlock, PsLoop, PsPragma from ..ast.expressions import PsExpression +from ...types import PsScalarType + if TYPE_CHECKING: from ...codegen.config import OpenMpConfig @@ -110,6 +112,13 @@ class AddOpenMP: pragma_text += " parallel" if not omp_params.omit_parallel_construct else "" pragma_text += f" for schedule({omp_params.schedule})" + if bool(ctx.symbols_with_reduction): + for symbol, reduction in ctx.symbols_with_reduction.items(): + if isinstance(symbol.dtype, PsScalarType): + pragma_text += f" reduction({reduction.op}: {symbol.name})" + else: + NotImplementedError("OMP: Reductions for non-scalar data types are not supported yet.") + if omp_params.num_threads is not None: pragma_text += f" num_threads({str(omp_params.num_threads)})"