Commit 8bc8b39a by Markus Holzer

### added test cases for sympyextensions

parent fb85c0b7
 ... ... @@ -272,7 +272,7 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr, def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Symbol], positive: Optional[bool] = None, replace_mixed: Optional[List[Assignment]] = None) -> sp.Expr: """Replaces second order mixed terms like x*y by 2*( (x+y)**2 - x**2 - y**2 ). """Replaces second order mixed terms like 4*x*y by 2*( (x+y)**2 - x**2 - y**2 ). This makes the term longer - simplify usually is undoing these - however this transformation can be done to find more common sub-expressions ... ... @@ -293,7 +293,7 @@ def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Sym if expr.is_Mul: distinct_search_symbols = set() nr_of_search_terms = 0 other_factors = 1 other_factors = sp.Integer(1) for t in expr.args: if t in search_symbols: nr_of_search_terms += 1 ... ... @@ -481,7 +481,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], pass elif t.func is sp.Mul: if check_type(t): result['muls'] += len(t.args) - 1 result['muls'] += len(t.args) for a in t.args: if a == 1 or a == -1: result['muls'] -= 1 ... ... @@ -515,7 +515,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], elif sp.nsimplify(t.exp) == sp.Rational(1, 2): result['sqrts'] += 1 else: warnings.warn("Cannot handle exponent", t.exp, " of sp.Pow node") warnings.warn(f"Cannot handle exponent {t.exp} of sp.Pow node") else: warnings.warn("Counting operations: only integer exponents are supported in Pow, " "counting will be inaccurate") ... ... @@ -526,7 +526,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], elif isinstance(t, sp.Rel): pass else: warnings.warn("Unknown sympy node of type " + str(t.func) + " counting will be inaccurate") warnings.warn(f"Unknown sympy node of type {str(t.func)} counting will be inaccurate") if visit_children: for a in t.args: ... ...
 import sympy import pystencils from pystencils.sympyextensions import replace_second_order_products from pystencils.sympyextensions import remove_higher_order_terms from pystencils.sympyextensions import complete_the_squares_in_exp from pystencils.sympyextensions import extract_most_common_factor from pystencils.sympyextensions import count_operations from pystencils.sympyextensions import common_denominator from pystencils.sympyextensions import get_symmetric_part from pystencils import Assignment from pystencils.fast_approximation import (fast_division, fast_inv_sqrt, fast_sqrt, insert_fast_divisions, insert_fast_sqrts) def test_replace_second_order_products(): x, y = sympy.symbols('x y') expr = 4 * x * y expected_expr_positive = 2 * ((x + y) ** 2 - x ** 2 - y ** 2) expected_expr_negative = 2 * (-(x - y) ** 2 + x ** 2 + y ** 2) result = replace_second_order_products(expr, search_symbols=[x, y], positive=True) assert result == expected_expr_positive assert (result - expected_expr_positive).simplify() == 0 result = replace_second_order_products(expr, search_symbols=[x, y], positive=False) assert result == expected_expr_negative assert (result - expected_expr_negative).simplify() == 0 result = replace_second_order_products(expr, search_symbols=[x, y], positive=None) assert result == expected_expr_positive a = [Assignment(sympy.symbols('z'), x + y)] replace_second_order_products(expr, search_symbols=[x, y], positive=True, replace_mixed=a) assert len(a) == 2 def test_remove_higher_order_terms(): x, y = sympy.symbols('x y') expr = sympy.Mul(x, y) result = remove_higher_order_terms(expr, order=1, symbols=[x, y]) assert result == 0 result = remove_higher_order_terms(expr, order=2, symbols=[x, y]) assert result == expr expr = sympy.Pow(x, 3) result = remove_higher_order_terms(expr, order=2, symbols=[x, y]) assert result == 0 result = remove_higher_order_terms(expr, order=3, symbols=[x, y]) assert result == expr def test_complete_the_squares_in_exp(): a, b, c, s, n = sympy.symbols('a b c s n') expr = a * s ** 2 + b * s + c result = complete_the_squares_in_exp(expr, symbols_to_complete=[s]) assert result == expr expr = sympy.exp(a * s ** 2 + b * s + c) expected_result = sympy.exp(a*s**2 + c - b**2 / (4*a)) result = complete_the_squares_in_exp(expr, symbols_to_complete=[s]) assert result == expected_result def test_extract_most_common_factor(): x, y = sympy.symbols('x y') expr = 1 / (x + y) + 3 / (x + y) + 3 / (x + y) most_common_factor = extract_most_common_factor(expr) assert most_common_factor[0] == 7 assert sympy.prod(most_common_factor) == expr expr = 1 / x + 3 / (x + y) + 3 / y most_common_factor = extract_most_common_factor(expr) assert most_common_factor[0] == 3 assert sympy.prod(most_common_factor) == expr expr = 1 / x most_common_factor = extract_most_common_factor(expr) assert most_common_factor[0] == 1 assert sympy.prod(most_common_factor) == expr assert most_common_factor[1] == expr def test_count_operations(): x, y, z = sympy.symbols('x y z') expr = 1/x + y * sympy.sqrt(z) ops = count_operations(expr, only_type=None) assert ops['adds'] == 1 assert ops['muls'] == 1 assert ops['divs'] == 1 assert ops['sqrts'] == 1 expr = sympy.sqrt(x + y) expr = insert_fast_sqrts(expr).atoms(fast_sqrt) ops = count_operations(*expr, only_type=None) assert ops['fast_sqrts'] == 1 expr = sympy.sqrt(x / y) expr = insert_fast_divisions(expr).atoms(fast_division) ops = count_operations(*expr, only_type=None) assert ops['fast_div'] == 1 expr = pystencils.Assignment(sympy.Symbol('tmp'), 3 / sympy.sqrt(x + y)) expr = insert_fast_sqrts(expr).atoms(fast_inv_sqrt) ops = count_operations(*expr, only_type=None) assert ops['fast_inv_sqrts'] == 1 expr = sympy.Piecewise((1.0, x > 0), (0.0, True)) + y * z ops = count_operations(expr, only_type=None) assert ops['adds'] == 1 expr = sympy.Pow(1/x + y * sympy.sqrt(z), 100) ops = count_operations(expr, only_type=None) assert ops['adds'] == 1 assert ops['muls'] == 100 assert ops['divs'] == 1 assert ops['sqrts'] == 1 def test_common_denominator(): x = sympy.symbols('x') expr = sympy.Rational(1, 2) + x * sympy.Rational(2, 3) cm = common_denominator(expr) assert cm == 6 def test_get_symmetric_part(): x, y, z = sympy.symbols('x y z') expr = x / 9 - y ** 2 / 6 + z ** 2 / 3 + z / 3 expected_result = x / 9 - y ** 2 / 6 + z ** 2 / 3 sym_part = get_symmetric_part(expr, sympy.symbols(f'y z')) assert sym_part == expected_result
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!