From d7332d595078d4be8832634280788a8ce5475fba Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Fri, 5 Apr 2019 07:28:23 +0200
Subject: [PATCH] OpenMP support for staggered kernels

---
 pystencils/cpu/kernelcreation.py | 63 +++++++++++++++++---------------
 pystencils/kernelcreation.py     | 10 ++++-
 2 files changed, 43 insertions(+), 30 deletions(-)

diff --git a/pystencils/cpu/kernelcreation.py b/pystencils/cpu/kernelcreation.py
index a3f0211..6b94b6d 100644
--- a/pystencils/cpu/kernelcreation.py
+++ b/pystencils/cpu/kernelcreation.py
@@ -3,7 +3,7 @@ from functools import partial
 from pystencils.astnodes import SympyAssignment, Block, LoopOverCoordinate, KernelFunction
 from pystencils.transformations import resolve_buffer_accesses, resolve_field_accesses, make_loop_over_domain, \
     add_types, get_optimal_loop_ordering, parse_base_pointer_info, move_constants_before_loop, \
-    split_inner_loop, get_base_buffer_index
+    split_inner_loop, get_base_buffer_index, filtered_tree_iteration
 from pystencils.data_types import TypedSymbol, BasicType, StructType, create_type
 from pystencils.field import Field, FieldType
 import pystencils.astnodes as ast
@@ -152,7 +152,7 @@ def create_indexed_kernel(assignments: AssignmentOrAstNodeList, index_fields, fu
     return ast_node
 
 
-def add_openmp(ast_node, schedule="static", num_threads=True, collapse=None):
+def add_openmp(ast_node, schedule="static", num_threads=True, collapse=None, assume_single_outer_loop=True):
     """Parallelize the outer loop with OpenMP.
 
     Args:
@@ -160,6 +160,8 @@ def add_openmp(ast_node, schedule="static", num_threads=True, collapse=None):
         schedule: OpenMP scheduling policy e.g. 'static' or 'dynamic'
         num_threads: explicitly specify number of threads
         collapse: number of nested loops to include in parallel region (see OpenMP collapse)
+        assume_single_outer_loop: if True an exception is raised if multiple outer loops are detected for all but
+                                  optimized staggered kernels the single-outer-loop assumption should be true
     """
     if not num_threads:
         return
@@ -170,31 +172,34 @@ def add_openmp(ast_node, schedule="static", num_threads=True, collapse=None):
     wrapper_block = ast.PragmaBlock('#pragma omp parallel' + threads_clause, body.take_child_nodes())
     body.append(wrapper_block)
 
-    outer_loops = [l for l in body.atoms(ast.LoopOverCoordinate) if l.is_outermost_loop]
+    outer_loops = [l for l in filtered_tree_iteration(body, LoopOverCoordinate, stop_type=SympyAssignment)
+                   if l.is_outermost_loop]
     assert outer_loops, "No outer loop found"
-    assert len(outer_loops) <= 1, "More than one outer loop found. Not clear where to put OpenMP pragma."
-    loop_to_parallelize = outer_loops[0]
-    try:
-        loop_range = int(loop_to_parallelize.stop - loop_to_parallelize.start)
-    except TypeError:
-        loop_range = None
-
-    if num_threads is None:
-        import multiprocessing
-        num_threads = multiprocessing.cpu_count()
-
-    if loop_range is not None and loop_range < num_threads and not collapse:
-        contained_loops = [l for l in loop_to_parallelize.body.args if isinstance(l, LoopOverCoordinate)]
-        if len(contained_loops) == 1:
-            contained_loop = contained_loops[0]
-            try:
-                contained_loop_range = int(contained_loop.stop - contained_loop.start)
-                if contained_loop_range > loop_range:
-                    loop_to_parallelize = contained_loop
-            except TypeError:
-                pass
-
-    prefix = "#pragma omp for schedule(%s)" % (schedule,)
-    if collapse:
-        prefix += " collapse(%d)" % (collapse, )
-    loop_to_parallelize.prefix_lines.append(prefix)
+    if assume_single_outer_loop and len(outer_loops) > 1:
+        raise ValueError("More than one outer loop found, only one outer loop expected")
+
+    for loop_to_parallelize in outer_loops:
+        try:
+            loop_range = int(loop_to_parallelize.stop - loop_to_parallelize.start)
+        except TypeError:
+            loop_range = None
+
+        if num_threads is None:
+            import multiprocessing
+            num_threads = multiprocessing.cpu_count()
+
+        if loop_range is not None and loop_range < num_threads and not collapse:
+            contained_loops = [l for l in loop_to_parallelize.body.args if isinstance(l, LoopOverCoordinate)]
+            if len(contained_loops) == 1:
+                contained_loop = contained_loops[0]
+                try:
+                    contained_loop_range = int(contained_loop.stop - contained_loop.start)
+                    if contained_loop_range > loop_range:
+                        loop_to_parallelize = contained_loop
+                except TypeError:
+                    pass
+
+        prefix = "#pragma omp for schedule(%s)" % (schedule,)
+        if collapse:
+            prefix += " collapse(%d)" % (collapse, )
+        loop_to_parallelize.prefix_lines.append(prefix)
diff --git a/pystencils/kernelcreation.py b/pystencils/kernelcreation.py
index 024ee4a..e8de7d2 100644
--- a/pystencils/kernelcreation.py
+++ b/pystencils/kernelcreation.py
@@ -245,13 +245,21 @@ def create_staggered_kernel(staggered_field, expressions, subexpressions=(), tar
     cpu_vectorize_info = kwargs.get('cpu_vectorize_info', None)
     if cpu_vectorize_info:
         del kwargs['cpu_vectorize_info']
+    openmp = kwargs.get('cpu_openmp', None)
+    if openmp:
+        del kwargs['cpu_openmp']
+
     ast = create_kernel(final_assignments, ghost_layers=ghost_layers, target=target, **kwargs)
 
     if target == 'cpu':
         remove_conditionals_in_staggered_kernel(ast)
         move_constants_before_loop(ast)
+        omp_collapse = None
         if blocking:
-            loop_blocking(ast, blocking)
+            omp_collapse = loop_blocking(ast, blocking)
+        if openmp:
+            from pystencils.cpu import add_openmp
+            add_openmp(ast, num_threads=openmp, collapse=omp_collapse, assume_single_outer_loop=False)
         if cpu_vectorize_info is True:
             vectorize(ast)
         elif isinstance(cpu_vectorize_info, dict):
-- 
GitLab