From f266bd7d38a9dabf4a617c3b36d065a36ee1476e Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Thu, 28 Jul 2022 15:52:02 +0200
Subject: [PATCH] Fix: `recursive_collect` now fails silently

---
 pystencils/sympyextensions.py | 16 +++++++++++++---
 1 file changed, 13 insertions(+), 3 deletions(-)

diff --git a/pystencils/sympyextensions.py b/pystencils/sympyextensions.py
index 1a99fa8eb..b9a452742 100644
--- a/pystencils/sympyextensions.py
+++ b/pystencils/sympyextensions.py
@@ -6,6 +6,7 @@ from functools import partial, reduce
 from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union
 
 import sympy as sp
+from sympy import PolynomialError
 from sympy.functions import Abs
 from sympy.core.numbers import Zero
 
@@ -442,11 +443,14 @@ def extract_most_common_factor(term):
 
 
 def recursive_collect(expr, symbols, order_by_occurences=False):
-    """Applies sympy.collect recursively for a list of symbols, collecting symbol 2 in the coefficients of symbol 1, 
+    """Applies sympy.collect recursively for a list of symbols, collecting symbol 2 in the coefficients of symbol 1,
     and so on.
+    
+    ``expr`` must be rewritable as a polynomial in the given ``symbols``.
+    It it is not, ``recursive_collect`` will fail quietly, returning the original expression.
 
     Args:
-        expr: A sympy expression
+        expr: A sympy expression.
         symbols: A sequence of symbols
         order_by_occurences: If True, during recursive descent, always collect the symbol occuring 
                              most often in the expression.
@@ -457,7 +461,13 @@ def recursive_collect(expr, symbols, order_by_occurences=False):
     if len(symbols) == 0:
         return expr
     symbol = symbols[0]
-    collected_poly = sp.Poly(expr.collect(symbol), symbol)
+    collected = expr.collect(symbol)
+    
+    try:
+        collected_poly = sp.Poly(collected, symbol)
+    except PolynomialError:
+        return expr
+
     coeffs = collected_poly.all_coeffs()[::-1]
     rec_sum = sum(symbol**i * recursive_collect(c, symbols[1:], order_by_occurences) for i, c in enumerate(coeffs))
     return rec_sum
-- 
GitLab