assignment.py 2.13 KB
Newer Older
Martin Bauer's avatar
Martin Bauer committed
1
# -*- coding: utf-8 -*-
2
import sympy as sp
3
4
from sympy.printing.latex import LatexPrinter

5
6
7
8
try:
    from sympy.codegen.ast import Assignment
except ImportError:
    Assignment = None
9
import numpy as np
10

11
__all__ = ['Assignment', 'assignment_from_stencil']
12
13
14
15
16
17


def print_assignment_latex(printer, expr):
    """sympy cannot print Assignments as Latex. Thus, this function is added to the sympy Latex printer"""
    printed_lhs = printer.doprint(expr.lhs)
    printed_rhs = printer.doprint(expr.rhs)
18
    return "{printed_lhs} \leftarrow {printed_rhs}".format(printed_lhs=printed_lhs, printed_rhs=printed_rhs)
19
20


Martin Bauer's avatar
Martin Bauer committed
21
def assignment_str(assignment):
22
23
24
25
26
27
28
29
30
31
32
    return "{lhs} ← {rhs}".format(lhs=assignment.lhs, rhs=assignment.rhs)


if Assignment:

    Assignment.__str__ = assignment_str
    LatexPrinter._print_Assignment = print_assignment_latex

else:
    # back port for older sympy versions that don't have Assignment  yet

Martin Bauer's avatar
Martin Bauer committed
33
    class Assignment(sp.Rel):  # pragma: no cover
34
35
36

        rel_op = ':='
        __slots__ = []
Martin Bauer's avatar
Martin Bauer committed
37

38
39
40
41
42
43
44
45
46
47
48
        def __new__(cls, lhs, rhs=0, **assumptions):
            from sympy.matrices.expressions.matexpr import (
                MatrixElement, MatrixSymbol)
            from sympy.tensor.indexed import Indexed
            lhs = sp.sympify(lhs)
            rhs = sp.sympify(rhs)
            # Tuple of things that can be on the lhs of an assignment
            assignable = (sp.Symbol, MatrixSymbol, MatrixElement, Indexed)
            if not isinstance(lhs, assignable):
                raise TypeError("Cannot assign to lhs of type %s." % type(lhs))
            return sp.Rel.__new__(cls, lhs, rhs, **assumptions)
Martin Bauer's avatar
Martin Bauer committed
49

50
51
        __str__ = assignment_str
        _print_Assignment = print_assignment_latex
52
53
54
55
56
57
58
59
60
61
62
63
64
65


def assignment_from_stencil(stencil_array, input_field, output_field, normalization_factor=None):
    stencil_array = np.array(stencil_array)
    rhs = 0
    offset = tuple(s // 2 for s in stencil_array.shape)

    for index, factor in np.ndenumerate(stencil_array):
        rhs += factor * input_field[tuple(i - o for i, o in zip(index, offset))]

    if normalization_factor:
        rhs *= normalization_factor

    return Assignment(output_field.center(), rhs)