Skip to content
Snippets Groups Projects
Commit d8e498fa authored by Martin Bauer's avatar Martin Bauer
Browse files

Workaround for sympy bug in placeholder_function

see https://github.com/sympy/sympy/issues/16662
parent 27a131fb
No related merge requests found
...@@ -2,6 +2,7 @@ import sympy as sp ...@@ -2,6 +2,7 @@ import sympy as sp
from typing import List from typing import List
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.astnodes import Node from pystencils.astnodes import Node
from pystencils.sympyextensions import is_constant
from pystencils.transformations import generic_visit from pystencils.transformations import generic_visit
...@@ -37,11 +38,11 @@ def to_placeholder_function(expr, name): ...@@ -37,11 +38,11 @@ def to_placeholder_function(expr, name):
assignments = [Assignment(sp.Symbol(name), expr)] assignments = [Assignment(sp.Symbol(name), expr)]
assignments += [Assignment(symbol, derivative) assignments += [Assignment(symbol, derivative)
for symbol, derivative in zip(derivative_symbols, derivatives) for symbol, derivative in zip(derivative_symbols, derivatives)
if not derivative.is_constant()] if not is_constant(derivative)]
def fdiff(_, index): def fdiff(_, index):
result = derivatives[index - 1] result = derivatives[index - 1]
return result if result.is_constant() else derivative_symbols[index - 1] return result if is_constant(result) else derivative_symbols[index - 1]
func = type(name, (sp.Function, PlaceholderFunction), func = type(name, (sp.Function, PlaceholderFunction),
{'fdiff': fdiff, {'fdiff': fdiff,
......
...@@ -172,6 +172,13 @@ def fast_subs(expression: T, substitutions: Dict, ...@@ -172,6 +172,13 @@ def fast_subs(expression: T, substitutions: Dict,
return visit(expression) return visit(expression)
def is_constant(expr):
"""Simple version of checking if a sympy expression is constant.
Works also for piecewise defined functions - sympy's is_constant() has a problem there, see:
https://github.com/sympy/sympy/issues/16662
"""
return len(expr.free_symbols) == 0
def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr, def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr,
required_match_replacement: Optional[Union[int, float]] = 0.5, required_match_replacement: Optional[Union[int, float]] = 0.5,
required_match_original: Optional[Union[int, float]] = None) -> sp.Expr: required_match_original: Optional[Union[int, float]] = None) -> sp.Expr:
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment