From f266bd7d38a9dabf4a617c3b36d065a36ee1476e Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Thu, 28 Jul 2022 15:52:02 +0200 Subject: [PATCH] Fix: `recursive_collect` now fails silently --- pystencils/sympyextensions.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/pystencils/sympyextensions.py b/pystencils/sympyextensions.py index 1a99fa8eb..b9a452742 100644 --- a/pystencils/sympyextensions.py +++ b/pystencils/sympyextensions.py @@ -6,6 +6,7 @@ from functools import partial, reduce from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union import sympy as sp +from sympy import PolynomialError from sympy.functions import Abs from sympy.core.numbers import Zero @@ -442,11 +443,14 @@ def extract_most_common_factor(term): def recursive_collect(expr, symbols, order_by_occurences=False): - """Applies sympy.collect recursively for a list of symbols, collecting symbol 2 in the coefficients of symbol 1, + """Applies sympy.collect recursively for a list of symbols, collecting symbol 2 in the coefficients of symbol 1, and so on. + + ``expr`` must be rewritable as a polynomial in the given ``symbols``. + It it is not, ``recursive_collect`` will fail quietly, returning the original expression. Args: - expr: A sympy expression + expr: A sympy expression. symbols: A sequence of symbols order_by_occurences: If True, during recursive descent, always collect the symbol occuring most often in the expression. @@ -457,7 +461,13 @@ def recursive_collect(expr, symbols, order_by_occurences=False): if len(symbols) == 0: return expr symbol = symbols[0] - collected_poly = sp.Poly(expr.collect(symbol), symbol) + collected = expr.collect(symbol) + + try: + collected_poly = sp.Poly(collected, symbol) + except PolynomialError: + return expr + coeffs = collected_poly.all_coeffs()[::-1] rec_sum = sum(symbol**i * recursive_collect(c, symbols[1:], order_by_occurences) for i, c in enumerate(coeffs)) return rec_sum -- GitLab