Commit d8e498fa by Martin Bauer

### Workaround for sympy bug in placeholder_function

`see https://github.com/sympy/sympy/issues/16662`
parent 27a131fb
 ... @@ -2,6 +2,7 @@ import sympy as sp ... @@ -2,6 +2,7 @@ import sympy as sp from typing import List from typing import List from pystencils.assignment import Assignment from pystencils.assignment import Assignment from pystencils.astnodes import Node from pystencils.astnodes import Node from pystencils.sympyextensions import is_constant from pystencils.transformations import generic_visit from pystencils.transformations import generic_visit ... @@ -37,11 +38,11 @@ def to_placeholder_function(expr, name): ... @@ -37,11 +38,11 @@ def to_placeholder_function(expr, name): assignments = [Assignment(sp.Symbol(name), expr)] assignments = [Assignment(sp.Symbol(name), expr)] assignments += [Assignment(symbol, derivative) assignments += [Assignment(symbol, derivative) for symbol, derivative in zip(derivative_symbols, derivatives) for symbol, derivative in zip(derivative_symbols, derivatives) if not derivative.is_constant()] if not is_constant(derivative)] def fdiff(_, index): def fdiff(_, index): result = derivatives[index - 1] result = derivatives[index - 1] return result if result.is_constant() else derivative_symbols[index - 1] return result if is_constant(result) else derivative_symbols[index - 1] func = type(name, (sp.Function, PlaceholderFunction), func = type(name, (sp.Function, PlaceholderFunction), {'fdiff': fdiff, {'fdiff': fdiff, ... ...
 ... @@ -172,6 +172,13 @@ def fast_subs(expression: T, substitutions: Dict, ... @@ -172,6 +172,13 @@ def fast_subs(expression: T, substitutions: Dict, return visit(expression) return visit(expression) def is_constant(expr): """Simple version of checking if a sympy expression is constant. Works also for piecewise defined functions - sympy's is_constant() has a problem there, see: https://github.com/sympy/sympy/issues/16662 """ return len(expr.free_symbols) == 0 def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr, def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr, required_match_replacement: Optional[Union[int, float]] = 0.5, required_match_replacement: Optional[Union[int, float]] = 0.5, required_match_original: Optional[Union[int, float]] = None) -> sp.Expr: required_match_original: Optional[Union[int, float]] = None) -> sp.Expr: ... ...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!