From d8e498fa9b62d6c7afd68af51e41bf4b9c16aa36 Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Tue, 16 Apr 2019 15:19:09 +0200 Subject: [PATCH] Workaround for sympy bug in placeholder_function see https://github.com/sympy/sympy/issues/16662 --- pystencils/placeholder_function.py | 5 +++-- pystencils/sympyextensions.py | 7 +++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/pystencils/placeholder_function.py b/pystencils/placeholder_function.py index 67bae6cea..aca93d169 100644 --- a/pystencils/placeholder_function.py +++ b/pystencils/placeholder_function.py @@ -2,6 +2,7 @@ import sympy as sp from typing import List from pystencils.assignment import Assignment from pystencils.astnodes import Node +from pystencils.sympyextensions import is_constant from pystencils.transformations import generic_visit @@ -37,11 +38,11 @@ def to_placeholder_function(expr, name): assignments = [Assignment(sp.Symbol(name), expr)] assignments += [Assignment(symbol, derivative) for symbol, derivative in zip(derivative_symbols, derivatives) - if not derivative.is_constant()] + if not is_constant(derivative)] def fdiff(_, index): result = derivatives[index - 1] - return result if result.is_constant() else derivative_symbols[index - 1] + return result if is_constant(result) else derivative_symbols[index - 1] func = type(name, (sp.Function, PlaceholderFunction), {'fdiff': fdiff, diff --git a/pystencils/sympyextensions.py b/pystencils/sympyextensions.py index 3dc84f146..e0906f84c 100644 --- a/pystencils/sympyextensions.py +++ b/pystencils/sympyextensions.py @@ -172,6 +172,13 @@ def fast_subs(expression: T, substitutions: Dict, return visit(expression) +def is_constant(expr): + """Simple version of checking if a sympy expression is constant. + Works also for piecewise defined functions - sympy's is_constant() has a problem there, see: + https://github.com/sympy/sympy/issues/16662 + """ + return len(expr.free_symbols) == 0 + def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr, required_match_replacement: Optional[Union[int, float]] = 0.5, required_match_original: Optional[Union[int, float]] = None) -> sp.Expr: -- GitLab