Skip to content
Snippets Groups Projects
Commit 1bd1c6d6 authored by Martin Bauer's avatar Martin Bauer
Browse files

pystencils, new feature: placeholder functions

parent 2922cdba
Branches
Tags
No related merge requests found
import sympy as sp
from typing import List
from pystencils.assignment import Assignment
from pystencils.transformations import generic_visit
class PlaceholderFunction:
pass
def to_placeholder_function(expr, name):
"""Replaces an expression by a sympy function.
- replacing an expression with just a symbol would lead to problem when calculating derivatives
- placeholder functions get rid of this problem
Examples:
>>> x, t = sp.symbols("x, t")
>>> temperature = x**2 + t**4 # some 'complicated' dependency
>>> temperature_placeholder = to_placeholder_function(temperature, 'T')
>>> diffusivity = temperature_placeholder + 42 * t
>>> sp.diff(diffusivity, t) # returns a symbol instead of the computed derivative
_dT_dt + 42
>>> result, subexpr = remove_placeholder_functions(diffusivity)
>>> result
T + 42*t
>>> subexpr
[Assignment(T, t**4 + x**2), Assignment(_dT_dt, 4*t**3), Assignment(_dT_dx, 2*x)]
"""
symbols = list(expr.atoms(sp.Symbol))
symbols.sort(key=lambda e: e.name)
derivative_symbols = [sp.Symbol("_d{}_d{}".format(name, s.name)) for s in symbols]
derivatives = [sp.diff(expr, s) for s in symbols]
assignments = [Assignment(sp.Symbol(name), expr)]
assignments += [Assignment(symbol, derivative)
for symbol, derivative in zip(derivative_symbols, derivatives)
if not derivative.is_constant()]
def fdiff(_, index):
result = derivatives[index - 1]
return result if result.is_constant() else derivative_symbols[index - 1]
func = type(name, (sp.Function, PlaceholderFunction),
{'fdiff': fdiff,
'value': sp.Symbol(name),
'subexpressions': assignments,
'nargs': len(symbols)})
return func(*symbols)
def remove_placeholder_functions(expr):
subexpressions = []
def visit(e):
if isinstance(e, PlaceholderFunction):
for se in e.subexpressions:
if se.lhs not in {a.lhs for a in subexpressions}:
subexpressions.append(se)
return e.value
else:
new_args = [visit(a) for a in e.args]
return e.func(*new_args) if new_args else e
return generic_visit(expr, visit), subexpressions
def prepend_placeholder_functions(assignments: List[Assignment]):
result, subexpressions = remove_placeholder_functions(assignments)
return subexpressions + result
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment