Commit 358c44b8 authored by Michael Kuron's avatar Michael Kuron
Browse files

create_staggered_kernel bugfixes to store pygrandchem compatibility

also add a test for subexpressions
parent 3b18545b
......@@ -213,16 +213,15 @@ def create_staggered_kernel(assignments, gpu_exclusive_conditions=False, **kwarg
"""
assert 'iteration_slice' not in kwargs and 'ghost_layers' not in kwargs
subexpressions = ()
if isinstance(assignments, AssignmentCollection):
subexpressions = assignments.subexpressions + [a for a in assignments.main_assignments
if type(a.lhs) is not Field.Access
and not FieldType.is_staggered(a.lhs.field)]
or not FieldType.is_staggered(a.lhs.field)]
assignments = [a for a in assignments.main_assignments if type(a.lhs) is Field.Access
and FieldType.is_staggered(a.lhs.field)]
else:
subexpressions = [a for a in assignments if type(a.lhs) is not Field.Access
and not FieldType.is_staggered(a.lhs.field)]
or not FieldType.is_staggered(a.lhs.field)]
assignments = [a for a in assignments if type(a.lhs) is Field.Access
and FieldType.is_staggered(a.lhs.field)]
if len(set([tuple(a.lhs.field.staggered_stencil) for a in assignments])) != 1:
......@@ -233,7 +232,6 @@ def create_staggered_kernel(assignments, gpu_exclusive_conditions=False, **kwarg
staggered_field = assignments[0].lhs.field
stencil = staggered_field.staggered_stencil
dim = staggered_field.spatial_dimensions
points = staggered_field.index_shape[0]
shape = staggered_field.shape
counters = [LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(dim)]
......@@ -277,9 +275,10 @@ def create_staggered_kernel(assignments, gpu_exclusive_conditions=False, **kwarg
if gpu_exclusive_conditions:
raise NotImplementedError('gpu_exclusive_conditions is not implemented yet')
for d, direction in zip(range(points), stencil):
sp_assignments = [SympyAssignment(assignments[d].lhs, assignments[d].rhs)] + \
[SympyAssignment(s.lhs, s.rhs) for s in subexpressions]
for assignment in assignments:
direction = assignment.lhs.field.staggered_stencil[assignment.lhs.index[0]]
sp_assignments = [SympyAssignment(s.lhs, s.rhs) for s in subexpressions] + \
[SympyAssignment(assignment.lhs, assignment.rhs)]
last_conditional = Conditional(condition(direction), Block(sp_assignments))
final_assignments.append(last_conditional)
......
......@@ -3,7 +3,7 @@ import numpy as np
import sympy as sp
class TestDiffusion:
class TestStaggeredDiffusion:
def _run(self, num_neighbors):
L = (40, 40)
D = 0.066
......@@ -65,3 +65,12 @@ class TestDiffusion:
def test_diffusion_4(self):
self._run(4)
def test_staggered_subexpressions():
dh = ps.create_data_handling((10, 10), periodicity=True, default_target='cpu')
j = dh.add_array('j', values_per_cell=2, field_type=ps.FieldType.STAGGERED)
c = sp.symbols("c")
assignments = [ps.Assignment(j.staggered_access("W"), c),
ps.Assignment(c, 1)]
ps.create_staggered_kernel(assignments, target=dh.default_target).compile()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment