Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import pytest
from pystencils import (
fields,
Assignment,
create_kernel,
CreateKernelConfig,
CpuOptimConfig,
OpenMpConfig,
Target,
)
from pystencils.backend.ast import dfs_preorder
from pystencils.backend.ast.structural import PsLoop, PsPragma
@pytest.mark.parametrize("nesting_depth", range(3))
@pytest.mark.parametrize("schedule", ["static", "static,16", "dynamic", "auto"])
@pytest.mark.parametrize("collapse", range(3))
@pytest.mark.parametrize("omit_parallel_construct", range(3))
def test_openmp(nesting_depth, schedule, collapse, omit_parallel_construct):
f, g = fields("f, g: [3D]")
asm = Assignment(f.center(0), g.center(0))
omp = OpenMpConfig(
nesting_depth=nesting_depth,
schedule=schedule,
collapse=collapse,
omit_parallel_construct=omit_parallel_construct,
)
gen_config = CreateKernelConfig(
target=Target.CPU, cpu_optim=CpuOptimConfig(openmp=omp)
)
kernel = create_kernel(asm, gen_config)
ast = kernel.body
def find_omp_pragma(ast) -> PsPragma:
num_loops = 0
generator = dfs_preorder(ast)
for node in generator:
match node:
case PsLoop():
num_loops += 1
case PsPragma():
loop = next(generator)
assert isinstance(loop, PsLoop)
assert num_loops == nesting_depth
return node
pytest.fail("No OpenMP pragma found")
pragma = find_omp_pragma(ast)
tokens = set(pragma.text.split())
expected_tokens = {"omp", "for", f"schedule({omp.schedule})"}
if not omp.omit_parallel_construct:
expected_tokens.add("parallel")
if omp.collapse > 0:
expected_tokens.add(f"collapse({omp.collapse})")
assert tokens == expected_tokens