Skip to content
Snippets Groups Projects
Commit 67019520 authored by Markus Holzer's avatar Markus Holzer
Browse files

Merge branch 'fix_extract_constants' into 'master'

Fix: `recursive_collect` now fails silently

See merge request pycodegen/pystencils!301
parents 205d0a39 f266bd7d
No related merge requests found
Pipeline #42198 failed with stages
in 2 minutes and 55 seconds
......@@ -6,6 +6,7 @@ from functools import partial, reduce
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union
import sympy as sp
from sympy import PolynomialError
from sympy.functions import Abs
from sympy.core.numbers import Zero
......@@ -442,11 +443,14 @@ def extract_most_common_factor(term):
def recursive_collect(expr, symbols, order_by_occurences=False):
"""Applies sympy.collect recursively for a list of symbols, collecting symbol 2 in the coefficients of symbol 1,
"""Applies sympy.collect recursively for a list of symbols, collecting symbol 2 in the coefficients of symbol 1,
and so on.
``expr`` must be rewritable as a polynomial in the given ``symbols``.
It it is not, ``recursive_collect`` will fail quietly, returning the original expression.
Args:
expr: A sympy expression
expr: A sympy expression.
symbols: A sequence of symbols
order_by_occurences: If True, during recursive descent, always collect the symbol occuring
most often in the expression.
......@@ -457,7 +461,13 @@ def recursive_collect(expr, symbols, order_by_occurences=False):
if len(symbols) == 0:
return expr
symbol = symbols[0]
collected_poly = sp.Poly(expr.collect(symbol), symbol)
collected = expr.collect(symbol)
try:
collected_poly = sp.Poly(collected, symbol)
except PolynomialError:
return expr
coeffs = collected_poly.all_coeffs()[::-1]
rec_sum = sum(symbol**i * recursive_collect(c, symbols[1:], order_by_occurences) for i, c in enumerate(coeffs))
return rec_sum
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment