diff --git a/pystencils/sympyextensions.py b/pystencils/sympyextensions.py index cd9519d0668f842879ae57eafb1a5aec34878a2c..55fe74967f6deedeac45233dc198f562d072b485 100644 --- a/pystencils/sympyextensions.py +++ b/pystencils/sympyextensions.py @@ -272,7 +272,7 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr, def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Symbol], positive: Optional[bool] = None, replace_mixed: Optional[List[Assignment]] = None) -> sp.Expr: - """Replaces second order mixed terms like x*y by 2*( (x+y)**2 - x**2 - y**2 ). + """Replaces second order mixed terms like 4*x*y by 2*( (x+y)**2 - x**2 - y**2 ). This makes the term longer - simplify usually is undoing these - however this transformation can be done to find more common sub-expressions @@ -293,7 +293,7 @@ def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Sym if expr.is_Mul: distinct_search_symbols = set() nr_of_search_terms = 0 - other_factors = 1 + other_factors = sp.Integer(1) for t in expr.args: if t in search_symbols: nr_of_search_terms += 1 @@ -481,7 +481,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], pass elif t.func is sp.Mul: if check_type(t): - result['muls'] += len(t.args) - 1 + result['muls'] += len(t.args) for a in t.args: if a == 1 or a == -1: result['muls'] -= 1 @@ -515,7 +515,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], elif sp.nsimplify(t.exp) == sp.Rational(1, 2): result['sqrts'] += 1 else: - warnings.warn("Cannot handle exponent", t.exp, " of sp.Pow node") + warnings.warn(f"Cannot handle exponent {t.exp} of sp.Pow node") else: warnings.warn("Counting operations: only integer exponents are supported in Pow, " "counting will be inaccurate") @@ -526,7 +526,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], elif isinstance(t, sp.Rel): pass else: - warnings.warn("Unknown sympy node of type " + str(t.func) + " counting will be inaccurate") + warnings.warn(f"Unknown sympy node of type {str(t.func)} counting will be inaccurate") if visit_children: for a in t.args: diff --git a/pystencils_tests/test_sympyextensions.py b/pystencils_tests/test_sympyextensions.py new file mode 100644 index 0000000000000000000000000000000000000000..1636df632670a4a909321999067389bdaaa56a62 --- /dev/null +++ b/pystencils_tests/test_sympyextensions.py @@ -0,0 +1,140 @@ +import sympy +import pystencils + +from pystencils.sympyextensions import replace_second_order_products +from pystencils.sympyextensions import remove_higher_order_terms +from pystencils.sympyextensions import complete_the_squares_in_exp +from pystencils.sympyextensions import extract_most_common_factor +from pystencils.sympyextensions import count_operations +from pystencils.sympyextensions import common_denominator +from pystencils.sympyextensions import get_symmetric_part + +from pystencils import Assignment +from pystencils.fast_approximation import (fast_division, fast_inv_sqrt, fast_sqrt, + insert_fast_divisions, insert_fast_sqrts) + + +def test_replace_second_order_products(): + x, y = sympy.symbols('x y') + expr = 4 * x * y + expected_expr_positive = 2 * ((x + y) ** 2 - x ** 2 - y ** 2) + expected_expr_negative = 2 * (-(x - y) ** 2 + x ** 2 + y ** 2) + + result = replace_second_order_products(expr, search_symbols=[x, y], positive=True) + assert result == expected_expr_positive + assert (result - expected_expr_positive).simplify() == 0 + + result = replace_second_order_products(expr, search_symbols=[x, y], positive=False) + assert result == expected_expr_negative + assert (result - expected_expr_negative).simplify() == 0 + + result = replace_second_order_products(expr, search_symbols=[x, y], positive=None) + assert result == expected_expr_positive + + a = [Assignment(sympy.symbols('z'), x + y)] + replace_second_order_products(expr, search_symbols=[x, y], positive=True, replace_mixed=a) + assert len(a) == 2 + + +def test_remove_higher_order_terms(): + x, y = sympy.symbols('x y') + + expr = sympy.Mul(x, y) + + result = remove_higher_order_terms(expr, order=1, symbols=[x, y]) + assert result == 0 + result = remove_higher_order_terms(expr, order=2, symbols=[x, y]) + assert result == expr + + expr = sympy.Pow(x, 3) + + result = remove_higher_order_terms(expr, order=2, symbols=[x, y]) + assert result == 0 + result = remove_higher_order_terms(expr, order=3, symbols=[x, y]) + assert result == expr + + +def test_complete_the_squares_in_exp(): + a, b, c, s, n = sympy.symbols('a b c s n') + expr = a * s ** 2 + b * s + c + result = complete_the_squares_in_exp(expr, symbols_to_complete=[s]) + assert result == expr + + expr = sympy.exp(a * s ** 2 + b * s + c) + expected_result = sympy.exp(a*s**2 + c - b**2 / (4*a)) + result = complete_the_squares_in_exp(expr, symbols_to_complete=[s]) + assert result == expected_result + + +def test_extract_most_common_factor(): + x, y = sympy.symbols('x y') + expr = 1 / (x + y) + 3 / (x + y) + 3 / (x + y) + most_common_factor = extract_most_common_factor(expr) + + assert most_common_factor[0] == 7 + assert sympy.prod(most_common_factor) == expr + + expr = 1 / x + 3 / (x + y) + 3 / y + most_common_factor = extract_most_common_factor(expr) + + assert most_common_factor[0] == 3 + assert sympy.prod(most_common_factor) == expr + + expr = 1 / x + most_common_factor = extract_most_common_factor(expr) + + assert most_common_factor[0] == 1 + assert sympy.prod(most_common_factor) == expr + assert most_common_factor[1] == expr + + +def test_count_operations(): + x, y, z = sympy.symbols('x y z') + expr = 1/x + y * sympy.sqrt(z) + ops = count_operations(expr, only_type=None) + assert ops['adds'] == 1 + assert ops['muls'] == 1 + assert ops['divs'] == 1 + assert ops['sqrts'] == 1 + + expr = sympy.sqrt(x + y) + expr = insert_fast_sqrts(expr).atoms(fast_sqrt) + ops = count_operations(*expr, only_type=None) + assert ops['fast_sqrts'] == 1 + + expr = sympy.sqrt(x / y) + expr = insert_fast_divisions(expr).atoms(fast_division) + ops = count_operations(*expr, only_type=None) + assert ops['fast_div'] == 1 + + expr = pystencils.Assignment(sympy.Symbol('tmp'), 3 / sympy.sqrt(x + y)) + expr = insert_fast_sqrts(expr).atoms(fast_inv_sqrt) + ops = count_operations(*expr, only_type=None) + assert ops['fast_inv_sqrts'] == 1 + + expr = sympy.Piecewise((1.0, x > 0), (0.0, True)) + y * z + ops = count_operations(expr, only_type=None) + assert ops['adds'] == 1 + + expr = sympy.Pow(1/x + y * sympy.sqrt(z), 100) + ops = count_operations(expr, only_type=None) + assert ops['adds'] == 1 + assert ops['muls'] == 100 + assert ops['divs'] == 1 + assert ops['sqrts'] == 1 + + +def test_common_denominator(): + x = sympy.symbols('x') + expr = sympy.Rational(1, 2) + x * sympy.Rational(2, 3) + cm = common_denominator(expr) + assert cm == 6 + + +def test_get_symmetric_part(): + x, y, z = sympy.symbols('x y z') + expr = x / 9 - y ** 2 / 6 + z ** 2 / 3 + z / 3 + expected_result = x / 9 - y ** 2 / 6 + z ** 2 / 3 + sym_part = get_symmetric_part(expr, sympy.symbols(f'y z')) + + assert sym_part == expected_result