assignment.py 4.29 KB
Newer Older
Martin Bauer's avatar
Martin Bauer committed
1
import numpy as np
2
import sympy as sp
3
from sympy.codegen.ast import Assignment
4
5
from sympy.printing.latex import LatexPrinter

6
__all__ = ['Assignment', 'assignment_from_stencil']
7
8
9
10
11
12


def print_assignment_latex(printer, expr):
    """sympy cannot print Assignments as Latex. Thus, this function is added to the sympy Latex printer"""
    printed_lhs = printer.doprint(expr.lhs)
    printed_rhs = printer.doprint(expr.rhs)
13
    return r"{printed_lhs} \leftarrow {printed_rhs}".format(printed_lhs=printed_lhs, printed_rhs=printed_rhs)
14
15


Martin Bauer's avatar
Martin Bauer committed
16
def assignment_str(assignment):
17
    return r"{lhs} ← {rhs}".format(lhs=assignment.lhs, rhs=assignment.rhs)
18
19


20
_old_new = sp.codegen.ast.Assignment.__new__
21

22

23
24
25
26
27
def _Assignment__new__(cls, lhs, rhs, *args, **kwargs):
    if isinstance(lhs, (list, tuple, sp.Matrix)) and isinstance(rhs, (list, tuple, sp.Matrix)):
        assert len(lhs) == len(rhs), f'{lhs} and {rhs} must have same length when performing vector assignment!'
        return tuple(_old_new(cls, a, b, *args, **kwargs) for a, b in zip(lhs, rhs))
    return _old_new(cls, lhs, rhs, *args, **kwargs)
28
29


30
31
32
Assignment.__str__ = assignment_str
Assignment.__new__ = _Assignment__new__
LatexPrinter._print_Assignment = print_assignment_latex
Martin Bauer's avatar
Martin Bauer committed
33

34
sp.MutableDenseMatrix.__hash__ = lambda self: hash(tuple(self))
Martin Bauer's avatar
Martin Bauer committed
35

36

37
38
39
40
41
42
43
# Apparently, in SymPy 1.4 Assignment.__hash__ is not implemented. This has been fixed in current master
try:
    sympy_version = sp.__version__.split('.')

    if int(sympy_version[0]) <= 1 and int(sympy_version[1]) <= 4:
        def hash_fun(self):
            return hash((self.lhs, self.rhs))
44

45
46
47
48
        Assignment.__hash__ = hash_fun
except Exception:
    pass

49

50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def assignment_from_stencil(stencil_array, input_field, output_field,
                            normalization_factor=None, order='visual') -> Assignment:
    """Creates an assignment

    Args:
        stencil_array: nested list of numpy array defining the stencil weights
        input_field: field or field access, defining where the stencil should be applied to
        output_field: field or field access where the result is written to
        normalization_factor: optional normalization factor for the stencil
        order: defines how the stencil_array is interpreted. Possible values are 'visual' and 'numpy'.
               For details see examples

    Returns:
        Assignment that can be used to create a kernel

    Examples:
        >>> import pystencils as ps
        >>> f, g = ps.fields("f, g: [2D]")
        >>> stencil = [[0, 2, 0],
        ...            [3, 4, 5],
        ...            [0, 6, 0]]

        By default 'visual ordering is used - i.e. the stencil is applied as the nested lists are written down
73
74
75
        >>> expected_output = Assignment(g[0, 0], 3*f[-1, 0] + 6*f[0, -1] + 4*f[0, 0] + 2*f[0, 1] + 5*f[1, 0])
        >>> assignment_from_stencil(stencil, f, g, order='visual') == expected_output
        True
76
77

        'numpy' ordering uses the first coordinate of the stencil array for x offset, second for y offset etc.
78
79
80
        >>> expected_output = Assignment(g[0, 0], 2*f[-1, 0] + 3*f[0, -1] + 4*f[0, 0] + 5*f[0, 1] + 6*f[1, 0])
        >>> assignment_from_stencil(stencil, f, g, order='numpy') == expected_output
        True
81
82

        You can also pass field accesses to apply the stencil at an already shifted position:
83
84
85
        >>> expected_output = Assignment(g[2, 0], 3*f[0, 0] + 6*f[1, -1] + 4*f[1, 0] + 2*f[1, 1] + 5*f[2, 0])
        >>> assignment_from_stencil(stencil, f[1, 0], g[2, 0]) == expected_output
        True
86
    """
Martin Bauer's avatar
Martin Bauer committed
87
    from pystencils.field import Field
88

89
    stencil_array = np.array(stencil_array)
90
91
92
93
94
95
96
97
98
99
100
101
102
    if order == 'visual':
        stencil_array = np.swapaxes(stencil_array, 0, 1)
        stencil_array = np.flip(stencil_array, axis=1)
    elif order == 'numpy':
        pass
    else:
        raise ValueError("'order' has to be either 'visual' or 'numpy'")

    if isinstance(input_field, Field):
        input_field = input_field.center
    if isinstance(output_field, Field):
        output_field = output_field.center

103
104
105
106
    rhs = 0
    offset = tuple(s // 2 for s in stencil_array.shape)

    for index, factor in np.ndenumerate(stencil_array):
107
108
        shift = tuple(i - o for i, o in zip(index, offset))
        rhs += factor * input_field.get_shifted(*shift)
109
110
111
112

    if normalization_factor:
        rhs *= normalization_factor

113
    return Assignment(output_field, rhs)