assignment.py 4.11 KB
Newer Older
Martin Bauer's avatar
Martin Bauer committed
1
# -*- coding: utf-8 -*-
2
import sympy as sp
3
4
from sympy.printing.latex import LatexPrinter

5
6
7
8
try:
    from sympy.codegen.ast import Assignment
except ImportError:
    Assignment = None
9
import numpy as np
10

11
__all__ = ['Assignment', 'assignment_from_stencil']
12
13
14
15
16
17


def print_assignment_latex(printer, expr):
    """sympy cannot print Assignments as Latex. Thus, this function is added to the sympy Latex printer"""
    printed_lhs = printer.doprint(expr.lhs)
    printed_rhs = printer.doprint(expr.rhs)
18
    return r"{printed_lhs} \leftarrow {printed_rhs}".format(printed_lhs=printed_lhs, printed_rhs=printed_rhs)
19
20


Martin Bauer's avatar
Martin Bauer committed
21
def assignment_str(assignment):
22
    return r"{lhs} ← {rhs}".format(lhs=assignment.lhs, rhs=assignment.rhs)
23
24
25
26
27
28
29
30
31
32


if Assignment:

    Assignment.__str__ = assignment_str
    LatexPrinter._print_Assignment = print_assignment_latex

else:
    # back port for older sympy versions that don't have Assignment  yet

Martin Bauer's avatar
Martin Bauer committed
33
    class Assignment(sp.Rel):  # pragma: no cover
34
35
36

        rel_op = ':='
        __slots__ = []
Martin Bauer's avatar
Martin Bauer committed
37

38
39
40
41
42
43
44
45
46
47
48
        def __new__(cls, lhs, rhs=0, **assumptions):
            from sympy.matrices.expressions.matexpr import (
                MatrixElement, MatrixSymbol)
            from sympy.tensor.indexed import Indexed
            lhs = sp.sympify(lhs)
            rhs = sp.sympify(rhs)
            # Tuple of things that can be on the lhs of an assignment
            assignable = (sp.Symbol, MatrixSymbol, MatrixElement, Indexed)
            if not isinstance(lhs, assignable):
                raise TypeError("Cannot assign to lhs of type %s." % type(lhs))
            return sp.Rel.__new__(cls, lhs, rhs, **assumptions)
Martin Bauer's avatar
Martin Bauer committed
49

50
51
        __str__ = assignment_str
        _print_Assignment = print_assignment_latex
52
53


54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def assignment_from_stencil(stencil_array, input_field, output_field,
                            normalization_factor=None, order='visual') -> Assignment:
    """Creates an assignment

    Args:
        stencil_array: nested list of numpy array defining the stencil weights
        input_field: field or field access, defining where the stencil should be applied to
        output_field: field or field access where the result is written to
        normalization_factor: optional normalization factor for the stencil
        order: defines how the stencil_array is interpreted. Possible values are 'visual' and 'numpy'.
               For details see examples

    Returns:
        Assignment that can be used to create a kernel

    Examples:
        >>> import pystencils as ps
        >>> f, g = ps.fields("f, g: [2D]")
        >>> stencil = [[0, 2, 0],
        ...            [3, 4, 5],
        ...            [0, 6, 0]]

        By default 'visual ordering is used - i.e. the stencil is applied as the nested lists are written down
        >>> assignment_from_stencil(stencil, f, g, order='visual')
        Assignment(g_C, 3*f_W + 6*f_S + 4*f_C + 2*f_N + 5*f_E)

        'numpy' ordering uses the first coordinate of the stencil array for x offset, second for y offset etc.
        >>> assignment_from_stencil(stencil, f, g, order='numpy')
        Assignment(g_C, 2*f_W + 3*f_S + 4*f_C + 5*f_N + 6*f_E)

        You can also pass field accesses to apply the stencil at an already shifted position:
        >>> assignment_from_stencil(stencil, f[1, 0], g[2, 0])
        Assignment(g_2E, 3*f_C + 6*f_SE + 4*f_E + 2*f_NE + 5*f_2E)
    """
    from pystencils import Field

90
    stencil_array = np.array(stencil_array)
91
92
93
94
95
96
97
98
99
100
101
102
103
    if order == 'visual':
        stencil_array = np.swapaxes(stencil_array, 0, 1)
        stencil_array = np.flip(stencil_array, axis=1)
    elif order == 'numpy':
        pass
    else:
        raise ValueError("'order' has to be either 'visual' or 'numpy'")

    if isinstance(input_field, Field):
        input_field = input_field.center
    if isinstance(output_field, Field):
        output_field = output_field.center

104
105
106
107
    rhs = 0
    offset = tuple(s // 2 for s in stencil_array.shape)

    for index, factor in np.ndenumerate(stencil_array):
108
109
        shift = tuple(i - o for i, o in zip(index, offset))
        rhs += factor * input_field.get_shifted(*shift)
110
111
112
113

    if normalization_factor:
        rhs *= normalization_factor

114
    return Assignment(output_field, rhs)