diff --git a/pystencils/kernelcreation.py b/pystencils/kernelcreation.py index b5cfc5a4eda444f615474503ecbb79800c8e3969..8feadc001416be597e2c227aeb64f701813af293 100644 --- a/pystencils/kernelcreation.py +++ b/pystencils/kernelcreation.py @@ -213,16 +213,15 @@ def create_staggered_kernel(assignments, gpu_exclusive_conditions=False, **kwarg """ assert 'iteration_slice' not in kwargs and 'ghost_layers' not in kwargs - subexpressions = () if isinstance(assignments, AssignmentCollection): subexpressions = assignments.subexpressions + [a for a in assignments.main_assignments if type(a.lhs) is not Field.Access - and not FieldType.is_staggered(a.lhs.field)] + or not FieldType.is_staggered(a.lhs.field)] assignments = [a for a in assignments.main_assignments if type(a.lhs) is Field.Access and FieldType.is_staggered(a.lhs.field)] else: subexpressions = [a for a in assignments if type(a.lhs) is not Field.Access - and not FieldType.is_staggered(a.lhs.field)] + or not FieldType.is_staggered(a.lhs.field)] assignments = [a for a in assignments if type(a.lhs) is Field.Access and FieldType.is_staggered(a.lhs.field)] if len(set([tuple(a.lhs.field.staggered_stencil) for a in assignments])) != 1: @@ -233,7 +232,6 @@ def create_staggered_kernel(assignments, gpu_exclusive_conditions=False, **kwarg staggered_field = assignments[0].lhs.field stencil = staggered_field.staggered_stencil dim = staggered_field.spatial_dimensions - points = staggered_field.index_shape[0] shape = staggered_field.shape counters = [LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(dim)] @@ -277,9 +275,10 @@ def create_staggered_kernel(assignments, gpu_exclusive_conditions=False, **kwarg if gpu_exclusive_conditions: raise NotImplementedError('gpu_exclusive_conditions is not implemented yet') - for d, direction in zip(range(points), stencil): - sp_assignments = [SympyAssignment(assignments[d].lhs, assignments[d].rhs)] + \ - [SympyAssignment(s.lhs, s.rhs) for s in subexpressions] + for assignment in assignments: + direction = assignment.lhs.field.staggered_stencil[assignment.lhs.index[0]] + sp_assignments = [SympyAssignment(s.lhs, s.rhs) for s in subexpressions] + \ + [SympyAssignment(assignment.lhs, assignment.rhs)] last_conditional = Conditional(condition(direction), Block(sp_assignments)) final_assignments.append(last_conditional) diff --git a/pystencils_tests/test_staggered_diffusion.py b/pystencils_tests/test_staggered_kernel.py similarity index 84% rename from pystencils_tests/test_staggered_diffusion.py rename to pystencils_tests/test_staggered_kernel.py index 4d0fb294e8536d1d827961b1feb043ee75049514..0937dc98e6ed6eaa2d35f974b12fe411a3bafe5d 100644 --- a/pystencils_tests/test_staggered_diffusion.py +++ b/pystencils_tests/test_staggered_kernel.py @@ -3,7 +3,7 @@ import numpy as np import sympy as sp -class TestDiffusion: +class TestStaggeredDiffusion: def _run(self, num_neighbors): L = (40, 40) D = 0.066 @@ -65,3 +65,12 @@ class TestDiffusion: def test_diffusion_4(self): self._run(4) + + +def test_staggered_subexpressions(): + dh = ps.create_data_handling((10, 10), periodicity=True, default_target='cpu') + j = dh.add_array('j', values_per_cell=2, field_type=ps.FieldType.STAGGERED) + c = sp.symbols("c") + assignments = [ps.Assignment(j.staggered_access("W"), c), + ps.Assignment(c, 1)] + ps.create_staggered_kernel(assignments, target=dh.default_target).compile()