inkernel.py 1.98 KB
Newer Older
Martin Bauer's avatar
Martin Bauer committed
1
2
import sympy as sp
from pystencils import Field, TypedSymbol
Martin Bauer's avatar
Martin Bauer committed
3
from pystencils.integer_functions import bitwise_and
Martin Bauer's avatar
Martin Bauer committed
4
from pystencils.boundaries.boundaryhandling import FlagInterface
Martin Bauer's avatar
Martin Bauer committed
5
from pystencils.data_types import create_type
Martin Bauer's avatar
Martin Bauer committed
6
7


Martin Bauer's avatar
Martin Bauer committed
8
def add_neumann_boundary(eqs, fields, flag_field, boundary_flag="neumann_flag", inverse_flag=False):
Martin Bauer's avatar
Martin Bauer committed
9
10
11
12
13
    """
    Replaces all neighbor accesses by flag field guarded accesses.
    If flag in neighboring cell is set, the center value is used instead
    :param eqs: list of equations containing field accesses to direct neighbors
    :param fields: fields for which the Neumann boundary should be applied
Martin Bauer's avatar
Martin Bauer committed
14
    :param flag_field: integer field marking boundary cells
Martin Bauer's avatar
Martin Bauer committed
15
    :param boundary_flag: if flag field has value 'boundary_flag' (no bit operations yet)
Martin Bauer's avatar
Martin Bauer committed
16
                          the cell is assumed to be boundary
Martin Bauer's avatar
Martin Bauer committed
17
    :param inverse_flag: if true, boundary cells are where flag field has not the value of boundary_flag
Martin Bauer's avatar
Martin Bauer committed
18
19
20
21
22
23
    :return: list of equations with guarded field accesses
    """
    if not hasattr(fields, "__len__"):
        fields = [fields]
    fields = set(fields)

Martin Bauer's avatar
Martin Bauer committed
24
25
    if type(boundary_flag) is str:
        boundary_flag = TypedSymbol(boundary_flag, dtype=create_type(FlagInterface.FLAG_DTYPE))
Martin Bauer's avatar
Martin Bauer committed
26
27
28
29
30
31
32
33
34
35
36

    substitutions = {}
    for eq in eqs:
        for fa in eq.atoms(Field.Access):
            if fa.field not in fields:
                continue
            if not all(offset in (-1, 0, 1) for offset in fa.offsets):
                raise ValueError("Works only for single neighborhood stencils")
            if all(offset == 0 for offset in fa.offsets):
                continue

Martin Bauer's avatar
Martin Bauer committed
37
38
            if inverse_flag:
                condition = sp.Eq(bitwise_and(flag_field[tuple(fa.offsets)], boundary_flag), 0)
Martin Bauer's avatar
Martin Bauer committed
39
            else:
Martin Bauer's avatar
Martin Bauer committed
40
                condition = sp.Ne(bitwise_and(flag_field[tuple(fa.offsets)], boundary_flag), 0)
Martin Bauer's avatar
Martin Bauer committed
41
42
43
44

            center = fa.field(*fa.index)
            substitutions[fa] = sp.Piecewise((center, condition), (fa, True))
    return [eq.subs(substitutions) for eq in eqs]