Skip to content
Snippets Groups Projects
test_sympyextensions.py 7.23 KiB
Newer Older
import sympy
Markus Holzer's avatar
Markus Holzer committed
import numpy as np
import sympy as sp
import pystencils

from pystencils.sympyextensions import replace_second_order_products
from pystencils.sympyextensions import remove_higher_order_terms
from pystencils.sympyextensions import complete_the_squares_in_exp
from pystencils.sympyextensions import extract_most_common_factor
from pystencils.sympyextensions import simplify_by_equality
from pystencils.sympyextensions import count_operations
from pystencils.sympyextensions import common_denominator
from pystencils.sympyextensions import get_symmetric_part
Markus Holzer's avatar
Markus Holzer committed
from pystencils.sympyextensions import scalar_product
from pystencils.sympyextensions import kronecker_delta

from pystencils import Assignment
from pystencils.functions import DivFunc
from pystencils.fast_approximation import (fast_division, fast_inv_sqrt, fast_sqrt,
                                           insert_fast_divisions, insert_fast_sqrts)


Markus Holzer's avatar
Markus Holzer committed
def test_utility():
    a = [1, 2]
    b = (2, 3)

    a_np = np.array(a)
    b_np = np.array(b)
    assert scalar_product(a, b) == np.dot(a_np, b_np)

    a = sympy.Symbol("a")
    b = sympy.Symbol("b")

    assert kronecker_delta(a, a, a, b) == 0
    assert kronecker_delta(a, a, a, a) == 1
    assert kronecker_delta(3, 3, 3, 2) == 0
    assert kronecker_delta(2, 2, 2, 2) == 1
    assert kronecker_delta([10] * 100) == 1
    assert kronecker_delta((0, 1), (0, 1)) == 1


def test_replace_second_order_products():
    x, y = sympy.symbols('x y')
    expr = 4 * x * y
    expected_expr_positive = 2 * ((x + y) ** 2 - x ** 2 - y ** 2)
    expected_expr_negative = 2 * (-(x - y) ** 2 + x ** 2 + y ** 2)

    result = replace_second_order_products(expr, search_symbols=[x, y], positive=True)
    assert result == expected_expr_positive
    assert (result - expected_expr_positive).simplify() == 0

    result = replace_second_order_products(expr, search_symbols=[x, y], positive=False)
    assert result == expected_expr_negative
    assert (result - expected_expr_negative).simplify() == 0

    result = replace_second_order_products(expr, search_symbols=[x, y], positive=None)
    assert result == expected_expr_positive

    a = [Assignment(sympy.symbols('z'), x + y)]
    replace_second_order_products(expr, search_symbols=[x, y], positive=True, replace_mixed=a)
    assert len(a) == 2

Markus Holzer's avatar
Markus Holzer committed
    assert replace_second_order_products(4 + y, search_symbols=[x, y]) == y + 4


def test_remove_higher_order_terms():
    x, y = sympy.symbols('x y')

    expr = sympy.Mul(x, y)

    result = remove_higher_order_terms(expr, order=1, symbols=[x, y])
    assert result == 0
    result = remove_higher_order_terms(expr, order=2, symbols=[x, y])
    assert result == expr

    expr = sympy.Pow(x, 3)

    result = remove_higher_order_terms(expr, order=2, symbols=[x, y])
    assert result == 0
    result = remove_higher_order_terms(expr, order=3, symbols=[x, y])
    assert result == expr


def test_complete_the_squares_in_exp():
    a, b, c, s, n = sympy.symbols('a b c s n')
    expr = a * s ** 2 + b * s + c
    result = complete_the_squares_in_exp(expr, symbols_to_complete=[s])
    assert result == expr

    expr = sympy.exp(a * s ** 2 + b * s + c)
    expected_result = sympy.exp(a*s**2 + c - b**2 / (4*a))
    result = complete_the_squares_in_exp(expr, symbols_to_complete=[s])
    assert result == expected_result


def test_extract_most_common_factor():
    x, y = sympy.symbols('x y')
    expr = 1 / (x + y) + 3 / (x + y) + 3 / (x + y)
    most_common_factor = extract_most_common_factor(expr)

    assert most_common_factor[0] == 7
    assert sympy.prod(most_common_factor) == expr

    expr = 1 / x + 3 / (x + y) + 3 / y
    most_common_factor = extract_most_common_factor(expr)

    assert most_common_factor[0] == 3
    assert sympy.prod(most_common_factor) == expr

    expr = 1 / x
    most_common_factor = extract_most_common_factor(expr)

    assert most_common_factor[0] == 1
    assert sympy.prod(most_common_factor) == expr
    assert most_common_factor[1] == expr


def test_count_operations():
    x, y, z = sympy.symbols('x y z')
    expr = 1/x + y * sympy.sqrt(z)
    ops = count_operations(expr, only_type=None)
    assert ops['adds'] == 1
    assert ops['muls'] == 1
    assert ops['divs'] == 1
    assert ops['sqrts'] == 1

Markus Holzer's avatar
Markus Holzer committed
    expr = 1 / sympy.sqrt(z)
    ops = count_operations(expr, only_type=None)
    assert ops['adds'] == 0
    assert ops['muls'] == 0
    assert ops['divs'] == 1
    assert ops['sqrts'] == 1

    expr = sympy.Rel(1 / sympy.sqrt(z), 5)
    ops = count_operations(expr, only_type=None)
    assert ops['adds'] == 0
    assert ops['muls'] == 0
    assert ops['divs'] == 1
    assert ops['sqrts'] == 1

    expr = sympy.sqrt(x + y)
    expr = insert_fast_sqrts(expr).atoms(fast_sqrt)
    ops = count_operations(*expr, only_type=None)
    assert ops['fast_sqrts'] == 1

    expr = sympy.sqrt(x / y)
    expr = insert_fast_divisions(expr).atoms(fast_division)
    ops = count_operations(*expr, only_type=None)
    assert ops['fast_div'] == 1

    expr = pystencils.Assignment(sympy.Symbol('tmp'), 3 / sympy.sqrt(x + y))
    expr = insert_fast_sqrts(expr).atoms(fast_inv_sqrt)
    ops = count_operations(*expr, only_type=None)
    assert ops['fast_inv_sqrts'] == 1

    expr = sympy.Piecewise((1.0, x > 0), (0.0, True)) + y * z
    ops = count_operations(expr, only_type=None)
    assert ops['adds'] == 1

    expr = sympy.Pow(1/x + y * sympy.sqrt(z), 100)
    ops = count_operations(expr, only_type=None)
    assert ops['adds'] == 1
Markus Holzer's avatar
Markus Holzer committed
    assert ops['muls'] == 99
    assert ops['divs'] == 1
    assert ops['sqrts'] == 1

    expr = DivFunc(x, y)
    ops = count_operations(expr, only_type=None)
    assert ops['divs'] == 1

    expr = DivFunc(x + z, y + z)
    ops = count_operations(expr, only_type=None)
    assert ops['adds'] == 2
    assert ops['divs'] == 1

    expr = sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False))
    ops = count_operations(expr, only_type=None)
    assert ops['muls'] == 99

    expr = DivFunc(1, sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False)))
    ops = count_operations(expr, only_type=None)
    assert ops['divs'] == 1
    assert ops['muls'] == 99

    expr = DivFunc(y + z, sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False)))
    ops = count_operations(expr, only_type=None)
    assert ops['adds'] == 1
    assert ops['divs'] == 1
    assert ops['muls'] == 99


def test_common_denominator():
    x = sympy.symbols('x')
    expr = sympy.Rational(1, 2) + x * sympy.Rational(2, 3)
    cm = common_denominator(expr)
    assert cm == 6


def test_get_symmetric_part():
    x, y, z = sympy.symbols('x y z')
    expr = x / 9 - y ** 2 / 6 + z ** 2 / 3 + z / 3
    expected_result = x / 9 - y ** 2 / 6 + z ** 2 / 3
    sym_part = get_symmetric_part(expr, sympy.symbols(f'y z'))

    assert sym_part == expected_result


def test_simplify_by_equality():
    x, y, z = sp.symbols('x, y, z')
    p, q = sp.symbols('p, q')

    #   Let x = y + z
    expr = x * p - y * p + z * q
    expr = simplify_by_equality(expr, x, y, z)
    assert expr == z * p + z * q

    expr = x * (p - 2 * q) + 2 * q * z
    expr = simplify_by_equality(expr, x, y, z)
    assert expr == x * p - 2 * q * y

    expr = x * (y + z) - y * z
    expr = simplify_by_equality(expr, x, y, z)
    assert expr == x*y + z**2

    #   Let x = y + 2
    expr = x * p - 2 * p
    expr = simplify_by_equality(expr, x, y, 2)
    assert expr == y * p