Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Tom Harke
pystencils
Commits
ef924b18
Commit
ef924b18
authored
Mar 31, 2018
by
Martin Bauer
Browse files
Code Cleanup
- assignment collection - sympyextensions
parent
c43672d2
Changes
13
Expand all
Hide whitespace changes
Inline
Side-by-side
__init__.py
View file @
ef924b18
...
...
@@ -5,6 +5,7 @@ from pystencils.kernelcreation import createKernel, createIndexedKernel
from
pystencils.display_utils
import
showCode
,
toDot
from
pystencils.assignment_collection
import
AssignmentCollection
from
pystencils.assignment
import
Assignment
from
pystencils.sympyextensions
import
SymbolCreator
__all__
=
[
'Field'
,
'FieldType'
,
'extractCommonSubexpressions'
,
'TypedSymbol'
,
...
...
@@ -12,4 +13,5 @@ __all__ = ['Field', 'FieldType', 'extractCommonSubexpressions',
'createKernel'
,
'createIndexedKernel'
,
'showCode'
,
'toDot'
,
'AssignmentCollection'
,
'Assignment'
]
'Assignment'
,
'SymbolCreator'
]
assignment_collection/assignment_collection.py
View file @
ef924b18
This diff is collapsed.
Click to expand it.
assignment_collection/simplifications.py
View file @
ef924b18
import
sympy
as
sp
from
typing
import
Callable
,
List
from
pystencils
import
Assignment
,
AssignmentCollection
from
pystencils.sympyextensions
import
replaceA
dditive
from
pystencils.sympyextensions
import
subs_a
dditive
def
sympyCseOnEquationList
(
eqs
):
ec
=
AssignmentCollection
(
eqs
,
[])
return
sympyCSE
(
ec
).
allEquations
def
sympy_cse_on_assignment_list
(
assignments
:
List
[
Assignment
])
->
List
[
Assignment
]:
"""Extracts common subexpressions from a list of assignments."""
ec
=
AssignmentCollection
(
assignments
,
[])
return
sympy_cse
(
ec
).
all_assignments
def
sympyCSE
(
assignment_collection
):
"""
Searches for common subexpressions inside the equation collection, in both the existing subexpressions as well
as the equations themselves. It uses the sympy subexpression detection to do this. Return a new equation collection
def
sympy_cse
(
ac
:
AssignmentCollection
)
->
AssignmentCollection
:
"""Searches for common subexpressions inside the equation collection.
Searches is done in both the existing subexpressions as well as the assignments themselves.
It uses the sympy subexpression detection to do this. Return a new equation collection
with the additional subexpressions found
"""
symbol
G
en
=
a
ssignment_collection
.
subexpression
S
ymbol
NameG
enerator
replacements
,
new
E
q
=
sp
.
cse
(
a
ssignment_collection
.
subexpressions
+
assignment_collection
.
main
A
ssignments
,
symbols
=
symbol
G
en
)
replacement
E
qs
=
[
Assignment
(
*
r
)
for
r
in
replacements
]
symbol
_g
en
=
a
c
.
subexpression
_s
ymbol
_g
enerator
replacements
,
new
_e
q
=
sp
.
cse
(
a
c
.
subexpressions
+
ac
.
main
_a
ssignments
,
symbols
=
symbol
_g
en
)
replacement
_e
qs
=
[
Assignment
(
*
r
)
for
r
in
replacements
]
modified
S
ubexpressions
=
new
E
q
[:
len
(
a
ssignment_collection
.
subexpressions
)]
modified
U
pdate
E
quations
=
new
E
q
[
len
(
a
ssignment_collection
.
subexpressions
):]
modified
_s
ubexpressions
=
new
_e
q
[:
len
(
a
c
.
subexpressions
)]
modified
_u
pdate
_e
quations
=
new
_e
q
[
len
(
a
c
.
subexpressions
):]
new
S
ubexpressions
=
replacement
E
qs
+
modified
S
ubexpressions
topologically
S
orted
P
airs
=
sp
.
cse_main
.
reps_toposort
([[
e
.
lhs
,
e
.
rhs
]
for
e
in
new
S
ubexpressions
])
new
S
ubexpressions
=
[
Assignment
(
a
[
0
],
a
[
1
])
for
a
in
topologically
S
orted
P
airs
]
new
_s
ubexpressions
=
replacement
_e
qs
+
modified
_s
ubexpressions
topologically
_s
orted
_p
airs
=
sp
.
cse_main
.
reps_toposort
([[
e
.
lhs
,
e
.
rhs
]
for
e
in
new
_s
ubexpressions
])
new
_s
ubexpressions
=
[
Assignment
(
a
[
0
],
a
[
1
])
for
a
in
topologically
_s
orted
_p
airs
]
return
a
ssignment_collection
.
copy
(
modified
U
pdate
E
quations
,
new
S
ubexpressions
)
return
a
c
.
copy
(
modified
_u
pdate
_e
quations
,
new
_s
ubexpressions
)
def
applyOnAllEquations
(
assignment_collection
,
operation
):
def
apply_to_all_assignments
(
assignment_collection
:
AssignmentCollection
,
operation
:
Callable
[[
sp
.
Expr
],
sp
.
Expr
])
->
AssignmentCollection
:
"""Applies sympy expand operation to all equations in collection"""
result
=
[
Assignment
(
eq
.
lhs
,
operation
(
eq
.
rhs
))
for
eq
in
assignment_collection
.
main
A
ssignments
]
result
=
[
Assignment
(
eq
.
lhs
,
operation
(
eq
.
rhs
))
for
eq
in
assignment_collection
.
main
_a
ssignments
]
return
assignment_collection
.
copy
(
result
)
def
applyOnAllSubexpressions
(
assignment_collection
,
operation
):
result
=
[
Assignment
(
eq
.
lhs
,
operation
(
eq
.
rhs
))
for
eq
in
assignment_collection
.
subexpressions
]
return
assignment_collection
.
copy
(
assignment_collection
.
mainAssignments
,
result
)
def
apply_on_all_subexpressions
(
ac
:
AssignmentCollection
,
operation
:
Callable
[[
sp
.
Expr
],
sp
.
Expr
])
->
AssignmentCollection
:
result
=
[
Assignment
(
eq
.
lhs
,
operation
(
eq
.
rhs
))
for
eq
in
ac
.
subexpressions
]
return
ac
.
copy
(
ac
.
main_assignments
,
result
)
def
subexpression
S
ubstitution
InE
xisting
S
ubexpressions
(
assignment
_c
ollection
)
:
def
subexpression
_s
ubstitution
_in_e
xisting
_s
ubexpressions
(
a
c
:
A
ssignment
Collection
)
->
AssignmentC
ollection
:
"""Goes through the subexpressions list and replaces the term in the following subexpressions"""
result
=
[]
for
outerCtr
,
s
in
enumerate
(
a
ssignment_collection
.
subexpressions
):
new
R
hs
=
s
.
rhs
for
outerCtr
,
s
in
enumerate
(
a
c
.
subexpressions
):
new
_r
hs
=
s
.
rhs
for
innerCtr
in
range
(
outerCtr
):
sub
E
xpr
=
a
ssignment_collection
.
subexpressions
[
innerCtr
]
new
R
hs
=
replaceA
dditive
(
new
R
hs
,
sub
E
xpr
.
lhs
,
sub
E
xpr
.
rhs
,
required
M
atch
R
eplacement
=
1.0
)
new
R
hs
=
new
R
hs
.
subs
(
sub
E
xpr
.
rhs
,
sub
E
xpr
.
lhs
)
result
.
append
(
Assignment
(
s
.
lhs
,
new
R
hs
))
sub
_e
xpr
=
a
c
.
subexpressions
[
innerCtr
]
new
_r
hs
=
subs_a
dditive
(
new
_r
hs
,
sub
_e
xpr
.
lhs
,
sub
_e
xpr
.
rhs
,
required
_m
atch
_r
eplacement
=
1.0
)
new
_r
hs
=
new
_r
hs
.
subs
(
sub
_e
xpr
.
rhs
,
sub
_e
xpr
.
lhs
)
result
.
append
(
Assignment
(
s
.
lhs
,
new
_r
hs
))
return
a
ssignment_collection
.
copy
(
assignment_collection
.
main
A
ssignments
,
result
)
return
a
c
.
copy
(
ac
.
main
_a
ssignments
,
result
)
def
subexpression
S
ubstitution
In
main
A
ssignments
(
assignment
_c
ollection
)
:
"""Replaces already existing subexpressions in the equations of the assignment_collection"""
def
subexpression
_s
ubstitution
_in_
main
_a
ssignments
(
a
c
:
A
ssignment
Collection
)
->
AssignmentC
ollection
:
"""Replaces already existing subexpressions in the equations of the assignment_collection
.
"""
result
=
[]
for
s
in
a
ssignment_collection
.
main
A
ssignments
:
new
R
hs
=
s
.
rhs
for
subExpr
in
a
ssignment_collection
.
subexpressions
:
new
R
hs
=
replaceA
dditive
(
new
R
hs
,
subExpr
.
lhs
,
subExpr
.
rhs
,
required
M
atch
R
eplacement
=
1.0
)
result
.
append
(
Assignment
(
s
.
lhs
,
new
R
hs
))
return
a
ssignment_collection
.
copy
(
result
)
for
s
in
a
c
.
main
_a
ssignments
:
new
_r
hs
=
s
.
rhs
for
subExpr
in
a
c
.
subexpressions
:
new
_r
hs
=
subs_a
dditive
(
new
_r
hs
,
subExpr
.
lhs
,
subExpr
.
rhs
,
required
_m
atch
_r
eplacement
=
1.0
)
result
.
append
(
Assignment
(
s
.
lhs
,
new
_r
hs
))
return
a
c
.
copy
(
result
)
def
add
S
ubexpressions
ForD
ivisions
(
assignment
_c
ollection
):
def
add
_s
ubexpressions
_for_d
ivisions
(
a
c
:
A
ssignment
C
ollection
)
->
AssignmentCollection
:
"""Introduces subexpressions for all divisions which have no constant in the denominator.
e.g. :math:`
\f
rac{1}{x}` is replaced, :math:`
\f
rac{1}{3}` is not replaced."""
For example :math:`
\f
rac{1}{x}` is replaced, :math:`
\f
rac{1}{3}` is not replaced.
"""
divisors
=
set
()
def
search
D
ivisors
(
term
):
def
search
_d
ivisors
(
term
):
if
term
.
func
==
sp
.
Pow
:
if
term
.
exp
.
is_integer
and
term
.
exp
.
is_number
and
term
.
exp
<
0
:
divisors
.
add
(
term
)
else
:
for
a
in
term
.
args
:
search
D
ivisors
(
a
)
search
_d
ivisors
(
a
)
for
eq
in
assignment
_collection
.
allEquation
s
:
search
D
ivisors
(
eq
.
rhs
)
for
eq
in
ac
.
all_
assignments
:
search
_d
ivisors
(
eq
.
rhs
)
new
S
ymbol
G
en
=
a
ssignment_collection
.
subexpression
S
ymbol
NameG
enerator
substitutions
=
{
divisor
:
newSymbol
for
newSymbol
,
divisor
in
zip
(
new
S
ymbol
G
en
,
divisors
)}
return
a
ssignment_collection
.
copyW
ith
S
ubstitutions
Applied
(
substitutions
,
True
)
new
_s
ymbol
_g
en
=
a
c
.
subexpression
_s
ymbol
_g
enerator
substitutions
=
{
divisor
:
newSymbol
for
newSymbol
,
divisor
in
zip
(
new
_s
ymbol
_g
en
,
divisors
)}
return
a
c
.
new_w
ith
_s
ubstitutions
(
substitutions
,
True
)
assignment_collection/simplificationstrategy.py
View file @
ef924b18
import
sympy
as
sp
from
collections
import
namedtuple
from
typing
import
Callable
,
Any
,
Optional
,
Sequence
from
pystencils.assignment_collection.assignment_collection
import
AssignmentCollection
class
SimplificationStrategy
(
object
):
"""
A simplification strategy is an ordered collection of simplification rules.
"""
A simplification strategy is an ordered collection of simplification rules.
Each simplification is a function taking an equation collection, and returning a new simplified
equation collection. The strategy can nicely print intermediate simplification stages and results
to Jupyter notebooks.
...
...
@@ -13,10 +15,11 @@ class SimplificationStrategy(object):
def
__init__
(
self
):
self
.
_rules
=
[]
def
add
(
self
,
rule
):
"""
Adds the given simplification rule to the end of the collection.
:param rule: function that taking one equation collection and returning a (simplified) equation collection
def
add
(
self
,
rule
:
Callable
[[
AssignmentCollection
],
AssignmentCollection
])
->
None
:
"""Adds the given simplification rule to the end of the collection.
Args:
rule: function that rewrites/simplifies an assignment collection
"""
self
.
_rules
.
append
(
rule
)
...
...
@@ -24,19 +27,20 @@ class SimplificationStrategy(object):
def
rules
(
self
):
return
self
.
_rules
def
apply
(
self
,
updateRule
)
:
"""
Applies all simplification
rules
t
o the given
equation
collection"""
def
apply
(
self
,
assignment_collection
:
AssignmentCollection
)
->
AssignmentCollection
:
"""
Runs all
rules o
n
the given
assignment
collection
.
"""
for
t
in
self
.
_rules
:
updateRule
=
t
(
updateRule
)
return
updateRule
assignment_collection
=
t
(
assignment_collection
)
return
assignment_collection
def
__call__
(
self
,
assignment_collection
)
:
def
__call__
(
self
,
assignment_collection
:
AssignmentCollection
)
->
AssignmentCollection
:
"""Same as apply"""
return
self
.
apply
(
assignment_collection
)
def
createSimplificationReport
(
self
,
assignment_collection
):
"""
Returns a simplification report containing the number of operations at each simplification stage, together
def
create_simplification_report
(
self
,
assignment_collection
:
AssignmentCollection
)
->
Any
:
"""Creates a report to be displayed as HTML in a Jupyter notebook.
The simplification report contains the number of operations at each simplification stage together
with the run-time the simplification took.
"""
...
...
@@ -60,70 +64,83 @@ class SimplificationStrategy(object):
return
result
def
_repr_html_
(
self
):
htmlTable
=
'<table style="border:none">'
htmlTable
+=
"<tr><th>Name</th><th>Runtime</th><th>Adds</th><th>Muls</th><th>Divs</th><th>Total</th></tr>"
html_table
=
'<table style="border:none">'
html_table
+=
"<tr><th>Name</th>"
\
"<th>Runtime</th>"
\
"<th>Adds</th>"
\
"<th>Muls</th>"
\
"<th>Divs</th>"
\
"<th>Total</th></tr>"
line
=
"<tr><td>{simplificationName}</td>"
\
"<td>{runtime}</td> <td>{adds}</td> <td>{muls}</td> <td>{divs}</td> <td>{total}</td> </tr>"
for
e
in
self
.
elements
:
htmlTable
+=
line
.
format
(
**
e
.
_asdict
())
htmlTable
+=
"</table>"
return
htmlTable
# noinspection PyProtectedMember
html_table
+=
line
.
format
(
**
e
.
_asdict
())
html_table
+=
"</table>"
return
html_table
import
timeit
report
=
Report
()
op
=
assignment_collection
.
operation
C
ount
op
=
assignment_collection
.
operation
_c
ount
total
=
op
[
'adds'
]
+
op
[
'muls'
]
+
op
[
'divs'
]
report
.
add
(
ReportElement
(
"OriginalTerm"
,
'-'
,
op
[
'adds'
],
op
[
'muls'
],
op
[
'divs'
],
total
))
for
t
in
self
.
_rules
:
start
T
ime
=
timeit
.
default_timer
()
start
_t
ime
=
timeit
.
default_timer
()
assignment_collection
=
t
(
assignment_collection
)
end
T
ime
=
timeit
.
default_timer
()
op
=
assignment_collection
.
operation
C
ount
time
S
tr
=
"%.2f ms"
%
((
end
T
ime
-
start
T
ime
)
*
1000
,)
end
_t
ime
=
timeit
.
default_timer
()
op
=
assignment_collection
.
operation
_c
ount
time
_s
tr
=
"%.2f ms"
%
((
end
_t
ime
-
start
_t
ime
)
*
1000
,)
total
=
op
[
'adds'
]
+
op
[
'muls'
]
+
op
[
'divs'
]
report
.
add
(
ReportElement
(
t
.
__name__
,
time
S
tr
,
op
[
'adds'
],
op
[
'muls'
],
op
[
'divs'
],
total
))
report
.
add
(
ReportElement
(
t
.
__name__
,
time
_s
tr
,
op
[
'adds'
],
op
[
'muls'
],
op
[
'divs'
],
total
))
return
report
def
showIntermediateResults
(
self
,
assignment_collection
,
symbols
=
None
):
def
show_intermediate_results
(
self
,
assignment_collection
:
AssignmentCollection
,
symbols
:
Optional
[
Sequence
[
sp
.
Symbol
]]
=
None
)
->
Any
:
"""Shows the assignment collection after the application of each rule as HTML report for Jupyter notebook.
Args:
assignment_collection: the collection to apply the rules to
symbols: if not None, only the assignments are shown that have one of these symbols as left hand side
"""
class
IntermediateResults
:
def
__init__
(
self
,
strategy
,
eqColl
,
resSym
s
):
def
__init__
(
self
,
strategy
,
collection
,
restrict_symbol
s
):
self
.
strategy
=
strategy
self
.
assignment_collection
=
eqColl
self
.
restrict
S
ymbols
=
res
Sym
s
self
.
assignment_collection
=
collection
self
.
restrict
_s
ymbols
=
res
trict_symbol
s
def
__str__
(
self
):
def
print
EqC
ollection
(
title
,
eqColl
):
def
print
_assignment_c
ollection
(
title
,
c
):
text
=
title
if
self
.
restrict
S
ymbols
:
text
+=
"
\n
"
.
join
([
str
(
e
)
for
e
in
eqColl
.
get
(
self
.
restrict
S
ymbols
)])
if
self
.
restrict
_s
ymbols
:
text
+=
"
\n
"
.
join
([
str
(
e
)
for
e
in
c
.
get
(
self
.
restrict
_s
ymbols
)])
else
:
text
+=
(
" "
*
3
+
(
" "
*
3
).
join
(
str
(
eqColl
).
splitlines
(
True
)))
text
+=
(
" "
*
3
+
(
" "
*
3
).
join
(
str
(
c
).
splitlines
(
True
)))
return
text
result
=
print
EqC
ollection
(
"Initial Version"
,
self
.
assignment_collection
)
eqColl
=
self
.
assignment_collection
result
=
print
_assignment_c
ollection
(
"Initial Version"
,
self
.
assignment_collection
)
collection
=
self
.
assignment_collection
for
rule
in
self
.
strategy
.
rules
:
eqColl
=
rule
(
eqColl
)
result
+=
print
EqC
ollection
(
rule
.
__name__
,
eqColl
)
collection
=
rule
(
collection
)
result
+=
print
_assignment_c
ollection
(
rule
.
__name__
,
collection
)
return
result
def
_repr_html_
(
self
):
def
print
EqC
ollection
(
title
,
eqColl
):
def
print
_assignment_c
ollection
(
title
,
c
):
text
=
'<h5 style="padding-bottom:10px">%s</h5> <div style="padding-left:20px;">'
%
(
title
,
)
if
self
.
restrict
S
ymbols
:
text
+=
"
\n
"
.
join
([
"$$"
+
sp
.
latex
(
e
)
+
'$$'
for
e
in
eqColl
.
get
(
self
.
restrict
S
ymbols
)])
if
self
.
restrict
_s
ymbols
:
text
+=
"
\n
"
.
join
([
"$$"
+
sp
.
latex
(
e
)
+
'$$'
for
e
in
c
.
get
(
self
.
restrict
_s
ymbols
)])
else
:
text
+=
eqColl
.
_repr_html_
()
# noinspection PyProtectedMember
text
+=
c
.
_repr_html_
()
text
+=
"</div>"
return
text
result
=
print
EqC
ollection
(
"Initial Version"
,
self
.
assignment_collection
)
eqColl
=
self
.
assignment_collection
result
=
print
_assignment_c
ollection
(
"Initial Version"
,
self
.
assignment_collection
)
collection
=
self
.
assignment_collection
for
rule
in
self
.
strategy
.
rules
:
eqColl
=
rule
(
eqColl
)
result
+=
print
EqC
ollection
(
rule
.
__name__
,
eqColl
)
collection
=
rule
(
collection
)
result
+=
print
_assignment_c
ollection
(
rule
.
__name__
,
collection
)
return
result
return
IntermediateResults
(
self
,
assignment_collection
,
symbols
)
...
...
astnodes.py
View file @
ef924b18
...
...
@@ -2,7 +2,7 @@ import sympy as sp
from
sympy.tensor
import
IndexedBase
from
pystencils.field
import
Field
from
pystencils.data_types
import
TypedSymbol
,
createType
,
castFunc
from
pystencils.sympyextensions
import
fast
S
ubs
from
pystencils.sympyextensions
import
fast
_s
ubs
class
Node
(
object
):
...
...
@@ -275,11 +275,11 @@ class Block(Node):
@
property
def
undefinedSymbols
(
self
):
result
=
set
()
defined
S
ymbols
=
set
()
defined
_s
ymbols
=
set
()
for
a
in
self
.
args
:
result
.
update
(
a
.
undefinedSymbols
)
defined
S
ymbols
.
update
(
a
.
symbolsDefined
)
return
result
-
defined
S
ymbols
defined
_s
ymbols
.
update
(
a
.
symbolsDefined
)
return
result
-
defined
_s
ymbols
def
__str__
(
self
):
return
"Block "
+
''
.
join
(
'{!s}
\n
'
.
format
(
node
)
for
node
in
self
.
_nodes
)
...
...
@@ -426,8 +426,8 @@ class SympyAssignment(Node):
self
.
_isDeclaration
=
False
def
subs
(
self
,
*
args
,
**
kwargs
):
self
.
lhs
=
fast
S
ubs
(
self
.
lhs
,
*
args
,
**
kwargs
)
self
.
rhs
=
fast
S
ubs
(
self
.
rhs
,
*
args
,
**
kwargs
)
self
.
lhs
=
fast
_s
ubs
(
self
.
lhs
,
*
args
,
**
kwargs
)
self
.
rhs
=
fast
_s
ubs
(
self
.
rhs
,
*
args
,
**
kwargs
)
@
property
def
args
(
self
):
...
...
@@ -494,11 +494,11 @@ class ResolvedFieldAccess(sp.Indexed):
self
.
args
[
1
].
subs
(
old
,
new
),
self
.
field
,
self
.
offsets
,
self
.
idxCoordinateValues
)
def
fast
S
ubs
(
self
,
subs
Dict
):
if
self
in
subs
Dict
:
return
subs
Dict
[
self
]
return
ResolvedFieldAccess
(
self
.
args
[
0
].
subs
(
subs
Dict
),
self
.
args
[
1
].
subs
(
subs
Dict
),
def
fast
_s
ubs
(
self
,
subs
titutions
):
if
self
in
subs
titutions
:
return
subs
titutions
[
self
]
return
ResolvedFieldAccess
(
self
.
args
[
0
].
subs
(
subs
titutions
),
self
.
args
[
1
].
subs
(
subs
titutions
),
self
.
field
,
self
.
offsets
,
self
.
idxCoordinateValues
)
def
_hashable_content
(
self
):
...
...
derivative.py
View file @
ef924b18
import
sympy
as
sp
from
collections
import
namedtuple
,
defaultdict
from
pystencils.sympyextensions
import
normalize
P
roduct
,
prod
from
pystencils.sympyextensions
import
normalize
_p
roduct
,
prod
def
defaultDiffSortKey
(
d
):
...
...
@@ -57,7 +57,7 @@ class Diff(sp.Expr):
if
self
.
arg
.
func
!=
sp
.
Mul
:
constant
,
variable
=
1
,
self
.
arg
else
:
for
factor
in
normalize
P
roduct
(
self
.
arg
):
for
factor
in
normalize
_p
roduct
(
self
.
arg
):
if
factor
in
functions
or
isinstance
(
factor
,
Diff
):
variable
*=
factor
else
:
...
...
@@ -150,7 +150,7 @@ class DiffOperator(sp.Expr):
i.e. DiffOperator('x')*DiffOperator('x') is a second derivative replaced by Diff(Diff(arg, x), t)
"""
def
handleMul
(
mul
):
args
=
normalize
P
roduct
(
mul
)
args
=
normalize
_p
roduct
(
mul
)
diffs
=
[
a
for
a
in
args
if
isinstance
(
a
,
DiffOperator
)]
if
len
(
diffs
)
==
0
:
return
mul
*
argument
if
applyToConstants
else
mul
...
...
@@ -254,7 +254,7 @@ def fullDiffExpand(expr, functions=None, constants=None):
for
term
in
diffInner
.
args
if
diffInner
.
func
==
sp
.
Add
else
[
diffInner
]:
independentTerms
=
1
dependentTerms
=
[]
for
factor
in
normalize
P
roduct
(
term
):
for
factor
in
normalize
_p
roduct
(
term
):
if
factor
in
functions
or
isinstance
(
factor
,
Diff
):
dependentTerms
.
append
(
factor
)
else
:
...
...
@@ -310,7 +310,7 @@ def expandUsingProductRule(expr):
if
arg
.
func
not
in
(
sp
.
Mul
,
sp
.
Pow
):
return
Diff
(
arg
,
target
=
expr
.
target
,
superscript
=
expr
.
superscript
)
else
:
prodList
=
normalize
P
roduct
(
arg
)
prodList
=
normalize
_p
roduct
(
arg
)
result
=
0
for
i
in
range
(
len
(
prodList
)):
preFactor
=
prod
(
prodList
[
j
]
for
j
in
range
(
len
(
prodList
))
if
i
!=
j
)
...
...
@@ -347,7 +347,7 @@ def combineUsingProductRule(expr):
if
isinstance
(
term
,
Diff
):
diffDict
[
DiffInfo
(
term
.
target
,
term
.
superscript
)].
append
(
DiffSplit
(
1
,
term
.
arg
))
else
:
mulArgs
=
normalize
P
roduct
(
term
)
mulArgs
=
normalize
_p
roduct
(
term
)
diffs
=
[
d
for
d
in
mulArgs
if
isinstance
(
d
,
Diff
)]
factor
=
prod
(
d
for
d
in
mulArgs
if
not
isinstance
(
d
,
Diff
))
if
len
(
diffs
)
==
0
:
...
...
field.py
View file @
ef924b18
...
...
@@ -8,7 +8,7 @@ from sympy.tensor import IndexedBase
from
pystencils.assignment
import
Assignment
from
pystencils.alignedarray
import
aligned_empty
from
pystencils.data_types
import
TypedSymbol
,
createType
,
createCompositeTypeFromString
,
StructType
from
pystencils.sympyextensions
import
is
I
nteger
S
equence
from
pystencils.sympyextensions
import
is
_i
nteger
_s
equence
class
FieldType
(
Enum
):
...
...
@@ -221,7 +221,7 @@ class Field(object):
@
property
def
hasFixedShape
(
self
):
return
is
I
nteger
S
equence
(
self
.
shape
)
return
is
_i
nteger
_s
equence
(
self
.
shape
)
@
property
def
indexShape
(
self
):
...
...
@@ -229,7 +229,7 @@ class Field(object):
@
property
def
hasFixedIndexShape
(
self
):
return
is
I
nteger
S
equence
(
self
.
indexShape
)
return
is
_i
nteger
_s
equence
(
self
.
indexShape
)
@
property
def
spatialStrides
(
self
):
...
...
finitedifferences.py
View file @
ef924b18
...
...
@@ -3,7 +3,7 @@ import sympy as sp
from
pystencils.assignment_collection
import
AssignmentCollection
from
pystencils.field
import
Field
from
pystencils.
transformat
ions
import
fast
S
ubs
from
pystencils.
sympyextens
ions
import
fast
_s
ubs
from
pystencils.derivative
import
Diff
...
...
@@ -103,7 +103,7 @@ def discretizeStaggered(term, symbolsToFieldDict, coordinate, coordinateOffset,
neighborGrad
=
(
field
[
up
+
offset
](
i
)
-
field
[
down
+
offset
](
i
))
/
(
2
*
dx
)
substitutions
[
grad
(
s
)[
d
]]
=
(
centerGrad
+
neighborGrad
)
/
2
return
fast
S
ubs
(
term
,
substitutions
)
return
fast
_s
ubs
(
term
,
substitutions
)
def
discretizeDivergence
(
vectorTerm
,
symbolsToFieldDict
,
dx
):
...
...
@@ -356,7 +356,7 @@ class Discretization2ndOrder:
elif
isinstance
(
expr
,
sp
.
Matrix
):
return
expr
.
applyfunc
(
self
.
__call__
)
elif
isinstance
(
expr
,
AssignmentCollection
):
return
expr
.
copy
(
main
A
ssignments
=
[
e
for
e
in
expr
.
main
A
ssignments
],
return
expr
.
copy
(
main
_a
ssignments
=
[
e
for
e
in
expr
.
main
_a
ssignments
],
subexpressions
=
[
e
for
e
in
expr
.
subexpressions
])
transientTerms
=
expr
.
atoms
(
Transient
)
...
...
kerncraft_coupling/kerncraft_interface.py
View file @
ef924b18
...
...
@@ -12,7 +12,7 @@ from kerncraft.iaca import iaca_analyse_instrumented_binary, iaca_instrumentatio
from
pystencils.kerncraft_coupling.generate_benchmark
import
generateBenchmark
from
pystencils.astnodes
import
LoopOverCoordinate
,
SympyAssignment
,
ResolvedFieldAccess
from
pystencils.field
import
getLayoutFromStrides
from
pystencils.sympyextensions
import
count
NumberOfO
perations
InA
st
from
pystencils.sympyextensions
import
count
_o
perations
_in_a
st
from
pystencils.utils
import
DotDict
...
...
@@ -78,7 +78,7 @@ class PyStencilsKerncraftKernel(kerncraft.kernel.Kernel):
self
.
datatype
=
list
(
self
.
variables
.
values
())[
0
][
0
]
# flops
operationCount
=
count
NumberOfO
perations
InA
st
(
innerLoop
)
operationCount
=
count
_o
perations
_in_a
st
(
innerLoop
)
self
.
_flops
=
{
'+'
:
operationCount
[
'adds'
],
'*'
:
operationCount
[
'muls'
],
...
...
kernelcreation.py
View file @
ef924b18
...
...
@@ -33,9 +33,9 @@ def createKernel(equations, target='cpu', dataType="double", iterationSlice=None
# ---- Normalizing parameters
splitGroups
=
()
if
isinstance
(
equations
,
AssignmentCollection
):
if
'splitGroups'
in
equations
.
simplification
H
ints
:
splitGroups
=
equations
.
simplification
H
ints
[
'splitGroups'
]
equations
=
equations
.
all
Equation
s
if
'splitGroups'
in
equations
.
simplification
_h
ints
:
splitGroups
=
equations
.
simplification
_h
ints
[
'splitGroups'
]
equations
=
equations
.
all
_assignment
s
# ---- Creating ast
if
target
==
'cpu'
:
...
...
@@ -84,7 +84,7 @@ def createIndexedKernel(equations, indexFields, target='cpu', dataType="double",
"""
if
isinstance
(
equations
,
AssignmentCollection
):
equations
=
equations
.
all
Equation
s
equations
=
equations
.
all
_assignment
s
if
target
==
'cpu'
:
from
pystencils.cpu
import
createIndexedKernel
from
pystencils.cpu
import
addOpenMP
...
...
sympyextensions.py
View file @
ef924b18
This diff is collapsed.
Click to expand it.
transformations/transformations.py
View file @
ef924b18
...
...
@@ -21,22 +21,6 @@ def filteredTreeIteration(node, nodeType):
yield
from
filteredTreeIteration
(
arg
,
nodeType
)
def
fastSubs
(
term
,
subsDict
):
"""Similar to sympy subs function.
This version is much faster for big substitution dictionaries than sympy version"""
if
type
(
term
)
is
sp
.
Matrix
:
return
term
.
copy
().
applyfunc
(
functools
.
partial
(
fastSubs
,
subsDict
=
subsDict
))
def
visit
(
expr
):
if
expr
in
subsDict
:
return
subsDict
[
expr
]
if
not
hasattr
(
expr
,
'args'
):
return
expr
paramList
=
[
visit
(
a
)
for
a
in
expr
.
args
]
return
expr
if
not
paramList
else
expr
.
func
(
*
paramList
)
return
visit
(
term
)
def
getCommonShape
(
fieldSet
):
"""Takes a set of pystencils Fields and returns their common spatial shape if it exists. Otherwise
ValueError is raised"""
...
...
vectorization.py
View file @
ef924b18
import
sympy
as
sp
import
warnings
from
pystencils.sympyextensions
import
fast
S
ubs
from
pystencils.sympyextensions
import
fast
_s
ubs
from
pystencils.transformations
import
filteredTreeIteration
from
pystencils.data_types
import
TypedSymbol
,
VectorType
,
BasicType
,
getTypeOfExpression
,
castFunc
,
collateTypes
,
\
PointerType
...
...
@@ -97,7 +97,7 @@ def insertVectorCasts(astNode):
substitutionDict
=
{}
for
asmt
in
filteredTreeIteration
(
astNode
,
ast
.
SympyAssignment
):
subsExpr
=
fast
S
ubs
(
asmt
.
rhs
,
substitutionDict
,
skip
=
lambda
e
:
isinstance
(
e
,
ast
.
ResolvedFieldAccess
))
subsExpr
=
fast
_s
ubs
(
asmt
.
rhs
,
substitutionDict
,
skip
=
lambda
e
:
isinstance
(
e
,
ast
.
ResolvedFieldAccess
))
asmt
.
rhs
=
visitExpr
(
subsExpr
)
rhsType
=
getTypeOfExpression
(
asmt
.
rhs
)
if
isinstance
(
asmt
.
lhs
,
TypedSymbol
):
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment