diff --git a/pystencils/sympyextensions.py b/pystencils/sympyextensions.py
index 1a99fa8eb7f67fa89e650795f79d9f8ad2c3cb01..b9a452742aab763b5cb631525b303f4c106e2d64 100644
--- a/pystencils/sympyextensions.py
+++ b/pystencils/sympyextensions.py
@@ -6,6 +6,7 @@ from functools import partial, reduce
 from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union
 
 import sympy as sp
+from sympy import PolynomialError
 from sympy.functions import Abs
 from sympy.core.numbers import Zero
 
@@ -442,11 +443,14 @@ def extract_most_common_factor(term):
 
 
 def recursive_collect(expr, symbols, order_by_occurences=False):
-    """Applies sympy.collect recursively for a list of symbols, collecting symbol 2 in the coefficients of symbol 1, 
+    """Applies sympy.collect recursively for a list of symbols, collecting symbol 2 in the coefficients of symbol 1,
     and so on.
+    
+    ``expr`` must be rewritable as a polynomial in the given ``symbols``.
+    It it is not, ``recursive_collect`` will fail quietly, returning the original expression.
 
     Args:
-        expr: A sympy expression
+        expr: A sympy expression.
         symbols: A sequence of symbols
         order_by_occurences: If True, during recursive descent, always collect the symbol occuring 
                              most often in the expression.
@@ -457,7 +461,13 @@ def recursive_collect(expr, symbols, order_by_occurences=False):
     if len(symbols) == 0:
         return expr
     symbol = symbols[0]
-    collected_poly = sp.Poly(expr.collect(symbol), symbol)
+    collected = expr.collect(symbol)
+    
+    try:
+        collected_poly = sp.Poly(collected, symbol)
+    except PolynomialError:
+        return expr
+
     coeffs = collected_poly.all_coeffs()[::-1]
     rec_sum = sum(symbol**i * recursive_collect(c, symbols[1:], order_by_occurences) for i, c in enumerate(coeffs))
     return rec_sum