From 1bd1c6d6568c5c48ea376ce418ef43aa64b7ea31 Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Fri, 29 Mar 2019 13:27:24 +0100 Subject: [PATCH] pystencils, new feature: placeholder functions --- pystencils/placeholder_function.py | 71 ++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 pystencils/placeholder_function.py diff --git a/pystencils/placeholder_function.py b/pystencils/placeholder_function.py new file mode 100644 index 0000000..73d8bda --- /dev/null +++ b/pystencils/placeholder_function.py @@ -0,0 +1,71 @@ +import sympy as sp +from typing import List +from pystencils.assignment import Assignment +from pystencils.transformations import generic_visit + + +class PlaceholderFunction: + pass + + +def to_placeholder_function(expr, name): + """Replaces an expression by a sympy function. + + - replacing an expression with just a symbol would lead to problem when calculating derivatives + - placeholder functions get rid of this problem + + Examples: + >>> x, t = sp.symbols("x, t") + >>> temperature = x**2 + t**4 # some 'complicated' dependency + >>> temperature_placeholder = to_placeholder_function(temperature, 'T') + >>> diffusivity = temperature_placeholder + 42 * t + >>> sp.diff(diffusivity, t) # returns a symbol instead of the computed derivative + _dT_dt + 42 + >>> result, subexpr = remove_placeholder_functions(diffusivity) + >>> result + T + 42*t + >>> subexpr + [Assignment(T, t**4 + x**2), Assignment(_dT_dt, 4*t**3), Assignment(_dT_dx, 2*x)] + + """ + symbols = list(expr.atoms(sp.Symbol)) + symbols.sort(key=lambda e: e.name) + derivative_symbols = [sp.Symbol("_d{}_d{}".format(name, s.name)) for s in symbols] + derivatives = [sp.diff(expr, s) for s in symbols] + + assignments = [Assignment(sp.Symbol(name), expr)] + assignments += [Assignment(symbol, derivative) + for symbol, derivative in zip(derivative_symbols, derivatives) + if not derivative.is_constant()] + + def fdiff(_, index): + result = derivatives[index - 1] + return result if result.is_constant() else derivative_symbols[index - 1] + + func = type(name, (sp.Function, PlaceholderFunction), + {'fdiff': fdiff, + 'value': sp.Symbol(name), + 'subexpressions': assignments, + 'nargs': len(symbols)}) + return func(*symbols) + + +def remove_placeholder_functions(expr): + subexpressions = [] + + def visit(e): + if isinstance(e, PlaceholderFunction): + for se in e.subexpressions: + if se.lhs not in {a.lhs for a in subexpressions}: + subexpressions.append(se) + return e.value + else: + new_args = [visit(a) for a in e.args] + return e.func(*new_args) if new_args else e + + return generic_visit(expr, visit), subexpressions + + +def prepend_placeholder_functions(assignments: List[Assignment]): + result, subexpressions = remove_placeholder_functions(assignments) + return subexpressions + result -- GitLab