Newer
Older
import pystencils
from pystencils.sympyextensions import replace_second_order_products
from pystencils.sympyextensions import remove_higher_order_terms
from pystencils.sympyextensions import complete_the_squares_in_exp
from pystencils.sympyextensions import extract_most_common_factor
from pystencils.sympyextensions import simplify_by_equality
from pystencils.sympyextensions import count_operations
from pystencils.sympyextensions import common_denominator
from pystencils.sympyextensions import get_symmetric_part
from pystencils.sympyextensions import scalar_product
from pystencils.sympyextensions import kronecker_delta
from pystencils import Assignment
from pystencils.fast_approximation import (fast_division, fast_inv_sqrt, fast_sqrt,
insert_fast_divisions, insert_fast_sqrts)
def test_utility():
a = [1, 2]
b = (2, 3)
a_np = np.array(a)
b_np = np.array(b)
assert scalar_product(a, b) == np.dot(a_np, b_np)
a = sympy.Symbol("a")
b = sympy.Symbol("b")
assert kronecker_delta(a, a, a, b) == 0
assert kronecker_delta(a, a, a, a) == 1
assert kronecker_delta(3, 3, 3, 2) == 0
assert kronecker_delta(2, 2, 2, 2) == 1
assert kronecker_delta([10] * 100) == 1
assert kronecker_delta((0, 1), (0, 1)) == 1
def test_replace_second_order_products():
x, y = sympy.symbols('x y')
expr = 4 * x * y
expected_expr_positive = 2 * ((x + y) ** 2 - x ** 2 - y ** 2)
expected_expr_negative = 2 * (-(x - y) ** 2 + x ** 2 + y ** 2)
result = replace_second_order_products(expr, search_symbols=[x, y], positive=True)
assert result == expected_expr_positive
assert (result - expected_expr_positive).simplify() == 0
result = replace_second_order_products(expr, search_symbols=[x, y], positive=False)
assert result == expected_expr_negative
assert (result - expected_expr_negative).simplify() == 0
result = replace_second_order_products(expr, search_symbols=[x, y], positive=None)
assert result == expected_expr_positive
a = [Assignment(sympy.symbols('z'), x + y)]
replace_second_order_products(expr, search_symbols=[x, y], positive=True, replace_mixed=a)
assert len(a) == 2
assert replace_second_order_products(4 + y, search_symbols=[x, y]) == y + 4
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def test_remove_higher_order_terms():
x, y = sympy.symbols('x y')
expr = sympy.Mul(x, y)
result = remove_higher_order_terms(expr, order=1, symbols=[x, y])
assert result == 0
result = remove_higher_order_terms(expr, order=2, symbols=[x, y])
assert result == expr
expr = sympy.Pow(x, 3)
result = remove_higher_order_terms(expr, order=2, symbols=[x, y])
assert result == 0
result = remove_higher_order_terms(expr, order=3, symbols=[x, y])
assert result == expr
def test_complete_the_squares_in_exp():
a, b, c, s, n = sympy.symbols('a b c s n')
expr = a * s ** 2 + b * s + c
result = complete_the_squares_in_exp(expr, symbols_to_complete=[s])
assert result == expr
expr = sympy.exp(a * s ** 2 + b * s + c)
expected_result = sympy.exp(a*s**2 + c - b**2 / (4*a))
result = complete_the_squares_in_exp(expr, symbols_to_complete=[s])
assert result == expected_result
def test_extract_most_common_factor():
x, y = sympy.symbols('x y')
expr = 1 / (x + y) + 3 / (x + y) + 3 / (x + y)
most_common_factor = extract_most_common_factor(expr)
assert most_common_factor[0] == 7
assert sympy.prod(most_common_factor) == expr
expr = 1 / x + 3 / (x + y) + 3 / y
most_common_factor = extract_most_common_factor(expr)
assert most_common_factor[0] == 3
assert sympy.prod(most_common_factor) == expr
expr = 1 / x
most_common_factor = extract_most_common_factor(expr)
assert most_common_factor[0] == 1
assert sympy.prod(most_common_factor) == expr
assert most_common_factor[1] == expr
def test_count_operations():
x, y, z = sympy.symbols('x y z')
expr = 1/x + y * sympy.sqrt(z)
ops = count_operations(expr, only_type=None)
assert ops['adds'] == 1
assert ops['muls'] == 1
assert ops['divs'] == 1
assert ops['sqrts'] == 1
expr = 1 / sympy.sqrt(z)
ops = count_operations(expr, only_type=None)
assert ops['adds'] == 0
assert ops['muls'] == 0
assert ops['divs'] == 1
assert ops['sqrts'] == 1
expr = sympy.Rel(1 / sympy.sqrt(z), 5)
ops = count_operations(expr, only_type=None)
assert ops['adds'] == 0
assert ops['muls'] == 0
assert ops['divs'] == 1
assert ops['sqrts'] == 1
expr = sympy.sqrt(x + y)
expr = insert_fast_sqrts(expr).atoms(fast_sqrt)
ops = count_operations(*expr, only_type=None)
assert ops['fast_sqrts'] == 1
expr = sympy.sqrt(x / y)
expr = insert_fast_divisions(expr).atoms(fast_division)
ops = count_operations(*expr, only_type=None)
assert ops['fast_div'] == 1
expr = pystencils.Assignment(sympy.Symbol('tmp'), 3 / sympy.sqrt(x + y))
expr = insert_fast_sqrts(expr).atoms(fast_inv_sqrt)
ops = count_operations(*expr, only_type=None)
assert ops['fast_inv_sqrts'] == 1
expr = sympy.Piecewise((1.0, x > 0), (0.0, True)) + y * z
ops = count_operations(expr, only_type=None)
assert ops['adds'] == 1
expr = sympy.Pow(1/x + y * sympy.sqrt(z), 100)
ops = count_operations(expr, only_type=None)
assert ops['adds'] == 1
assert ops['divs'] == 1
assert ops['sqrts'] == 1
expr = DivFunc(x, y)
ops = count_operations(expr, only_type=None)
assert ops['divs'] == 1
expr = DivFunc(x + z, y + z)
ops = count_operations(expr, only_type=None)
assert ops['adds'] == 2
assert ops['divs'] == 1
expr = sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False))
ops = count_operations(expr, only_type=None)
assert ops['muls'] == 99
expr = DivFunc(1, sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False)))
ops = count_operations(expr, only_type=None)
assert ops['divs'] == 1
assert ops['muls'] == 99
expr = DivFunc(y + z, sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False)))
ops = count_operations(expr, only_type=None)
assert ops['adds'] == 1
assert ops['divs'] == 1
assert ops['muls'] == 99
def test_common_denominator():
x = sympy.symbols('x')
expr = sympy.Rational(1, 2) + x * sympy.Rational(2, 3)
cm = common_denominator(expr)
assert cm == 6
def test_get_symmetric_part():
x, y, z = sympy.symbols('x y z')
expr = x / 9 - y ** 2 / 6 + z ** 2 / 3 + z / 3
expected_result = x / 9 - y ** 2 / 6 + z ** 2 / 3
sym_part = get_symmetric_part(expr, sympy.symbols(f'y z'))
assert sym_part == expected_result
def test_simplify_by_equality():
x, y, z = sp.symbols('x, y, z')
p, q = sp.symbols('p, q')
# Let x = y + z
expr = x * p - y * p + z * q
expr = simplify_by_equality(expr, x, y, z)
assert expr == z * p + z * q
expr = x * (p - 2 * q) + 2 * q * z
expr = simplify_by_equality(expr, x, y, z)
assert expr == x * p - 2 * q * y
expr = x * (y + z) - y * z
expr = simplify_by_equality(expr, x, y, z)
assert expr == x*y + z**2
# Let x = y + 2
expr = x * p - 2 * p
expr = simplify_by_equality(expr, x, y, 2)
assert expr == y * p