Skip to content
Snippets Groups Projects
Commit a16969bf authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Add omp reduction clauses for reduced symbols

parent 4ae330dc
1 merge request!438Reduction Support
......@@ -219,6 +219,11 @@ class KernelCreationContext:
"""Return an iterable of all symbols listed in the symbol table."""
return self._symbols.values()
@property
def symbols_with_reduction(self) -> dict[PsSymbol, ReductionSymbolProperty]:
"""Return a dictionary holding symbols and their reduction property."""
return self._symbols_with_reduction
# Fields and Arrays
@property
......
......@@ -10,6 +10,8 @@ from ..ast import PsAstNode
from ..ast.structural import PsBlock, PsLoop, PsPragma
from ..ast.expressions import PsExpression
from ...types import PsScalarType
if TYPE_CHECKING:
from ...codegen.config import OpenMpConfig
......@@ -110,6 +112,13 @@ class AddOpenMP:
pragma_text += " parallel" if not omp_params.omit_parallel_construct else ""
pragma_text += f" for schedule({omp_params.schedule})"
if bool(ctx.symbols_with_reduction):
for symbol, reduction in ctx.symbols_with_reduction.items():
if isinstance(symbol.dtype, PsScalarType):
pragma_text += f" reduction({reduction.op}: {symbol.name})"
else:
NotImplementedError("OMP: Reductions for non-scalar data types are not supported yet.")
if omp_params.num_threads is not None:
pragma_text += f" num_threads({str(omp_params.num_threads)})"
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment