Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from pystencils import fields, Assignment, AssignmentCollection
from pystencils.simp.subexpression_insertion import *
def test_subexpression_insertion():
f, g = fields('f(10), g(10) : [2D]')
xi = sp.symbols('xi_:10')
xi_set = set(xi)
subexpressions = [
Assignment(xi[0], -f(4)),
Assignment(xi[1], -(f(1) * f(2))),
Assignment(xi[2], 2.31 * f(5)),
Assignment(xi[3], 1.8 + f(5) + f(6)),
Assignment(xi[4], 5.7 + f(6)),
Assignment(xi[5], (f(4) + f(5))**2),
Assignment(xi[6], f(3)**2),
Assignment(xi[7], f(4)),
Assignment(xi[8], 13),
Assignment(xi[9], 0),
]
assignments = [Assignment(g(i), x) for i, x in enumerate(xi)]
ac = AssignmentCollection(assignments, subexpressions=subexpressions)
ac_ins = insert_symbol_times_minus_one(ac)
assert (ac_ins.bound_symbols & xi_set) == (xi_set - {xi[0]})
ac_ins = insert_constant_multiples(ac)
assert (ac_ins.bound_symbols & xi_set) == (xi_set - {xi[0], xi[2]})
ac_ins = insert_constant_additions(ac)
assert (ac_ins.bound_symbols & xi_set) == (xi_set - {xi[4]})
ac_ins = insert_squares(ac)
assert (ac_ins.bound_symbols & xi_set) == (xi_set - {xi[6]})
ac_ins = insert_aliases(ac)
assert (ac_ins.bound_symbols & xi_set) == (xi_set - {xi[7]})
ac_ins = insert_zeros(ac)
assert (ac_ins.bound_symbols & xi_set) == (xi_set - {xi[9]})
ac_ins = insert_constants(ac, skip={xi[9]})
assert (ac_ins.bound_symbols & xi_set) == (xi_set - {xi[8]})