From 358c44b811517e2ba9d79510a30cb22497a52246 Mon Sep 17 00:00:00 2001
From: Michael Kuron <m.kuron@gmx.de>
Date: Fri, 29 Nov 2019 11:34:26 +0100
Subject: [PATCH] create_staggered_kernel bugfixes to store pygrandchem
 compatibility

also add a test for subexpressions
---
 pystencils/kernelcreation.py                        | 13 ++++++-------
 ...ggered_diffusion.py => test_staggered_kernel.py} | 11 ++++++++++-
 2 files changed, 16 insertions(+), 8 deletions(-)
 rename pystencils_tests/{test_staggered_diffusion.py => test_staggered_kernel.py} (84%)

diff --git a/pystencils/kernelcreation.py b/pystencils/kernelcreation.py
index b5cfc5a4e..8feadc001 100644
--- a/pystencils/kernelcreation.py
+++ b/pystencils/kernelcreation.py
@@ -213,16 +213,15 @@ def create_staggered_kernel(assignments, gpu_exclusive_conditions=False, **kwarg
     """
     assert 'iteration_slice' not in kwargs and 'ghost_layers' not in kwargs
 
-    subexpressions = ()
     if isinstance(assignments, AssignmentCollection):
         subexpressions = assignments.subexpressions + [a for a in assignments.main_assignments
                                                        if type(a.lhs) is not Field.Access
-                                                       and not FieldType.is_staggered(a.lhs.field)]
+                                                       or not FieldType.is_staggered(a.lhs.field)]
         assignments = [a for a in assignments.main_assignments if type(a.lhs) is Field.Access
                        and FieldType.is_staggered(a.lhs.field)]
     else:
         subexpressions = [a for a in assignments if type(a.lhs) is not Field.Access
-                          and not FieldType.is_staggered(a.lhs.field)]
+                          or not FieldType.is_staggered(a.lhs.field)]
         assignments = [a for a in assignments if type(a.lhs) is Field.Access 
                        and FieldType.is_staggered(a.lhs.field)]
     if len(set([tuple(a.lhs.field.staggered_stencil) for a in assignments])) != 1:
@@ -233,7 +232,6 @@ def create_staggered_kernel(assignments, gpu_exclusive_conditions=False, **kwarg
     staggered_field = assignments[0].lhs.field
     stencil = staggered_field.staggered_stencil
     dim = staggered_field.spatial_dimensions
-    points = staggered_field.index_shape[0]
     shape = staggered_field.shape
 
     counters = [LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(dim)]
@@ -277,9 +275,10 @@ def create_staggered_kernel(assignments, gpu_exclusive_conditions=False, **kwarg
     if gpu_exclusive_conditions:
         raise NotImplementedError('gpu_exclusive_conditions is not implemented yet')
 
-    for d, direction in zip(range(points), stencil):
-        sp_assignments = [SympyAssignment(assignments[d].lhs, assignments[d].rhs)] + \
-                         [SympyAssignment(s.lhs, s.rhs) for s in subexpressions]
+    for assignment in assignments:
+        direction = assignment.lhs.field.staggered_stencil[assignment.lhs.index[0]]
+        sp_assignments = [SympyAssignment(s.lhs, s.rhs) for s in subexpressions] + \
+                         [SympyAssignment(assignment.lhs, assignment.rhs)]
         last_conditional = Conditional(condition(direction), Block(sp_assignments))
         final_assignments.append(last_conditional)
 
diff --git a/pystencils_tests/test_staggered_diffusion.py b/pystencils_tests/test_staggered_kernel.py
similarity index 84%
rename from pystencils_tests/test_staggered_diffusion.py
rename to pystencils_tests/test_staggered_kernel.py
index 4d0fb294e..0937dc98e 100644
--- a/pystencils_tests/test_staggered_diffusion.py
+++ b/pystencils_tests/test_staggered_kernel.py
@@ -3,7 +3,7 @@ import numpy as np
 import sympy as sp
 
 
-class TestDiffusion:
+class TestStaggeredDiffusion:
     def _run(self, num_neighbors):
         L = (40, 40)
         D = 0.066
@@ -65,3 +65,12 @@ class TestDiffusion:
 
     def test_diffusion_4(self):
         self._run(4)
+
+
+def test_staggered_subexpressions():
+    dh = ps.create_data_handling((10, 10), periodicity=True, default_target='cpu')
+    j = dh.add_array('j', values_per_cell=2, field_type=ps.FieldType.STAGGERED)
+    c = sp.symbols("c")
+    assignments = [ps.Assignment(j.staggered_access("W"), c),
+                   ps.Assignment(c, 1)]
+    ps.create_staggered_kernel(assignments, target=dh.default_target).compile()
-- 
GitLab