Commit f6a9b096 authored by Martin Bauer's avatar Martin Bauer
Browse files

Fixes for boundaries in kernels

parent 4179a44c
......@@ -2,14 +2,13 @@ import sympy as sp
from lbmpy.boundaries.boundaryhandling import BoundaryOffsetInfo, LbmWeightInfo
from pystencils.assignment import Assignment
from pystencils.astnodes import LoopOverCoordinate
from pystencils.data_types import cast_func
from pystencils.astnodes import Block, Conditional, LoopOverCoordinate, SympyAssignment
from pystencils.data_types import type_all_numbers
from pystencils.field import Field
from pystencils.simp.assignment_collection import AssignmentCollection
from pystencils.simp.simplifications import sympy_cse_on_assignment_list
from pystencils.stencil import inverse_direction
from pystencils.sympyextensions import fast_subs
from pystencils.astnodes import Block, Conditional
def direction_indices_in_direction(direction, stencil):
......@@ -34,11 +33,6 @@ def boundary_substitutions(lb_method):
return replacements
def type_all_numbers(expr, dtype):
substitutions = {a: cast_func(a, dtype) for a in expr.atoms(sp.Number)}
return expr.subs(substitutions)
def border_conditions(direction, field, ghost_layers=1):
abs_direction = tuple(-e if e < 0 else e for e in direction)
assert sum(abs_direction) == 1
......@@ -87,59 +81,7 @@ def transformed_boundary_rule(boundary, accessor_func, field, direction_symbol,
return ac.main_assignments[0].rhs
def read_assignments_with_boundaries(method, pdf_field, boundary_spec, pre_stream_access, read_access):
stencil = method.stencil
reads = [Assignment(*v) for v in zip(method.pre_collision_pdf_symbols,
read_access(pdf_field, method.stencil))]
for direction, boundary in boundary_spec.items():
dir_indices = direction_indices_in_direction(direction, method.stencil)
border_cond = border_conditions(direction, pdf_field, ghost_layers=1)
for dir_index in dir_indices:
inv_index = stencil.index(inverse_direction(stencil[dir_index]))
value_from_boundary = transformed_boundary_rule(boundary, pre_stream_access, pdf_field, dir_index,
method, index_field=None)
value_without_boundary = reads[inv_index].rhs
new_rhs = sp.Piecewise((value_from_boundary, border_cond),
(value_without_boundary, True))
reads[inv_index] = Assignment(reads[inv_index].lhs, new_rhs)
return AssignmentCollection(reads)
def update_rule_with_boundaries(collision_rule, input_field, output_field,
boundaries, accessor, pre_stream_access):
reads = read_assignments_with_boundaries(collision_rule.method, input_field, boundaries,
pre_stream_access, accessor.read)
write_substitutions = {}
method = collision_rule.method
post_collision_symbols = method.post_collision_pdf_symbols
pre_collision_symbols = method.pre_collision_pdf_symbols
output_accesses = accessor.write(output_field, method.stencil)
input_accesses = accessor.read(input_field, method.stencil)
for (idx, offset), output_access in zip(enumerate(method.stencil), output_accesses):
write_substitutions[post_collision_symbols[idx]] = output_access
result = collision_rule.new_with_substitutions(write_substitutions)
result.subexpressions = reads.all_assignments + result.subexpressions
if 'split_groups' in result.simplification_hints:
all_substitutions = write_substitutions.copy()
for (idx, offset), input_access in zip(enumerate(method.stencil), input_accesses):
all_substitutions[pre_collision_symbols[idx]] = input_access
new_split_groups = []
for split_group in result.simplification_hints['split_groups']:
new_split_groups.append([fast_subs(e, all_substitutions) for e in split_group])
result.simplification_hints['split_groups'] = new_split_groups
return result
def boundary_conditional(boundary, direction, read_of_next_accessor, lb_method, output_field, cse=False, **kwargs):
def boundary_conditional(boundary, direction, read_of_next_accessor, lb_method, output_field, cse=False):
stencil = lb_method.stencil
tmp_field = output_field.new_field_with_different_name("t")
......@@ -147,7 +89,7 @@ def boundary_conditional(boundary, direction, read_of_next_accessor, lb_method,
assignments = []
for direction_idx in dir_indices:
rule = boundary(tmp_field, direction_idx, lb_method, **kwargs)
rule = boundary(tmp_field, direction_idx, lb_method, index_field=None)
boundary_subs = boundary_substitutions(lb_method)
rule = [a.subs(boundary_subs) for a in rule]
......@@ -165,19 +107,28 @@ def boundary_conditional(boundary, direction, read_of_next_accessor, lb_method,
border_cond = border_conditions(direction, output_field, ghost_layers=1)
if cse:
assignments = sympy_cse_on_assignment_list(assignments)
assignments = [SympyAssignment(a.lhs, a.rhs) for a in assignments]
return Conditional(border_cond, Block(assignments))
def update_rule_with_push_boundaries(collision_rule, field, boundary_spec, accessor, read_of_next_accessor):
if 'split_groups' in collision_rule.simplification_hints:
raise NotImplementedError("Split is not supported yet")
method = collision_rule.method
loads = [Assignment(a, b) for a, b in zip(method.pre_collision_pdf_symbols, accessor.read(field, method.stencil))]
stores = [Assignment(a, b) for a, b in
zip(accessor.write(field, method.stencil), method.post_collision_pdf_symbols)]
result = loads + collision_rule.all_assignments + stores
result = collision_rule.copy()
result.subexpressions = loads + result.subexpressions
result.main_assignments += stores
for direction, boundary in boundary_spec.items():
cond = boundary_conditional(boundary, direction, read_of_next_accessor, method, field)
result.append(cond)
return result
\ No newline at end of file
result.main_assignments.append(cond)
if 'split_groups' in result.simplification_hints:
substitutions = {b: a for a, b in zip(accessor.write(field, method.stencil), method.post_collision_pdf_symbols)}
new_split_groups = []
for split_group in result.simplification_hints['split_groups']:
new_split_groups.append([fast_subs(e, substitutions) for e in split_group])
result.simplification_hints['split_groups'] = new_split_groups
return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment