placeholder_function.py 2.63 KB
Newer Older
1
2
3
import sympy as sp
from typing import List
from pystencils.assignment import Assignment
Martin Bauer's avatar
Martin Bauer committed
4
from pystencils.astnodes import Node
5
from pystencils.sympyextensions import is_constant
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from pystencils.transformations import generic_visit


class PlaceholderFunction:
    pass


def to_placeholder_function(expr, name):
    """Replaces an expression by a sympy function.

    - replacing an expression with just a symbol would lead to problem when calculating derivatives
    - placeholder functions get rid of this problem

    Examples:
        >>> x, t = sp.symbols("x, t")
        >>> temperature = x**2 + t**4 # some 'complicated' dependency
        >>> temperature_placeholder = to_placeholder_function(temperature, 'T')
        >>> diffusivity = temperature_placeholder + 42 * t
        >>> sp.diff(diffusivity, t)  # returns a symbol instead of the computed derivative
        _dT_dt + 42
        >>> result, subexpr = remove_placeholder_functions(diffusivity)
        >>> result
        T + 42*t
        >>> subexpr
        [Assignment(T, t**4 + x**2), Assignment(_dT_dt, 4*t**3), Assignment(_dT_dx, 2*x)]

    """
    symbols = list(expr.atoms(sp.Symbol))
    symbols.sort(key=lambda e: e.name)
    derivative_symbols = [sp.Symbol("_d{}_d{}".format(name, s.name)) for s in symbols]
    derivatives = [sp.diff(expr, s) for s in symbols]

    assignments = [Assignment(sp.Symbol(name), expr)]
    assignments += [Assignment(symbol, derivative)
                    for symbol, derivative in zip(derivative_symbols, derivatives)
41
                    if not is_constant(derivative)]
42
43
44

    def fdiff(_, index):
        result = derivatives[index - 1]
45
        return result if is_constant(result) else derivative_symbols[index - 1]
46
47
48
49
50
51
52
53
54
55
56
57
58

    func = type(name, (sp.Function, PlaceholderFunction),
                {'fdiff': fdiff,
                 'value': sp.Symbol(name),
                 'subexpressions': assignments,
                 'nargs': len(symbols)})
    return func(*symbols)


def remove_placeholder_functions(expr):
    subexpressions = []

    def visit(e):
Martin Bauer's avatar
Martin Bauer committed
59
60
61
        if isinstance(e, Node):
            return e
        elif isinstance(e, PlaceholderFunction):
62
63
64
65
66
67
68
69
70
71
72
73
74
75
            for se in e.subexpressions:
                if se.lhs not in {a.lhs for a in subexpressions}:
                    subexpressions.append(se)
            return e.value
        else:
            new_args = [visit(a) for a in e.args]
            return e.func(*new_args) if new_args else e

    return generic_visit(expr, visit), subexpressions


def prepend_placeholder_functions(assignments: List[Assignment]):
    result, subexpressions = remove_placeholder_functions(assignments)
    return subexpressions + result