placeholder_function.py 2.63 KB
 Martin Bauer committed Apr 01, 2019 1 2 3 ``````import sympy as sp from typing import List from pystencils.assignment import Assignment `````` Martin Bauer committed Apr 01, 2019 4 ``````from pystencils.astnodes import Node `````` Martin Bauer committed Apr 16, 2019 5 ``````from pystencils.sympyextensions import is_constant `````` Martin Bauer committed Apr 01, 2019 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 ``````from pystencils.transformations import generic_visit class PlaceholderFunction: pass def to_placeholder_function(expr, name): """Replaces an expression by a sympy function. - replacing an expression with just a symbol would lead to problem when calculating derivatives - placeholder functions get rid of this problem Examples: >>> x, t = sp.symbols("x, t") >>> temperature = x**2 + t**4 # some 'complicated' dependency >>> temperature_placeholder = to_placeholder_function(temperature, 'T') >>> diffusivity = temperature_placeholder + 42 * t >>> sp.diff(diffusivity, t) # returns a symbol instead of the computed derivative _dT_dt + 42 >>> result, subexpr = remove_placeholder_functions(diffusivity) >>> result T + 42*t >>> subexpr [Assignment(T, t**4 + x**2), Assignment(_dT_dt, 4*t**3), Assignment(_dT_dx, 2*x)] """ symbols = list(expr.atoms(sp.Symbol)) symbols.sort(key=lambda e: e.name) derivative_symbols = [sp.Symbol("_d{}_d{}".format(name, s.name)) for s in symbols] derivatives = [sp.diff(expr, s) for s in symbols] assignments = [Assignment(sp.Symbol(name), expr)] assignments += [Assignment(symbol, derivative) for symbol, derivative in zip(derivative_symbols, derivatives) `````` Martin Bauer committed Apr 16, 2019 41 `````` if not is_constant(derivative)] `````` Martin Bauer committed Apr 01, 2019 42 43 44 `````` def fdiff(_, index): result = derivatives[index - 1] `````` Martin Bauer committed Apr 16, 2019 45 `````` return result if is_constant(result) else derivative_symbols[index - 1] `````` Martin Bauer committed Apr 01, 2019 46 47 48 49 50 51 52 53 54 55 56 57 58 `````` func = type(name, (sp.Function, PlaceholderFunction), {'fdiff': fdiff, 'value': sp.Symbol(name), 'subexpressions': assignments, 'nargs': len(symbols)}) return func(*symbols) def remove_placeholder_functions(expr): subexpressions = [] def visit(e): `````` Martin Bauer committed Apr 01, 2019 59 60 61 `````` if isinstance(e, Node): return e elif isinstance(e, PlaceholderFunction): `````` Martin Bauer committed Apr 01, 2019 62 63 64 65 66 67 68 69 70 71 72 73 74 75 `````` for se in e.subexpressions: if se.lhs not in {a.lhs for a in subexpressions}: subexpressions.append(se) return e.value else: new_args = [visit(a) for a in e.args] return e.func(*new_args) if new_args else e return generic_visit(expr, visit), subexpressions def prepend_placeholder_functions(assignments: List[Assignment]): result, subexpressions = remove_placeholder_functions(assignments) return subexpressions + result``````