### Merge branch 'simplify_equality' into 'master'

```Added simplify_by_equality

See merge request pycodegen/pystencils!286```
parents be198ac4 8c53c16a
Pipeline #38427 failed with stages
in 3 minutes and 6 seconds
 ... ... @@ -453,6 +453,72 @@ def recursive_collect(expr, symbols, order_by_occurences=False): return rec_sum def summands(expr): return set(expr.args) if isinstance(expr, sp.Add) else {expr} def simplify_by_equality(expr, a, b, c): """ Uses the equality a = b + c, where a and b must be symbols, to simplify expr by attempting to express additive combinations of two quantities by the third. This works on expressions that are reducible to the form :math:`a * (...) + b * (...) + c * (...)`, without any mixed terms of a, b and c. """ if not isinstance(a, sp.Symbol) or not isinstance(b, sp.Symbol): raise ValueError("a and b must be symbols.") c = sp.sympify(c) if not (isinstance(c, sp.Symbol) or is_constant(c)): raise ValueError("c must be either a symbol or a constant!") expr = sp.sympify(expr) expr_expanded = sp.expand(expr) a_coeff = expr_expanded.coeff(a, 1) expr_expanded -= (a * a_coeff).expand() b_coeff = expr_expanded.coeff(b, 1) expr_expanded -= (b * b_coeff).expand() if isinstance(c, sp.Symbol): c_coeff = expr_expanded.coeff(c, 1) rest = expr_expanded - (c * c_coeff).expand() else: c_coeff = expr_expanded / c rest = 0 a_summands = summands(a_coeff) b_summands = summands(b_coeff) c_summands = summands(c_coeff) # replace b + c by a b_plus_c_coeffs = b_summands & c_summands for coeff in b_plus_c_coeffs: rest += a * coeff b_summands -= b_plus_c_coeffs c_summands -= b_plus_c_coeffs # replace a - b by c neg_b_summands = {-x for x in b_summands} a_minus_b_coeffs = a_summands & neg_b_summands for coeff in a_minus_b_coeffs: rest += c * coeff a_summands -= a_minus_b_coeffs b_summands -= {-x for x in a_minus_b_coeffs} # replace a - c by b neg_c_summands = {-x for x in c_summands} a_minus_c_coeffs = a_summands & neg_c_summands for coeff in a_minus_c_coeffs: rest += b * coeff a_summands -= a_minus_c_coeffs c_summands -= {-x for x in a_minus_c_coeffs} # put it back together return (rest + a * sum(a_summands) + b * sum(b_summands) + c * sum(c_summands)).expand() def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]], only_type: Optional[str] = 'real') -> Dict[str, int]: """Counts the number of additions, multiplications and division. ... ...
 import sympy import numpy as np import sympy as sp import pystencils from pystencils.sympyextensions import replace_second_order_products from pystencils.sympyextensions import remove_higher_order_terms from pystencils.sympyextensions import complete_the_squares_in_exp from pystencils.sympyextensions import extract_most_common_factor from pystencils.sympyextensions import simplify_by_equality from pystencils.sympyextensions import count_operations from pystencils.sympyextensions import common_denominator from pystencils.sympyextensions import get_symmetric_part ... ... @@ -176,3 +178,26 @@ def test_get_symmetric_part(): sym_part = get_symmetric_part(expr, sympy.symbols(f'y z')) assert sym_part == expected_result def test_simplify_by_equality(): x, y, z = sp.symbols('x, y, z') p, q = sp.symbols('p, q') # Let x = y + z expr = x * p - y * p + z * q expr = simplify_by_equality(expr, x, y, z) assert expr == z * p + z * q expr = x * (p - 2 * q) + 2 * q * z expr = simplify_by_equality(expr, x, y, z) assert expr == x * p - 2 * q * y expr = x * (y + z) - y * z expr = simplify_by_equality(expr, x, y, z) assert expr == x*y + z**2 # Let x = y + 2 expr = x * p - 2 * p expr = simplify_by_equality(expr, x, y, 2) assert expr == y * p
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment