diff --git a/pystencils/placeholder_function.py b/pystencils/placeholder_function.py index 67bae6ceafd2a9334ffabbe8abcabcc1a3c7eaa7..aca93d16937d7efa37668b3c15d03f4aa7069166 100644 --- a/pystencils/placeholder_function.py +++ b/pystencils/placeholder_function.py @@ -2,6 +2,7 @@ import sympy as sp from typing import List from pystencils.assignment import Assignment from pystencils.astnodes import Node +from pystencils.sympyextensions import is_constant from pystencils.transformations import generic_visit @@ -37,11 +38,11 @@ def to_placeholder_function(expr, name): assignments = [Assignment(sp.Symbol(name), expr)] assignments += [Assignment(symbol, derivative) for symbol, derivative in zip(derivative_symbols, derivatives) - if not derivative.is_constant()] + if not is_constant(derivative)] def fdiff(_, index): result = derivatives[index - 1] - return result if result.is_constant() else derivative_symbols[index - 1] + return result if is_constant(result) else derivative_symbols[index - 1] func = type(name, (sp.Function, PlaceholderFunction), {'fdiff': fdiff, diff --git a/pystencils/sympyextensions.py b/pystencils/sympyextensions.py index 3dc84f146474b9b708637205d4a625ba921b8cc8..e0906f84cca886e70a19eb548b36352d6385365d 100644 --- a/pystencils/sympyextensions.py +++ b/pystencils/sympyextensions.py @@ -172,6 +172,13 @@ def fast_subs(expression: T, substitutions: Dict, return visit(expression) +def is_constant(expr): + """Simple version of checking if a sympy expression is constant. + Works also for piecewise defined functions - sympy's is_constant() has a problem there, see: + https://github.com/sympy/sympy/issues/16662 + """ + return len(expr.free_symbols) == 0 + def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr, required_match_replacement: Optional[Union[int, float]] = 0.5, required_match_original: Optional[Union[int, float]] = None) -> sp.Expr: