diff --git a/pystencils/sympyextensions.py b/pystencils/sympyextensions.py index f63328d81781c578786c0019edb1549b83a53c83..861239fdce6560820157af29342be876285382af 100644 --- a/pystencils/sympyextensions.py +++ b/pystencils/sympyextensions.py @@ -453,6 +453,72 @@ def recursive_collect(expr, symbols, order_by_occurences=False): return rec_sum +def summands(expr): + return set(expr.args) if isinstance(expr, sp.Add) else {expr} + + +def simplify_by_equality(expr, a, b, c): + """ + Uses the equality a = b + c, where a and b must be symbols, to simplify expr + by attempting to express additive combinations of two quantities by the third. + + This works on expressions that are reducible to the form + :math:`a * (...) + b * (...) + c * (...)`, + without any mixed terms of a, b and c. + """ + if not isinstance(a, sp.Symbol) or not isinstance(b, sp.Symbol): + raise ValueError("a and b must be symbols.") + + c = sp.sympify(c) + + if not (isinstance(c, sp.Symbol) or is_constant(c)): + raise ValueError("c must be either a symbol or a constant!") + + expr = sp.sympify(expr) + + expr_expanded = sp.expand(expr) + a_coeff = expr_expanded.coeff(a, 1) + expr_expanded -= (a * a_coeff).expand() + b_coeff = expr_expanded.coeff(b, 1) + expr_expanded -= (b * b_coeff).expand() + if isinstance(c, sp.Symbol): + c_coeff = expr_expanded.coeff(c, 1) + rest = expr_expanded - (c * c_coeff).expand() + else: + c_coeff = expr_expanded / c + rest = 0 + + a_summands = summands(a_coeff) + b_summands = summands(b_coeff) + c_summands = summands(c_coeff) + + # replace b + c by a + b_plus_c_coeffs = b_summands & c_summands + for coeff in b_plus_c_coeffs: + rest += a * coeff + b_summands -= b_plus_c_coeffs + c_summands -= b_plus_c_coeffs + + # replace a - b by c + neg_b_summands = {-x for x in b_summands} + a_minus_b_coeffs = a_summands & neg_b_summands + for coeff in a_minus_b_coeffs: + rest += c * coeff + a_summands -= a_minus_b_coeffs + b_summands -= {-x for x in a_minus_b_coeffs} + + # replace a - c by b + neg_c_summands = {-x for x in c_summands} + a_minus_c_coeffs = a_summands & neg_c_summands + for coeff in a_minus_c_coeffs: + rest += b * coeff + a_summands -= a_minus_c_coeffs + c_summands -= {-x for x in a_minus_c_coeffs} + + # put it back together + return (rest + a * sum(a_summands) + b * sum(b_summands) + c * sum(c_summands)).expand() + + def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]], only_type: Optional[str] = 'real') -> Dict[str, int]: """Counts the number of additions, multiplications and division. diff --git a/pystencils_tests/test_sympyextensions.py b/pystencils_tests/test_sympyextensions.py index 82e0ef40206a293b99ed568ab6b7c75f28fd43a7..38a138d2b0d4a52d56214a6f2c5c4f0a7dedfd9c 100644 --- a/pystencils_tests/test_sympyextensions.py +++ b/pystencils_tests/test_sympyextensions.py @@ -1,11 +1,13 @@ import sympy import numpy as np +import sympy as sp import pystencils from pystencils.sympyextensions import replace_second_order_products from pystencils.sympyextensions import remove_higher_order_terms from pystencils.sympyextensions import complete_the_squares_in_exp from pystencils.sympyextensions import extract_most_common_factor +from pystencils.sympyextensions import simplify_by_equality from pystencils.sympyextensions import count_operations from pystencils.sympyextensions import common_denominator from pystencils.sympyextensions import get_symmetric_part @@ -176,3 +178,26 @@ def test_get_symmetric_part(): sym_part = get_symmetric_part(expr, sympy.symbols(f'y z')) assert sym_part == expected_result + + +def test_simplify_by_equality(): + x, y, z = sp.symbols('x, y, z') + p, q = sp.symbols('p, q') + + # Let x = y + z + expr = x * p - y * p + z * q + expr = simplify_by_equality(expr, x, y, z) + assert expr == z * p + z * q + + expr = x * (p - 2 * q) + 2 * q * z + expr = simplify_by_equality(expr, x, y, z) + assert expr == x * p - 2 * q * y + + expr = x * (y + z) - y * z + expr = simplify_by_equality(expr, x, y, z) + assert expr == x*y + z**2 + + # Let x = y + 2 + expr = x * p - 2 * p + expr = simplify_by_equality(expr, x, y, 2) + assert expr == y * p