diff --git a/pystencils/sympyextensions.py b/pystencils/sympyextensions.py index 1a99fa8eb7f67fa89e650795f79d9f8ad2c3cb01..b9a452742aab763b5cb631525b303f4c106e2d64 100644 --- a/pystencils/sympyextensions.py +++ b/pystencils/sympyextensions.py @@ -6,6 +6,7 @@ from functools import partial, reduce from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union import sympy as sp +from sympy import PolynomialError from sympy.functions import Abs from sympy.core.numbers import Zero @@ -442,11 +443,14 @@ def extract_most_common_factor(term): def recursive_collect(expr, symbols, order_by_occurences=False): - """Applies sympy.collect recursively for a list of symbols, collecting symbol 2 in the coefficients of symbol 1, + """Applies sympy.collect recursively for a list of symbols, collecting symbol 2 in the coefficients of symbol 1, and so on. + + ``expr`` must be rewritable as a polynomial in the given ``symbols``. + It it is not, ``recursive_collect`` will fail quietly, returning the original expression. Args: - expr: A sympy expression + expr: A sympy expression. symbols: A sequence of symbols order_by_occurences: If True, during recursive descent, always collect the symbol occuring most often in the expression. @@ -457,7 +461,13 @@ def recursive_collect(expr, symbols, order_by_occurences=False): if len(symbols) == 0: return expr symbol = symbols[0] - collected_poly = sp.Poly(expr.collect(symbol), symbol) + collected = expr.collect(symbol) + + try: + collected_poly = sp.Poly(collected, symbol) + except PolynomialError: + return expr + coeffs = collected_poly.all_coeffs()[::-1] rec_sum = sum(symbol**i * recursive_collect(c, symbols[1:], order_by_occurences) for i, c in enumerate(coeffs)) return rec_sum