Skip to content
Snippets Groups Projects
Commit 02965644 authored by Markus Holzer's avatar Markus Holzer
Browse files

Merge branch 'fhennig/pragmas' into 'backend-rework'

Pragmas and OpenMP Support

See merge request !383
parents 7790ec6f 990b2eb9
Branches
Tags
1 merge request!383Pragmas and OpenMP Support
Pipeline #66700 passed with stages
in 5 minutes and 16 seconds
Showing
with 368 additions and 42 deletions
......@@ -6,8 +6,14 @@ from . import fd
from . import stencil as stencil
from .display_utils import get_code_obj, get_code_str, show_code, to_dot
from .field import Field, FieldType, fields
from .types import create_type
from .cache import clear_cache
from .config import CreateKernelConfig, CpuOptimConfig, VectorizationConfig
from .config import (
CreateKernelConfig,
CpuOptimConfig,
VectorizationConfig,
OpenMpConfig,
)
from .kernel_decorator import kernel, kernel_config
from .kernelcreation import create_kernel
from .backend.kernelfunction import KernelFunction
......@@ -34,10 +40,12 @@ __all__ = [
"fields",
"DEFAULTS",
"TypedSymbol",
"create_type",
"make_slice",
"CreateKernelConfig",
"CpuOptimConfig",
"VectorizationConfig",
"OpenMpConfig",
"create_kernel",
"KernelFunction",
"Target",
......
......@@ -5,7 +5,7 @@ from .structural import (
PsAssignment,
PsAstNode,
PsBlock,
PsComment,
PsEmptyLeafMixIn,
PsConditional,
PsDeclaration,
PsExpression,
......@@ -63,7 +63,7 @@ class UndefinedSymbolsCollector:
undefined_vars |= self(branch_false)
return undefined_vars
case PsComment():
case PsEmptyLeafMixIn():
return set()
case unknown:
......@@ -92,11 +92,11 @@ class UndefinedSymbolsCollector:
case (
PsAssignment()
| PsBlock()
| PsComment()
| PsConditional()
| PsExpression()
| PsLoop()
| PsStatement()
| PsEmptyLeafMixIn()
):
return set()
......
......@@ -158,7 +158,7 @@ class PsConstantExpr(PsLeafMixIn, PsExpression):
def __repr__(self) -> str:
return f"PsConstantExpr({repr(self._constant)})"
class PsLiteralExpr(PsLeafMixIn, PsExpression):
__match_args__ = ("literal",)
......@@ -177,7 +177,7 @@ class PsLiteralExpr(PsLeafMixIn, PsExpression):
def clone(self) -> PsLiteralExpr:
return PsLiteralExpr(self._literal)
def structurally_equal(self, other: PsAstNode) -> bool:
if not isinstance(other, PsLiteralExpr):
return False
......
......@@ -4,32 +4,32 @@ from .structural import PsAstNode
def dfs_preorder(
node: PsAstNode, yield_pred: Callable[[PsAstNode], bool] = lambda _: True
node: PsAstNode, filter_pred: Callable[[PsAstNode], bool] = lambda _: True
) -> Generator[PsAstNode, None, None]:
"""Pre-Order depth-first traversal of an abstract syntax tree.
Args:
node: The tree's root node
yield_pred: Filter predicate; a node is only yielded to the caller if `yield_pred(node)` returns True
filter_pred: Filter predicate; a node is only returned to the caller if `yield_pred(node)` returns True
"""
if yield_pred(node):
if filter_pred(node):
yield node
for c in node.children:
yield from dfs_preorder(c, yield_pred)
yield from dfs_preorder(c, filter_pred)
def dfs_postorder(
node: PsAstNode, yield_pred: Callable[[PsAstNode], bool] = lambda _: True
node: PsAstNode, filter_pred: Callable[[PsAstNode], bool] = lambda _: True
) -> Generator[PsAstNode, None, None]:
"""Post-Order depth-first traversal of an abstract syntax tree.
Args:
node: The tree's root node
yield_pred: Filter predicate; a node is only yielded to the caller if `yield_pred(node)` returns True
filter_pred: Filter predicate; a node is only returned to the caller if `yield_pred(node)` returns True
"""
for c in node.children:
yield from dfs_postorder(c, yield_pred)
yield from dfs_postorder(c, filter_pred)
if yield_pred(node):
if filter_pred(node):
yield node
......@@ -307,7 +307,42 @@ class PsConditional(PsAstNode):
assert False, "unreachable code"
class PsComment(PsLeafMixIn, PsAstNode):
class PsEmptyLeafMixIn:
"""Mix-in marking AST leaves that can be treated as empty by the code generator,
such as comments and preprocessor directives."""
pass
class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode):
"""A C/C++ preprocessor pragma.
Example usage: ``PsPragma("omp parallel for")`` translates to ``#pragma omp parallel for``.
Args:
text: The pragma's text, without the ``#pragma ``.
"""
__match_args__ = ("text",)
def __init__(self, text: str) -> None:
self._text = text
@property
def text(self) -> str:
return self._text
def clone(self) -> PsPragma:
return PsPragma(self.text)
def structurally_equal(self, other: PsAstNode) -> bool:
if not isinstance(other, PsPragma):
return False
return self._text == other._text
class PsComment(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode):
__match_args__ = ("lines",)
def __init__(self, text: str) -> None:
......
......@@ -10,6 +10,7 @@ from .ast.structural import (
PsLoop,
PsConditional,
PsComment,
PsPragma,
)
from .ast.expressions import (
......@@ -235,6 +236,9 @@ class CAstPrinter:
lines_list[-1] = lines_list[-1] + " */"
return pc.indent("\n".join(lines_list))
case PsPragma(text):
return pc.indent("#pragma " + text)
case PsSymbolExpr(symbol):
return symbol.name
......@@ -246,7 +250,7 @@ class CAstPrinter:
)
return dtype.create_literal(constant.value)
case PsLiteralExpr(lit):
return lit.text
......
......@@ -3,10 +3,9 @@ from typing import cast
from .context import KernelCreationContext
from ..platforms import GenericCpu
from ..transformations import CanonicalizeSymbols, HoistLoopInvariantDeclarations
from ..ast.structural import PsBlock
from ...config import CpuOptimConfig
from ...config import CpuOptimConfig, OpenMpConfig
def optimize_cpu(
......@@ -16,6 +15,7 @@ def optimize_cpu(
cfg: CpuOptimConfig | None,
) -> PsBlock:
"""Carry out CPU-specific optimizations according to the given configuration."""
from ..transformations import CanonicalizeSymbols, HoistLoopInvariantDeclarations
canonicalize = CanonicalizeSymbols(ctx, True)
kernel_ast = cast(PsBlock, canonicalize(kernel_ast))
......@@ -32,8 +32,12 @@ def optimize_cpu(
if cfg.vectorize is not False:
raise NotImplementedError("Vectorization not implemented yet")
if cfg.openmp:
raise NotImplementedError("OpenMP not implemented yet")
if cfg.openmp is not False:
from ..transformations import AddOpenMP
params = cfg.openmp if isinstance(cfg.openmp, OpenMpConfig) else OpenMpConfig()
add_omp = AddOpenMP(ctx, params)
kernel_ast = cast(PsBlock, add_omp(kernel_ast))
if cfg.use_cacheline_zeroing:
raise NotImplementedError("CL-zeroing not implemented yet")
......
......@@ -22,7 +22,7 @@ from ..ast.structural import (
PsExpression,
PsAssignment,
PsDeclaration,
PsComment,
PsEmptyLeafMixIn,
)
from ..ast.expressions import (
PsArrayAccess,
......@@ -159,7 +159,7 @@ class TypeContext:
f" Constant type: {c.dtype}\n"
f" Target type: {self._target_type}"
)
case PsLiteralExpr(lit):
if not self._compatible(lit.dtype):
raise TypificationError(
......@@ -336,7 +336,7 @@ class Typifier:
self.visit(body)
case PsComment():
case PsEmptyLeafMixIn():
pass
case _:
......
......@@ -4,7 +4,7 @@ from ..types import PsType, constify
class PsLiteral:
"""Representation of literal code.
Instances of this class represent code literals inside the AST.
These literals are not to be confused with C literals; the name `Literal` refers to the fact that
the code generator takes them "literally", printing them as they are.
......@@ -22,22 +22,22 @@ class PsLiteral:
@property
def text(self) -> str:
return self._text
@property
def dtype(self) -> PsType:
return self._dtype
def __str__(self) -> str:
return f"{self._text}: {self._dtype}"
def __repr__(self) -> str:
return f"PsLiteral({repr(self._text)}, {repr(self._dtype)})"
def __eq__(self, other: object) -> bool:
if not isinstance(other, PsLiteral):
return False
return self._text == other._text and self._dtype == other._dtype
def __hash__(self) -> int:
return hash((PsLiteral, self._text, self._dtype))
......@@ -60,6 +60,11 @@ Loop Reshaping Transformations
.. autoclass:: ReshapeLoops
:members:
.. autoclass:: InsertPragmasAtLoops
:members:
.. autoclass:: AddOpenMP
:members:
Code Lowering and Materialization
---------------------------------
......@@ -78,6 +83,7 @@ from .eliminate_constants import EliminateConstants
from .eliminate_branches import EliminateBranches
from .hoist_loop_invariant_decls import HoistLoopInvariantDeclarations
from .reshape_loops import ReshapeLoops
from .add_pragmas import InsertPragmasAtLoops, LoopPragma, AddOpenMP
from .erase_anonymous_structs import EraseAnonymousStructTypes
from .select_functions import SelectFunctions
from .select_intrinsics import MaterializeVectorIntrinsics
......@@ -89,6 +95,9 @@ __all__ = [
"EliminateBranches",
"HoistLoopInvariantDeclarations",
"ReshapeLoops",
"InsertPragmasAtLoops",
"LoopPragma",
"AddOpenMP",
"EraseAnonymousStructTypes",
"SelectFunctions",
"MaterializeVectorIntrinsics",
......
from dataclasses import dataclass
from typing import Sequence
from collections import defaultdict
from ..kernelcreation import KernelCreationContext
from ..ast import PsAstNode
from ..ast.structural import PsBlock, PsLoop, PsPragma
from ..ast.expressions import PsExpression
from ...config import OpenMpConfig
__all__ = ["InsertPragmasAtLoops", "LoopPragma", "AddOpenMP"]
@dataclass
class LoopPragma:
"""A pragma that should be prepended to loops at a certain nesting depth."""
text: str
"""The pragma text, without the ``#pragma ``"""
loop_nesting_depth: int
"""Nesting depth of the loops the pragma should be added to. ``-1`` indicates the innermost loops."""
def __post_init__(self):
if self.loop_nesting_depth < -1:
raise ValueError("Loop nesting depth must be nonnegative or -1.")
@dataclass
class Nesting:
depth: int
has_inner_loops: bool = False
class InsertPragmasAtLoops:
"""Insert pragmas before loops in a loop nest.
This transformation augments the AST with pragma directives which are prepended to loops.
The directives are annotated with the nesting depth of the loops they should be added to,
where ``-1`` indicates the innermost loop.
The relative order of pragmas with the (exact) same nesting depth is preserved;
however, no guarantees are given about the relative order of pragmas inserted at ``-1``
and at the actual depth of the innermost loop.
"""
def __init__(
self, ctx: KernelCreationContext, insertions: Sequence[LoopPragma]
) -> None:
self._ctx = ctx
self._insertions: dict[int, list[LoopPragma]] = defaultdict(list)
for ins in insertions:
self._insertions[ins.loop_nesting_depth].append(ins)
def __call__(self, node: PsAstNode) -> PsAstNode:
is_loop = isinstance(node, PsLoop)
if is_loop:
node = PsBlock([node])
self.visit(node, Nesting(0))
if is_loop and len(node.children) == 1:
node = node.children[0]
return node
def visit(self, node: PsAstNode, nest: Nesting) -> None:
match node:
case PsExpression():
return
case PsBlock(children):
new_children: list[PsAstNode] = []
for c in children:
if isinstance(c, PsLoop):
nest.has_inner_loops = True
inner_nest = Nesting(nest.depth + 1)
self.visit(c.body, inner_nest)
if not inner_nest.has_inner_loops:
# c is the innermost loop
for pragma in self._insertions[-1]:
new_children.append(PsPragma(pragma.text))
for pragma in self._insertions[nest.depth]:
new_children.append(PsPragma(pragma.text))
new_children.append(c)
node.children = new_children
case other:
for c in other.children:
self.visit(c, nest)
class AddOpenMP:
"""Apply OpenMP directives to loop nests.
This transformation augments the AST with OpenMP pragmas according to the given
`OpenMpParams` configuration.
"""
def __init__(self, ctx: KernelCreationContext, omp_params: OpenMpConfig) -> None:
pragma_text = "omp"
pragma_text += " parallel" if not omp_params.omit_parallel_construct else ""
pragma_text += f" for schedule({omp_params.schedule})"
if omp_params.collapse > 0:
pragma_text += f" collapse({str(omp_params.collapse)})"
self._insert_pragmas = InsertPragmasAtLoops(
ctx, [LoopPragma(pragma_text, omp_params.nesting_depth)]
)
def __call__(self, node: PsAstNode) -> PsAstNode:
return self._insert_pragmas(node)
......@@ -12,6 +12,7 @@ from ..ast.structural import (
PsDeclaration,
PsAssignment,
PsComment,
PsPragma,
PsStatement,
)
from ..ast.expressions import PsExpression, PsSymbolExpr
......@@ -73,7 +74,7 @@ class CanonicalClone:
),
)
case PsComment():
case PsComment() | PsPragma():
return cast(Node_T, node.clone())
case PsDeclaration(lhs, rhs):
......
......@@ -15,19 +15,39 @@ from .types import PsIntegerType, PsNumericType, PsIeeeFloatType
from .defaults import DEFAULTS
@dataclass
class OpenMpConfig:
"""Parameters controlling kernel parallelization using OpenMP."""
nesting_depth: int = 0
"""Nesting depth of the loop that should be parallelized. Must be a nonnegative number."""
collapse: int = 0
"""Argument to the OpenMP ``collapse`` clause"""
schedule: str = "static"
"""Argument to the OpenMP ``schedule`` clause"""
omit_parallel_construct: bool = False
"""If set to ``True``, the OpenMP ``parallel`` construct is omitted, producing just a ``#pragma omp for``.
Use this option only if you intend to wrap the kernel into an external ``#pragma omp parallel`` region.
"""
@dataclass
class CpuOptimConfig:
"""Configuration for the CPU optimizer.
If any flag in this configuration is set to a value not supported by the CPU specified
in `CreateKernelConfig.target`, an error will be raised.
"""
openmp: bool = False
openmp: bool | OpenMpConfig = False
"""Enable OpenMP parallelization.
If set to `True`, the kernel will be parallelized using OpenMP according to the OpenMP settings
given in this configuration.
If set to `True`, the kernel will be parallelized using OpenMP according to the default settings in `OpenMpParams`.
To customize OpenMP parallelization, pass an instance of `OpenMpParams` instead.
"""
vectorize: bool | VectorizationConfig = False
......@@ -58,7 +78,7 @@ class CpuOptimConfig:
@dataclass
class VectorizationConfig:
"""Configuration for the auto-vectorizer.
If any flag in this configuration is set to a value not supported by the CPU specified
in `CreateKernelConfig.target`, an error will be raised.
"""
......@@ -182,19 +202,27 @@ class CreateKernelConfig:
raise PsOptionsError(
"Only fields with `field_type == FieldType.INDEXED` can be specified as `index_field`"
)
# Check optim
if self.cpu_optim is not None:
if not self.target.is_cpu():
raise PsOptionsError(f"`cpu_optim` cannot be set for non-CPU target {self.target}")
if self.cpu_optim.vectorize is not False and not self.target.is_vector_cpu():
raise PsOptionsError(f"Cannot enable auto-vectorization for non-vector CPU target {self.target}")
raise PsOptionsError(
f"`cpu_optim` cannot be set for non-CPU target {self.target}"
)
if (
self.cpu_optim.vectorize is not False
and not self.target.is_vector_cpu()
):
raise PsOptionsError(
f"Cannot enable auto-vectorization for non-vector CPU target {self.target}"
)
# Infer JIT
if self.jit is None:
if self.target.is_cpu():
from .backend.jit import LegacyCpuJit
self.jit = LegacyCpuJit()
else:
raise NotImplementedError(
......
......@@ -104,6 +104,7 @@ def create_kernel(
# Target-Specific optimizations
if config.target.is_cpu():
from .backend.kernelcreation import optimize_cpu
kernel_ast = optimize_cpu(ctx, platform, kernel_ast, config.cpu_optim)
erase_anons = EraseAnonymousStructTypes(ctx)
......
import pytest
from pystencils import (
fields,
Assignment,
create_kernel,
CreateKernelConfig,
CpuOptimConfig,
OpenMpConfig,
Target,
)
from pystencils.backend.ast import dfs_preorder
from pystencils.backend.ast.structural import PsLoop, PsPragma
@pytest.mark.parametrize("nesting_depth", range(3))
@pytest.mark.parametrize("schedule", ["static", "static,16", "dynamic", "auto"])
@pytest.mark.parametrize("collapse", range(3))
@pytest.mark.parametrize("omit_parallel_construct", range(3))
def test_openmp(nesting_depth, schedule, collapse, omit_parallel_construct):
f, g = fields("f, g: [3D]")
asm = Assignment(f.center(0), g.center(0))
omp = OpenMpConfig(
nesting_depth=nesting_depth,
schedule=schedule,
collapse=collapse,
omit_parallel_construct=omit_parallel_construct,
)
gen_config = CreateKernelConfig(
target=Target.CPU, cpu_optim=CpuOptimConfig(openmp=omp)
)
kernel = create_kernel(asm, gen_config)
ast = kernel.body
def find_omp_pragma(ast) -> PsPragma:
num_loops = 0
generator = dfs_preorder(ast)
for node in generator:
match node:
case PsLoop():
num_loops += 1
case PsPragma():
loop = next(generator)
assert isinstance(loop, PsLoop)
assert num_loops == nesting_depth
return node
pytest.fail("No OpenMP pragma found")
pragma = find_omp_pragma(ast)
tokens = set(pragma.text.split())
expected_tokens = {"omp", "for", f"schedule({omp.schedule})"}
if not omp.omit_parallel_construct:
expected_tokens.add("parallel")
if omp.collapse > 0:
expected_tokens.add(f"collapse({omp.collapse})")
assert tokens == expected_tokens
......@@ -12,6 +12,7 @@ from pystencils.backend.ast.structural import (
PsBlock,
PsConditional,
PsComment,
PsPragma,
PsLoop,
)
from pystencils.types.quick import Fp, Ptr
......@@ -44,6 +45,7 @@ def test_cloning():
PsConditional(
y, PsBlock([PsStatement(x + y)]), PsBlock([PsComment("hello world")])
),
PsPragma("omp parallel for"),
PsLoop(
x,
y,
......@@ -54,6 +56,7 @@ def test_cloning():
PsComment("Loop body"),
PsAssignment(x, y),
PsAssignment(x, y),
PsPragma("#pragma clang loop vectorize(enable)"),
PsStatement(
PsDeref(PsCast(Ptr(Fp(32)), z))
+ PsSubscript(z, one + one + one)
......
import sympy as sp
from itertools import product
from pystencils import make_slice, fields, Assignment
from pystencils.backend.kernelcreation import (
KernelCreationContext,
AstFactory,
FullIterationSpace,
)
from pystencils.backend.ast import dfs_preorder
from pystencils.backend.ast.structural import PsBlock, PsPragma, PsLoop
from pystencils.backend.transformations import InsertPragmasAtLoops, LoopPragma
def test_insert_pragmas():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
f, g = fields("f, g: [3D]")
ispace = FullIterationSpace.create_from_slice(
ctx, make_slice[:, :, :], archetype_field=f
)
ctx.set_iteration_space(ispace)
stencil = list(product([-1, 0, 1], [-1, 0, 1], [-1, 0, 1]))
loop_body = PsBlock([
factory.parse_sympy(Assignment(f.center(0), sum(g.neighbors(stencil))))
])
loops = factory.loops_from_ispace(ispace, loop_body)
pragmas = (
LoopPragma("omp parallel for", 0),
LoopPragma("some nonsense pragma", 1),
LoopPragma("omp simd", -1),
)
add_pragmas = InsertPragmasAtLoops(ctx, pragmas)
ast = add_pragmas(loops)
assert isinstance(ast, PsBlock)
first_pragma = ast.statements[0]
assert isinstance(first_pragma, PsPragma)
assert first_pragma.text == pragmas[0].text
assert ast.statements[1] == loops
second_pragma = loops.body.statements[0]
assert isinstance(second_pragma, PsPragma)
assert second_pragma.text == pragmas[1].text
second_loop = list(dfs_preorder(ast, lambda node: isinstance(node, PsLoop)))[1]
assert isinstance(second_loop, PsLoop)
third_pragma = second_loop.body.statements[0]
assert isinstance(third_pragma, PsPragma)
assert third_pragma.text == pragmas[2].text
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment