From d8e498fa9b62d6c7afd68af51e41bf4b9c16aa36 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Tue, 16 Apr 2019 15:19:09 +0200
Subject: [PATCH] Workaround for sympy bug in placeholder_function

see https://github.com/sympy/sympy/issues/16662
---
 pystencils/placeholder_function.py | 5 +++--
 pystencils/sympyextensions.py      | 7 +++++++
 2 files changed, 10 insertions(+), 2 deletions(-)

diff --git a/pystencils/placeholder_function.py b/pystencils/placeholder_function.py
index 67bae6cea..aca93d169 100644
--- a/pystencils/placeholder_function.py
+++ b/pystencils/placeholder_function.py
@@ -2,6 +2,7 @@ import sympy as sp
 from typing import List
 from pystencils.assignment import Assignment
 from pystencils.astnodes import Node
+from pystencils.sympyextensions import is_constant
 from pystencils.transformations import generic_visit
 
 
@@ -37,11 +38,11 @@ def to_placeholder_function(expr, name):
     assignments = [Assignment(sp.Symbol(name), expr)]
     assignments += [Assignment(symbol, derivative)
                     for symbol, derivative in zip(derivative_symbols, derivatives)
-                    if not derivative.is_constant()]
+                    if not is_constant(derivative)]
 
     def fdiff(_, index):
         result = derivatives[index - 1]
-        return result if result.is_constant() else derivative_symbols[index - 1]
+        return result if is_constant(result) else derivative_symbols[index - 1]
 
     func = type(name, (sp.Function, PlaceholderFunction),
                 {'fdiff': fdiff,
diff --git a/pystencils/sympyextensions.py b/pystencils/sympyextensions.py
index 3dc84f146..e0906f84c 100644
--- a/pystencils/sympyextensions.py
+++ b/pystencils/sympyextensions.py
@@ -172,6 +172,13 @@ def fast_subs(expression: T, substitutions: Dict,
         return visit(expression)
 
 
+def is_constant(expr):
+    """Simple version of checking if a sympy expression is constant.
+    Works also for piecewise defined functions - sympy's is_constant() has a problem there, see:
+    https://github.com/sympy/sympy/issues/16662
+    """
+    return len(expr.free_symbols) == 0
+
 def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr,
                   required_match_replacement: Optional[Union[int, float]] = 0.5,
                   required_match_original: Optional[Union[int, float]] = None) -> sp.Expr:
-- 
GitLab