Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Jan Hönig
pystencils
Commits
b70ac70a
Commit
b70ac70a
authored
Jul 28, 2021
by
Markus Holzer
Browse files
Merge branch 'subexpression_insertion' into 'master'
Advanced Subexpression Insertion See merge request
pycodegen/pystencils!258
parents
2464ef8e
a384e104
Pipeline
#33790
passed with stages
in 13 minutes and 50 seconds
Changes
3
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
pystencils/simp/__init__.py
View file @
b70ac70a
...
...
@@ -5,10 +5,17 @@ from .simplifications import (
add_subexpressions_for_sums
,
apply_on_all_subexpressions
,
apply_to_all_assignments
,
subexpression_substitution_in_existing_subexpressions
,
subexpression_substitution_in_main_assignments
,
sympy_cse
,
sympy_cse_on_assignment_list
)
from
.subexpression_insertion
import
(
insert_aliases
,
insert_zeros
,
insert_constants
,
insert_constant_additions
,
insert_constant_multiples
,
insert_squares
,
insert_symbol_times_minus_one
)
from
.simplificationstrategy
import
SimplificationStrategy
__all__
=
[
'AssignmentCollection'
,
'SimplificationStrategy'
,
'sympy_cse'
,
'sympy_cse_on_assignment_list'
,
'apply_to_all_assignments'
,
'apply_on_all_subexpressions'
,
'subexpression_substitution_in_existing_subexpressions'
,
'subexpression_substitution_in_main_assignments'
,
'add_subexpressions_for_constants'
,
'add_subexpressions_for_divisions'
,
'add_subexpressions_for_sums'
,
'add_subexpressions_for_field_reads'
]
'add_subexpressions_for_divisions'
,
'add_subexpressions_for_sums'
,
'add_subexpressions_for_field_reads'
,
'insert_aliases'
,
'insert_zeros'
,
'insert_constants'
,
'insert_constant_additions'
,
'insert_constant_multiples'
,
'insert_squares'
,
'insert_symbol_times_minus_one'
]
pystencils/simp/subexpression_insertion.py
0 → 100644
View file @
b70ac70a
import
sympy
as
sp
from
pystencils.sympyextensions
import
is_constant
# Subexpression Insertion
def
insert_subexpressions
(
ac
,
selection_callback
,
skip
=
set
()):
"""
Removes a number of subexpressions from an assignment collection by
inserting their right-hand side wherever they occur.
Args:
- selection_callback: Function that is called to qualify subexpressions
for insertion. Should return `True` for any subexpression that is to be
inserted, and `False` otherwise.
- skip: Set of symbols (left-hand sides of subexpressions) that should be
ignored even if qualified by the callback.
"""
i
=
0
while
i
<
len
(
ac
.
subexpressions
):
exp
=
ac
.
subexpressions
[
i
]
if
exp
.
lhs
not
in
skip
and
selection_callback
(
exp
):
ac
=
ac
.
new_with_inserted_subexpression
(
exp
.
lhs
)
else
:
i
+=
1
return
ac
def
insert_aliases
(
ac
,
**
kwargs
):
"""Inserts subexpressions that are aliases of other symbols,
i.e. their right-hand side is only another symbol."""
return
insert_subexpressions
(
ac
,
lambda
x
:
isinstance
(
x
.
rhs
,
sp
.
Symbol
),
**
kwargs
)
def
insert_zeros
(
ac
,
**
kwargs
):
"""Inserts subexpressions whose right-hand side is zero."""
zero
=
sp
.
Integer
(
0
)
return
insert_subexpressions
(
ac
,
lambda
x
:
x
.
rhs
==
zero
,
**
kwargs
)
def
insert_constants
(
ac
,
**
kwargs
):
"""Inserts subexpressions whose right-hand side is constant,
i.e. contains no symbols."""
return
insert_subexpressions
(
ac
,
lambda
x
:
is_constant
(
x
.
rhs
),
**
kwargs
)
def
insert_symbol_times_minus_one
(
ac
,
**
kwargs
):
"""Inserts subexpressions whose right-hand side is just a
negation of another symbol."""
def
callback
(
exp
):
rhs
=
exp
.
rhs
minus_one
=
sp
.
Integer
(
-
1
)
atoms
=
rhs
.
atoms
(
sp
.
Symbol
)
return
len
(
atoms
)
==
1
and
rhs
==
minus_one
*
atoms
.
pop
()
return
insert_subexpressions
(
ac
,
callback
,
**
kwargs
)
def
insert_constant_multiples
(
ac
,
**
kwargs
):
"""Inserts subexpressions whose right-hand side is a constant
multiplied with another symbol."""
def
callback
(
exp
):
rhs
=
exp
.
rhs
symbols
=
rhs
.
atoms
(
sp
.
Symbol
)
numbers
=
rhs
.
atoms
(
sp
.
Number
)
return
len
(
symbols
)
==
1
and
len
(
numbers
)
==
1
and
\
rhs
==
numbers
.
pop
()
*
symbols
.
pop
()
return
insert_subexpressions
(
ac
,
callback
,
**
kwargs
)
def
insert_constant_additions
(
ac
,
**
kwargs
):
"""Inserts subexpressions whose right-hand side is a sum of a
constant and another symbol."""
def
callback
(
exp
):
rhs
=
exp
.
rhs
symbols
=
rhs
.
atoms
(
sp
.
Symbol
)
numbers
=
rhs
.
atoms
(
sp
.
Number
)
return
len
(
symbols
)
==
1
and
len
(
numbers
)
==
1
and
\
rhs
==
numbers
.
pop
()
+
symbols
.
pop
()
return
insert_subexpressions
(
ac
,
callback
,
**
kwargs
)
def
insert_squares
(
ac
,
**
kwargs
):
"""Inserts subexpressions whose right-hand side is another symbol squared."""
def
callback
(
exp
):
rhs
=
exp
.
rhs
symbols
=
rhs
.
atoms
(
sp
.
Symbol
)
return
len
(
symbols
)
==
1
and
rhs
==
symbols
.
pop
()
**
2
return
insert_subexpressions
(
ac
,
callback
,
**
kwargs
)
def
bind_symbols_to_skip
(
insertion_function
,
skip
):
return
lambda
ac
:
insertion_function
(
ac
,
skip
=
skip
)
pystencils_tests/test_subexpression_insertion.py
0 → 100644
View file @
b70ac70a
import
sympy
as
sp
from
pystencils
import
fields
,
Assignment
,
AssignmentCollection
from
pystencils.simp.subexpression_insertion
import
*
def
test_subexpression_insertion
():
f
,
g
=
fields
(
'f(10), g(10) : [2D]'
)
xi
=
sp
.
symbols
(
'xi_:10'
)
xi_set
=
set
(
xi
)
subexpressions
=
[
Assignment
(
xi
[
0
],
-
f
(
4
)),
Assignment
(
xi
[
1
],
-
(
f
(
1
)
*
f
(
2
))),
Assignment
(
xi
[
2
],
2.31
*
f
(
5
)),
Assignment
(
xi
[
3
],
1.8
+
f
(
5
)
+
f
(
6
)),
Assignment
(
xi
[
4
],
5.7
+
f
(
6
)),
Assignment
(
xi
[
5
],
(
f
(
4
)
+
f
(
5
))
**
2
),
Assignment
(
xi
[
6
],
f
(
3
)
**
2
),
Assignment
(
xi
[
7
],
f
(
4
)),
Assignment
(
xi
[
8
],
13
),
Assignment
(
xi
[
9
],
0
),
]
assignments
=
[
Assignment
(
g
(
i
),
x
)
for
i
,
x
in
enumerate
(
xi
)]
ac
=
AssignmentCollection
(
assignments
,
subexpressions
=
subexpressions
)
ac_ins
=
insert_symbol_times_minus_one
(
ac
)
assert
(
ac_ins
.
bound_symbols
&
xi_set
)
==
(
xi_set
-
{
xi
[
0
]})
ac_ins
=
insert_constant_multiples
(
ac
)
assert
(
ac_ins
.
bound_symbols
&
xi_set
)
==
(
xi_set
-
{
xi
[
0
],
xi
[
2
]})
ac_ins
=
insert_constant_additions
(
ac
)
assert
(
ac_ins
.
bound_symbols
&
xi_set
)
==
(
xi_set
-
{
xi
[
4
]})
ac_ins
=
insert_squares
(
ac
)
assert
(
ac_ins
.
bound_symbols
&
xi_set
)
==
(
xi_set
-
{
xi
[
6
]})
ac_ins
=
insert_aliases
(
ac
)
assert
(
ac_ins
.
bound_symbols
&
xi_set
)
==
(
xi_set
-
{
xi
[
7
]})
ac_ins
=
insert_zeros
(
ac
)
assert
(
ac_ins
.
bound_symbols
&
xi_set
)
==
(
xi_set
-
{
xi
[
9
]})
ac_ins
=
insert_constants
(
ac
,
skip
=
{
xi
[
9
]})
assert
(
ac_ins
.
bound_symbols
&
xi_set
)
==
(
xi_set
-
{
xi
[
8
]})
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment