From 69e3439d6e34691b4b8e497f4fb4526300e00713 Mon Sep 17 00:00:00 2001
From: Michael Kuron <mkuron@icp.uni-stuttgart.de>
Date: Thu, 28 Nov 2019 17:55:12 +0100
Subject: [PATCH] Allow multiple staggered fields in one
 create_staggered_kernel_2 call

---
 pystencils/kernelcreation.py | 39 ++++++++++++++++++++++++------------
 1 file changed, 26 insertions(+), 13 deletions(-)

diff --git a/pystencils/kernelcreation.py b/pystencils/kernelcreation.py
index 2ea3d4573..e4ade2f1b 100644
--- a/pystencils/kernelcreation.py
+++ b/pystencils/kernelcreation.py
@@ -6,7 +6,7 @@ import sympy as sp
 from pystencils.assignment import Assignment
 from pystencils.astnodes import Block, Conditional, LoopOverCoordinate, SympyAssignment
 from pystencils.cpu.vectorization import vectorize
-from pystencils.field import Field
+from pystencils.field import Field, FieldType
 from pystencils.gpucuda.indexing import indexing_creator_from_params
 from pystencils.simp.assignment_collection import AssignmentCollection
 from pystencils.stencil import direction_string_to_offset, inverse_direction_string
@@ -314,9 +314,11 @@ def create_staggered_kernel_2(assignments, gpu_exclusive_conditions=False, **kwa
     Further index coordinates can be used to store vectors/tensors at each point.
 
     Args:
-        assignments: a sequence of assignments or an AssignmentCollection with one item for each staggered grid point.
-                     When storing vectors/tensors, the number of items expected is multiplied with the number of
-                     components.
+        assignments: a sequence of assignments or an AssignmentCollection.
+                     Assignments to staggered field are processed specially, while subexpressions and assignments to
+                     regular fields are passed through to `create_kernel`. Multiple different staggered fields can be
+                     used, but they all need to use the same stencil (i.e. the same number of staggered points) and
+                     shape.
         gpu_exclusive_conditions: whether to use nested conditionals instead of multiple conditionals
         kwargs: passed directly to create_kernel, iteration_slice and ghost_layers parameters are not allowed
 
@@ -327,15 +329,26 @@ def create_staggered_kernel_2(assignments, gpu_exclusive_conditions=False, **kwa
 
     subexpressions = ()
     if isinstance(assignments, AssignmentCollection):
-        assignments = assignments.main_assignments
-        subexpressions = assignments.subexpressions
-    if len(set([a.lhs.field for a in assignments])) != 1:
-        raise ValueError("All assignments need to be made to the same staggered field")
+        subexpressions = assignments.subexpressions + [a for a in assignments.main_assignments
+                                                       if type(a.lhs) is not Field.Access and
+                                                       not FieldType.is_staggered(a.lhs.field)]
+        assignments = [a for a in assignments.main_assignments if type(a.lhs) is Field.Access and
+                       FieldType.is_staggered(a.lhs.field)]
+    else:
+        subexpressions = [a for a in assignments if type(a.lhs) is not Field.Access and
+                          not FieldType.is_staggered(a.lhs.field)]
+        assignments = [a for a in assignments if type(a.lhs) is Field.Access and
+                       FieldType.is_staggered(a.lhs.field)]
+    if len(set([tuple(a.lhs.field.staggered_stencil) for a in assignments])) != 1:
+        raise ValueError("All assignments need to be made to staggered fields with the same stencil")
+    if len(set([a.lhs.field.shape for a in assignments])) != 1:
+        raise ValueError("All assignments need to be made to staggered fields with the same shape")
+
     staggered_field = assignments[0].lhs.field
+    stencil = staggered_field.staggered_stencil
     dim = staggered_field.spatial_dimensions
     points = staggered_field.index_shape[0]
-    values_per_point = sp.Mul(*staggered_field.index_shape[1:])
-    assert len(assignments) == points * values_per_point
+    shape = staggered_field.shape
 
     counters = [LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(dim)]
 
@@ -343,7 +356,7 @@ def create_staggered_kernel_2(assignments, gpu_exclusive_conditions=False, **kwa
 
     # find out whether any of the ghost layers is not needed
     common_exclusions = set(["E", "W", "N", "S", "T", "B"][:2 * dim])
-    for direction in staggered_field.staggered_stencil:
+    for direction in stencil:
         exclusions = set(["E", "W", "N", "S", "T", "B"][:2 * dim])
         for elementary_direction in direction:
             exclusions.remove(inverse_direction_string(elementary_direction))
@@ -370,7 +383,7 @@ def create_staggered_kernel_2(assignments, gpu_exclusive_conditions=False, **kwa
             offset = direction_string_to_offset(e)
             for i, o in enumerate(offset):
                 if o == 1:
-                    conditions.append(counters[i] < staggered_field.shape[i] - 1)
+                    conditions.append(counters[i] < shape[i] - 1)
                 elif o == -1:
                     conditions.append(counters[i] > 0)
         return sp.And(*conditions)
@@ -378,7 +391,7 @@ def create_staggered_kernel_2(assignments, gpu_exclusive_conditions=False, **kwa
     if gpu_exclusive_conditions:
         raise NotImplementedError('gpu_exclusive_conditions is not implemented yet')
 
-    for d, direction in zip(range(points), staggered_field.staggered_stencil):
+    for d, direction in zip(range(points), stencil):
         sp_assignments = [SympyAssignment(assignments[d].lhs, assignments[d].rhs)] + \
                          [SympyAssignment(s.lhs, s.rhs) for s in subexpressions]
         last_conditional = Conditional(condition(direction), Block(sp_assignments))
-- 
GitLab