From 23a8f33d6c1b8ba04d8b295a7eaef4c8164e5ece Mon Sep 17 00:00:00 2001
From: Michael Kuron <m.kuron@gmx.de>
Date: Sat, 30 Nov 2019 23:47:05 +0100
Subject: [PATCH] create_staggered_kernel: re-add gpu_exclusive_conditions

---
 pystencils/kernelcreation.py | 33 ++++++++++++++++++++++++++++-----
 1 file changed, 28 insertions(+), 5 deletions(-)

diff --git a/pystencils/kernelcreation.py b/pystencils/kernelcreation.py
index b968d4c62..16876bd0e 100644
--- a/pystencils/kernelcreation.py
+++ b/pystencils/kernelcreation.py
@@ -191,7 +191,7 @@ def create_indexed_kernel(assignments,
         raise ValueError("Unknown target %s. Has to be either 'cpu' or 'gpu'" % (target,))
 
 
-def create_staggered_kernel(assignments, gpu_exclusive_conditions=False, **kwargs):
+def create_staggered_kernel(assignments, target='cpu', gpu_exclusive_conditions=False, **kwargs):
     """Kernel that updates a staggered field.
 
     .. image:: /img/staggered_grid.svg
@@ -205,7 +205,9 @@ def create_staggered_kernel(assignments, gpu_exclusive_conditions=False, **kwarg
                      regular fields are passed through to `create_kernel`. Multiple different staggered fields can be
                      used, but they all need to use the same stencil (i.e. the same number of staggered points) and
                      shape.
-        gpu_exclusive_conditions: whether to use nested conditionals instead of multiple conditionals
+        target: 'cpu', 'llvm' or 'gpu'
+        gpu_exclusive_conditions: disable the use of multiple conditionals inside the loop. The outer layers are then
+                                  handled in an else branch.
         kwargs: passed directly to create_kernel, iteration_slice and ghost_layers parameters are not allowed
 
     Returns:
@@ -277,7 +279,28 @@ def create_staggered_kernel(assignments, gpu_exclusive_conditions=False, **kwarg
         return sp.And(*conditions)
 
     if gpu_exclusive_conditions:
-        raise NotImplementedError('gpu_exclusive_conditions is not implemented yet')
+        outer_assignment = None
+        for assignment in assignments:
+            direction = assignment.lhs.field.staggered_stencil[assignment.lhs.index[0]]
+            assignment = SympyAssignment(assignment.lhs, assignment.rhs)
+            outer_assignment = Conditional(condition(direction), Block([assignment]), outer_assignment)
+
+        inner_assignment = []
+        for assignment in assignments:
+            direction = assignment.lhs.field.staggered_stencil[assignment.lhs.index[0]]
+            inner_assignment.append(SympyAssignment(assignment.lhs, assignment.rhs))
+        last_conditional = Conditional(sp.And(*[condition(d) for d in stencil]),
+                                       Block(inner_assignment), outer_assignment)
+        final_assignments = [s for s in subexpressions if not hasattr(s, 'lhs')] + \
+                            [SympyAssignment(s.lhs, s.rhs) for s in subexpressions if hasattr(s, 'lhs')] + \
+                            [last_conditional]
+
+        if target == 'cpu':
+            from pystencils.cpu import create_kernel as create_kernel_cpu
+            ast = create_kernel_cpu(final_assignments, ghost_layers=ghost_layers, **kwargs)
+        else:
+            ast = create_kernel(final_assignments, ghost_layers=ghost_layers, target=target, **kwargs)
+        return ast
 
     for assignment in assignments:
         direction = assignment.lhs.field.staggered_stencil[assignment.lhs.index[0]]
@@ -290,6 +313,6 @@ def create_staggered_kernel(assignments, gpu_exclusive_conditions=False, **kwarg
     remove_start_conditional = any([gl[0] == 0 for gl in ghost_layers])
     prepend_optimizations = [lambda ast: remove_conditionals_in_staggered_kernel(ast, remove_start_conditional),
                              move_constants_before_loop]
-    ast = create_kernel(final_assignments, ghost_layers=ghost_layers, cpu_prepend_optimizations=prepend_optimizations,
-                        **kwargs)
+    ast = create_kernel(final_assignments, ghost_layers=ghost_layers, target=target,
+                        cpu_prepend_optimizations=prepend_optimizations, **kwargs)
     return ast
-- 
GitLab