Commit 69e3439d authored by Michael Kuron's avatar Michael Kuron
Browse files

Allow multiple staggered fields in one create_staggered_kernel_2 call

parent c819048d
......@@ -6,7 +6,7 @@ import sympy as sp
from pystencils.assignment import Assignment
from pystencils.astnodes import Block, Conditional, LoopOverCoordinate, SympyAssignment
from pystencils.cpu.vectorization import vectorize
from pystencils.field import Field
from pystencils.field import Field, FieldType
from pystencils.gpucuda.indexing import indexing_creator_from_params
from pystencils.simp.assignment_collection import AssignmentCollection
from pystencils.stencil import direction_string_to_offset, inverse_direction_string
......@@ -314,9 +314,11 @@ def create_staggered_kernel_2(assignments, gpu_exclusive_conditions=False, **kwa
Further index coordinates can be used to store vectors/tensors at each point.
Args:
assignments: a sequence of assignments or an AssignmentCollection with one item for each staggered grid point.
When storing vectors/tensors, the number of items expected is multiplied with the number of
components.
assignments: a sequence of assignments or an AssignmentCollection.
Assignments to staggered field are processed specially, while subexpressions and assignments to
regular fields are passed through to `create_kernel`. Multiple different staggered fields can be
used, but they all need to use the same stencil (i.e. the same number of staggered points) and
shape.
gpu_exclusive_conditions: whether to use nested conditionals instead of multiple conditionals
kwargs: passed directly to create_kernel, iteration_slice and ghost_layers parameters are not allowed
......@@ -327,15 +329,26 @@ def create_staggered_kernel_2(assignments, gpu_exclusive_conditions=False, **kwa
subexpressions = ()
if isinstance(assignments, AssignmentCollection):
assignments = assignments.main_assignments
subexpressions = assignments.subexpressions
if len(set([a.lhs.field for a in assignments])) != 1:
raise ValueError("All assignments need to be made to the same staggered field")
subexpressions = assignments.subexpressions + [a for a in assignments.main_assignments
if type(a.lhs) is not Field.Access and
not FieldType.is_staggered(a.lhs.field)]
assignments = [a for a in assignments.main_assignments if type(a.lhs) is Field.Access and
FieldType.is_staggered(a.lhs.field)]
else:
subexpressions = [a for a in assignments if type(a.lhs) is not Field.Access and
not FieldType.is_staggered(a.lhs.field)]
assignments = [a for a in assignments if type(a.lhs) is Field.Access and
FieldType.is_staggered(a.lhs.field)]
if len(set([tuple(a.lhs.field.staggered_stencil) for a in assignments])) != 1:
raise ValueError("All assignments need to be made to staggered fields with the same stencil")
if len(set([a.lhs.field.shape for a in assignments])) != 1:
raise ValueError("All assignments need to be made to staggered fields with the same shape")
staggered_field = assignments[0].lhs.field
stencil = staggered_field.staggered_stencil
dim = staggered_field.spatial_dimensions
points = staggered_field.index_shape[0]
values_per_point = sp.Mul(*staggered_field.index_shape[1:])
assert len(assignments) == points * values_per_point
shape = staggered_field.shape
counters = [LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(dim)]
......@@ -343,7 +356,7 @@ def create_staggered_kernel_2(assignments, gpu_exclusive_conditions=False, **kwa
# find out whether any of the ghost layers is not needed
common_exclusions = set(["E", "W", "N", "S", "T", "B"][:2 * dim])
for direction in staggered_field.staggered_stencil:
for direction in stencil:
exclusions = set(["E", "W", "N", "S", "T", "B"][:2 * dim])
for elementary_direction in direction:
exclusions.remove(inverse_direction_string(elementary_direction))
......@@ -370,7 +383,7 @@ def create_staggered_kernel_2(assignments, gpu_exclusive_conditions=False, **kwa
offset = direction_string_to_offset(e)
for i, o in enumerate(offset):
if o == 1:
conditions.append(counters[i] < staggered_field.shape[i] - 1)
conditions.append(counters[i] < shape[i] - 1)
elif o == -1:
conditions.append(counters[i] > 0)
return sp.And(*conditions)
......@@ -378,7 +391,7 @@ def create_staggered_kernel_2(assignments, gpu_exclusive_conditions=False, **kwa
if gpu_exclusive_conditions:
raise NotImplementedError('gpu_exclusive_conditions is not implemented yet')
for d, direction in zip(range(points), staggered_field.staggered_stencil):
for d, direction in zip(range(points), stencil):
sp_assignments = [SympyAssignment(assignments[d].lhs, assignments[d].rhs)] + \
[SympyAssignment(s.lhs, s.rhs) for s in subexpressions]
last_conditional = Conditional(condition(direction), Block(sp_assignments))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment