Skip to content
Snippets Groups Projects
Commit f6a9b096 authored by Martin Bauer's avatar Martin Bauer
Browse files

Fixes for boundaries in kernels

parent 4179a44c
No related merge requests found
Pipeline #19052 failed with stages
in 4 minutes and 44 seconds
...@@ -2,14 +2,13 @@ import sympy as sp ...@@ -2,14 +2,13 @@ import sympy as sp
from lbmpy.boundaries.boundaryhandling import BoundaryOffsetInfo, LbmWeightInfo from lbmpy.boundaries.boundaryhandling import BoundaryOffsetInfo, LbmWeightInfo
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.astnodes import LoopOverCoordinate from pystencils.astnodes import Block, Conditional, LoopOverCoordinate, SympyAssignment
from pystencils.data_types import cast_func from pystencils.data_types import type_all_numbers
from pystencils.field import Field from pystencils.field import Field
from pystencils.simp.assignment_collection import AssignmentCollection from pystencils.simp.assignment_collection import AssignmentCollection
from pystencils.simp.simplifications import sympy_cse_on_assignment_list from pystencils.simp.simplifications import sympy_cse_on_assignment_list
from pystencils.stencil import inverse_direction from pystencils.stencil import inverse_direction
from pystencils.sympyextensions import fast_subs from pystencils.sympyextensions import fast_subs
from pystencils.astnodes import Block, Conditional
def direction_indices_in_direction(direction, stencil): def direction_indices_in_direction(direction, stencil):
...@@ -34,11 +33,6 @@ def boundary_substitutions(lb_method): ...@@ -34,11 +33,6 @@ def boundary_substitutions(lb_method):
return replacements return replacements
def type_all_numbers(expr, dtype):
substitutions = {a: cast_func(a, dtype) for a in expr.atoms(sp.Number)}
return expr.subs(substitutions)
def border_conditions(direction, field, ghost_layers=1): def border_conditions(direction, field, ghost_layers=1):
abs_direction = tuple(-e if e < 0 else e for e in direction) abs_direction = tuple(-e if e < 0 else e for e in direction)
assert sum(abs_direction) == 1 assert sum(abs_direction) == 1
...@@ -87,59 +81,7 @@ def transformed_boundary_rule(boundary, accessor_func, field, direction_symbol, ...@@ -87,59 +81,7 @@ def transformed_boundary_rule(boundary, accessor_func, field, direction_symbol,
return ac.main_assignments[0].rhs return ac.main_assignments[0].rhs
def read_assignments_with_boundaries(method, pdf_field, boundary_spec, pre_stream_access, read_access): def boundary_conditional(boundary, direction, read_of_next_accessor, lb_method, output_field, cse=False):
stencil = method.stencil
reads = [Assignment(*v) for v in zip(method.pre_collision_pdf_symbols,
read_access(pdf_field, method.stencil))]
for direction, boundary in boundary_spec.items():
dir_indices = direction_indices_in_direction(direction, method.stencil)
border_cond = border_conditions(direction, pdf_field, ghost_layers=1)
for dir_index in dir_indices:
inv_index = stencil.index(inverse_direction(stencil[dir_index]))
value_from_boundary = transformed_boundary_rule(boundary, pre_stream_access, pdf_field, dir_index,
method, index_field=None)
value_without_boundary = reads[inv_index].rhs
new_rhs = sp.Piecewise((value_from_boundary, border_cond),
(value_without_boundary, True))
reads[inv_index] = Assignment(reads[inv_index].lhs, new_rhs)
return AssignmentCollection(reads)
def update_rule_with_boundaries(collision_rule, input_field, output_field,
boundaries, accessor, pre_stream_access):
reads = read_assignments_with_boundaries(collision_rule.method, input_field, boundaries,
pre_stream_access, accessor.read)
write_substitutions = {}
method = collision_rule.method
post_collision_symbols = method.post_collision_pdf_symbols
pre_collision_symbols = method.pre_collision_pdf_symbols
output_accesses = accessor.write(output_field, method.stencil)
input_accesses = accessor.read(input_field, method.stencil)
for (idx, offset), output_access in zip(enumerate(method.stencil), output_accesses):
write_substitutions[post_collision_symbols[idx]] = output_access
result = collision_rule.new_with_substitutions(write_substitutions)
result.subexpressions = reads.all_assignments + result.subexpressions
if 'split_groups' in result.simplification_hints:
all_substitutions = write_substitutions.copy()
for (idx, offset), input_access in zip(enumerate(method.stencil), input_accesses):
all_substitutions[pre_collision_symbols[idx]] = input_access
new_split_groups = []
for split_group in result.simplification_hints['split_groups']:
new_split_groups.append([fast_subs(e, all_substitutions) for e in split_group])
result.simplification_hints['split_groups'] = new_split_groups
return result
def boundary_conditional(boundary, direction, read_of_next_accessor, lb_method, output_field, cse=False, **kwargs):
stencil = lb_method.stencil stencil = lb_method.stencil
tmp_field = output_field.new_field_with_different_name("t") tmp_field = output_field.new_field_with_different_name("t")
...@@ -147,7 +89,7 @@ def boundary_conditional(boundary, direction, read_of_next_accessor, lb_method, ...@@ -147,7 +89,7 @@ def boundary_conditional(boundary, direction, read_of_next_accessor, lb_method,
assignments = [] assignments = []
for direction_idx in dir_indices: for direction_idx in dir_indices:
rule = boundary(tmp_field, direction_idx, lb_method, **kwargs) rule = boundary(tmp_field, direction_idx, lb_method, index_field=None)
boundary_subs = boundary_substitutions(lb_method) boundary_subs = boundary_substitutions(lb_method)
rule = [a.subs(boundary_subs) for a in rule] rule = [a.subs(boundary_subs) for a in rule]
...@@ -165,19 +107,28 @@ def boundary_conditional(boundary, direction, read_of_next_accessor, lb_method, ...@@ -165,19 +107,28 @@ def boundary_conditional(boundary, direction, read_of_next_accessor, lb_method,
border_cond = border_conditions(direction, output_field, ghost_layers=1) border_cond = border_conditions(direction, output_field, ghost_layers=1)
if cse: if cse:
assignments = sympy_cse_on_assignment_list(assignments) assignments = sympy_cse_on_assignment_list(assignments)
assignments = [SympyAssignment(a.lhs, a.rhs) for a in assignments]
return Conditional(border_cond, Block(assignments)) return Conditional(border_cond, Block(assignments))
def update_rule_with_push_boundaries(collision_rule, field, boundary_spec, accessor, read_of_next_accessor): def update_rule_with_push_boundaries(collision_rule, field, boundary_spec, accessor, read_of_next_accessor):
if 'split_groups' in collision_rule.simplification_hints:
raise NotImplementedError("Split is not supported yet")
method = collision_rule.method method = collision_rule.method
loads = [Assignment(a, b) for a, b in zip(method.pre_collision_pdf_symbols, accessor.read(field, method.stencil))] loads = [Assignment(a, b) for a, b in zip(method.pre_collision_pdf_symbols, accessor.read(field, method.stencil))]
stores = [Assignment(a, b) for a, b in stores = [Assignment(a, b) for a, b in
zip(accessor.write(field, method.stencil), method.post_collision_pdf_symbols)] zip(accessor.write(field, method.stencil), method.post_collision_pdf_symbols)]
result = loads + collision_rule.all_assignments + stores result = collision_rule.copy()
result.subexpressions = loads + result.subexpressions
result.main_assignments += stores
for direction, boundary in boundary_spec.items(): for direction, boundary in boundary_spec.items():
cond = boundary_conditional(boundary, direction, read_of_next_accessor, method, field) cond = boundary_conditional(boundary, direction, read_of_next_accessor, method, field)
result.append(cond) result.main_assignments.append(cond)
return result
\ No newline at end of file if 'split_groups' in result.simplification_hints:
substitutions = {b: a for a, b in zip(accessor.write(field, method.stencil), method.post_collision_pdf_symbols)}
new_split_groups = []
for split_group in result.simplification_hints['split_groups']:
new_split_groups.append([fast_subs(e, substitutions) for e in split_group])
result.simplification_hints['split_groups'] = new_split_groups
return result
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment