Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Sebastian Bindgen
pystencils
Commits
1eacfdad
Commit
1eacfdad
authored
Dec 02, 2017
by
Martin Bauer
Browse files
New transformations for staggered field traversal
- loop cutting - simplification of conditionals inside loop
parent
74b69826
Changes
5
Hide whitespace changes
Inline
Side-by-side
astnodes.py
View file @
1eacfdad
...
...
@@ -20,6 +20,13 @@ class ResolvedFieldAccess(sp.Indexed):
self
.
args
[
1
].
subs
(
old
,
new
),
self
.
field
,
self
.
offsets
,
self
.
idxCoordinateValues
)
def
fastSubs
(
self
,
subsDict
):
if
self
in
subsDict
:
return
subsDict
[
self
]
return
ResolvedFieldAccess
(
self
.
args
[
0
].
subs
(
subsDict
),
self
.
args
[
1
].
subs
(
subsDict
),
self
.
field
,
self
.
offsets
,
self
.
idxCoordinateValues
)
def
_hashable_content
(
self
):
superClassContents
=
super
(
ResolvedFieldAccess
,
self
).
_hashable_content
()
return
superClassContents
+
tuple
(
self
.
offsets
)
+
(
repr
(
self
.
idxCoordinateValues
),
hash
(
self
.
field
))
...
...
@@ -89,8 +96,23 @@ class Conditional(Node):
"""
assert
conditionExpr
.
is_Boolean
or
conditionExpr
.
is_Relational
self
.
conditionExpr
=
conditionExpr
self
.
trueBlock
=
trueBlock
self
.
falseBlock
=
falseBlock
def
handleChild
(
c
):
if
c
is
None
:
return
None
if
not
isinstance
(
c
,
Block
):
c
=
Block
([
c
])
c
.
parent
=
self
return
c
self
.
trueBlock
=
handleChild
(
trueBlock
)
self
.
falseBlock
=
handleChild
(
falseBlock
)
def
subs
(
self
,
*
args
,
**
kwargs
):
self
.
trueBlock
.
subs
(
*
args
,
**
kwargs
)
if
self
.
falseBlock
:
self
.
falseBlock
.
subs
(
*
args
,
**
kwargs
)
self
.
conditionExpr
=
self
.
conditionExpr
.
subs
(
*
args
,
**
kwargs
)
@
property
def
args
(
self
):
...
...
@@ -107,7 +129,7 @@ class Conditional(Node):
def
undefinedSymbols
(
self
):
result
=
self
.
trueBlock
.
undefinedSymbols
if
self
.
falseBlock
:
result
=
result
.
update
(
self
.
falseBlock
.
undefinedSymbols
)
result
.
update
(
self
.
falseBlock
.
undefinedSymbols
)
result
.
update
(
self
.
conditionExpr
.
atoms
(
sp
.
Symbol
))
return
result
...
...
@@ -243,11 +265,21 @@ class Block(Node):
def
insertBefore
(
self
,
newNode
,
insertBefore
):
newNode
.
parent
=
self
idx
=
self
.
_nodes
.
index
(
insertBefore
)
# move all assignment (definitions to the top)
if
isinstance
(
newNode
,
SympyAssignment
)
and
newNode
.
isDeclaration
:
while
idx
>
0
and
not
(
isinstance
(
self
.
_nodes
[
idx
-
1
],
SympyAssignment
)
and
self
.
_nodes
[
idx
-
1
].
isDeclaration
):
idx
-=
1
self
.
_nodes
.
insert
(
idx
,
newNode
)
def
append
(
self
,
node
):
node
.
parent
=
self
self
.
_nodes
.
append
(
node
)
if
isinstance
(
node
,
list
)
or
isinstance
(
node
,
tuple
):
for
n
in
node
:
n
.
parent
=
self
self
.
_nodes
.
append
(
n
)
else
:
node
.
parent
=
self
self
.
_nodes
.
append
(
node
)
def
takeChildNodes
(
self
):
tmp
=
self
.
_nodes
...
...
@@ -339,7 +371,6 @@ class LoopOverCoordinate(Node):
elif
child
==
self
.
stop
:
self
.
stop
=
replacement
@
property
def
symbolsDefined
(
self
):
return
set
([
self
.
loopCounterSymbol
])
...
...
@@ -389,14 +420,14 @@ class LoopOverCoordinate(Node):
def
__str__
(
self
):
return
'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})
\n
{!s}'
.
format
(
self
.
loopCounterName
,
self
.
start
,
self
.
loopCounterName
,
self
.
stop
,
self
.
loopCounterName
,
self
.
step
,
(
"
\t
"
+
"
\t
"
.
join
(
str
(
self
.
body
).
splitlines
(
True
))))
self
.
loopCounterName
,
self
.
stop
,
self
.
loopCounterName
,
self
.
step
,
(
"
\t
"
+
"
\t
"
.
join
(
str
(
self
.
body
).
splitlines
(
True
))))
def
__repr__
(
self
):
return
'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})'
.
format
(
self
.
loopCounterName
,
self
.
start
,
self
.
loopCounterName
,
self
.
stop
,
self
.
loopCounterName
,
self
.
step
)
self
.
loopCounterName
,
self
.
stop
,
self
.
loopCounterName
,
self
.
step
)
class
SympyAssignment
(
Node
):
...
...
backends/cbackend.py
View file @
1eacfdad
...
...
@@ -161,7 +161,7 @@ class CBackend(object):
def
_print_Conditional
(
self
,
node
):
conditionExpr
=
self
.
sympyPrinter
.
doprint
(
node
.
conditionExpr
)
trueBlock
=
self
.
_print_Block
(
node
.
trueBlock
)
result
=
"if (%s)
\n
%s "
%
(
conditionExpr
,
trueBlock
)
result
=
"if (%s)
\n
%s "
%
(
conditionExpr
,
trueBlock
)
if
node
.
falseBlock
:
falseBlock
=
self
.
_print_Block
(
node
.
falseBlock
)
result
+=
"else "
+
falseBlock
...
...
cpu/cpujit.py
View file @
1eacfdad
"""
r
"""
*pystencils* looks for a configuration file in JSON format at the following locations in the listed order.
...
...
@@ -58,7 +58,6 @@ compiled into the shared library. Then, the same script can be run from the comp
- **'objectCache'**: path to a folder where intermediate files are stored
- **'clearCacheOnStart'**: when true the cache is cleared on each start of a *pystencils* script
- **'sharedLibrary'**: path to a shared library file, which is created if `readFromSharedLibrary=false`
"""
from
__future__
import
print_function
import
os
...
...
@@ -197,7 +196,8 @@ def readConfig():
configPath
,
configExists
=
getConfigurationFilePath
()
config
=
defaultConfig
.
copy
()
if
configExists
:
loadedConfig
=
json
.
load
(
open
(
configPath
,
'r'
))
with
open
(
configPath
,
'r'
)
as
jsonConfigFile
:
loadedConfig
=
json
.
load
(
jsonConfigFile
)
config
=
_recursiveDictUpdate
(
config
,
loadedConfig
)
else
:
createFolder
(
configPath
,
True
)
...
...
sympyextensions.py
View file @
1eacfdad
...
...
@@ -86,6 +86,8 @@ def fastSubs(term, subsDict, skip=None):
def
visit
(
expr
):
if
skip
and
skip
(
expr
):
return
expr
if
hasattr
(
expr
,
"fastSubs"
):
return
expr
.
fastSubs
(
subsDict
)
if
expr
in
subsDict
:
return
subsDict
[
expr
]
if
not
hasattr
(
expr
,
'args'
):
...
...
transformations/transformations.py
View file @
1eacfdad
...
...
@@ -390,7 +390,12 @@ def moveConstantsBeforeLoop(astNode):
if
isinstance
(
element
,
ast
.
Block
):
lastBlock
=
element
lastBlockChild
=
prevElement
if
node
.
undefinedSymbols
.
intersection
(
element
.
symbolsDefined
):
if
isinstance
(
element
,
ast
.
Conditional
):
criticalSymbols
=
element
.
conditionExpr
.
atoms
(
sp
.
Symbol
)
else
:
criticalSymbols
=
element
.
symbolsDefined
if
node
.
undefinedSymbols
.
intersection
(
criticalSymbols
):
break
prevElement
=
element
element
=
element
.
parent
...
...
@@ -496,6 +501,120 @@ def splitInnerLoop(astNode, symbolGroups):
outerLoop
.
parent
.
append
(
ast
.
TemporaryMemoryFree
(
tmpArrayPointer
))
def
cutLoop
(
loopNode
,
cuttingPoints
):
"""Cuts loop at given cutting points, that means one loop is transformed into len(cuttingPoints)+1 new loops
that range from oldBegin to cuttingPoint[1], ..., cuttingPoint[-1] to oldEnd"""
if
loopNode
.
step
!=
1
:
raise
NotImplementedError
(
"Can only split loops that have a step of 1"
)
newLoops
=
[]
newStart
=
loopNode
.
start
cuttingPoints
=
list
(
cuttingPoints
)
+
[
loopNode
.
stop
]
for
newEnd
in
cuttingPoints
:
if
newEnd
-
newStart
==
1
:
newBody
=
deepcopy
(
loopNode
.
body
)
newBody
.
subs
({
loopNode
.
loopCounterSymbol
:
newStart
})
newLoops
.
append
(
newBody
)
else
:
newLoop
=
ast
.
LoopOverCoordinate
(
deepcopy
(
loopNode
.
body
),
loopNode
.
coordinateToLoopOver
,
newStart
,
newEnd
,
loopNode
.
step
)
newLoops
.
append
(
newLoop
)
newStart
=
newEnd
loopNode
.
parent
.
replace
(
loopNode
,
newLoops
)
def
isConditionNecessary
(
condition
,
preCondition
,
symbol
):
"""
Determines if a logical condition of a single variable is already contained in a stronger preCondition
so if from preCondition follows that condition is always true, then this condition is not necessary
:param condition: sympy relational of one variable
:param preCondition: logical expression that is known to be true
:param symbol: the single symbol of interest
:return: returns not (preCondition => condition) where "=>" is logical implication
"""
from
sympy.solvers.inequalities
import
reduce_rational_inequalities
from
sympy.logic.boolalg
import
to_dnf
def
toDnfList
(
expr
):
result
=
to_dnf
(
expr
)
if
isinstance
(
result
,
sp
.
Or
):
return
[
orTerm
.
args
for
orTerm
in
result
.
args
]
elif
isinstance
(
result
,
sp
.
And
):
return
[
result
.
args
]
else
:
return
result
t1
=
reduce_rational_inequalities
(
toDnfList
(
sp
.
And
(
condition
,
preCondition
)),
symbol
)
t2
=
reduce_rational_inequalities
(
toDnfList
(
preCondition
),
symbol
)
return
t1
!=
t2
def
simplifyBooleanExpression
(
expr
,
singleVariableRanges
):
"""Simplification of boolean expression using known ranges of variables
The singleVariableRanges parameter is a dict mapping a variable name to a sympy logical expression that
contains only this variable and defines a range for it. For example with a being a symbol
{ a: sp.And(a >=0, a < 10) }
"""
from
sympy.core.relational
import
Relational
from
sympy.logic.boolalg
import
to_dnf
expr
=
to_dnf
(
expr
)
def
visit
(
e
):
if
isinstance
(
e
,
Relational
):
symbols
=
e
.
atoms
(
sp
.
Symbol
)
if
len
(
symbols
)
==
1
:
symbol
=
symbols
.
pop
()
if
symbol
in
singleVariableRanges
:
if
not
isConditionNecessary
(
e
,
singleVariableRanges
[
symbol
],
symbol
):
return
sp
.
true
return
e
else
:
newArgs
=
[
visit
(
a
)
for
a
in
e
.
args
]
return
e
.
func
(
*
newArgs
)
if
newArgs
else
e
return
visit
(
expr
)
def
simplifyConditionals
(
node
,
loopConditionals
=
{}):
"""Simplifies/Removes conditions inside loops that depend on the loop counter."""
if
isinstance
(
node
,
ast
.
LoopOverCoordinate
):
ctrSym
=
node
.
loopCounterSymbol
loopConditionals
[
ctrSym
]
=
sp
.
And
(
ctrSym
>=
node
.
start
,
ctrSym
<
node
.
stop
)
simplifyConditionals
(
node
.
body
)
del
loopConditionals
[
ctrSym
]
elif
isinstance
(
node
,
ast
.
Conditional
):
node
.
conditionExpr
=
simplifyBooleanExpression
(
node
.
conditionExpr
,
loopConditionals
)
simplifyConditionals
(
node
.
trueBlock
)
if
node
.
falseBlock
:
simplifyConditionals
(
node
.
falseBlock
)
if
node
.
conditionExpr
==
sp
.
true
:
node
.
parent
.
replace
(
node
,
[
node
.
trueBlock
])
if
node
.
conditionExpr
==
sp
.
false
:
node
.
parent
.
replace
(
node
,
[
node
.
falseBlock
]
if
node
.
falseBlock
else
[])
elif
isinstance
(
node
,
ast
.
Block
):
for
a
in
list
(
node
.
args
):
simplifyConditionals
(
a
)
elif
isinstance
(
node
,
ast
.
SympyAssignment
):
return
node
else
:
raise
ValueError
(
"Can not handle node"
,
type
(
node
))
def
cleanupBlocks
(
node
):
"""Curly Brace Removal: Removes empty blocks, and replaces blocks with a single child by its child """
if
isinstance
(
node
,
ast
.
SympyAssignment
):
return
elif
isinstance
(
node
,
ast
.
Block
):
for
a
in
list
(
node
.
args
):
cleanupBlocks
(
a
)
if
len
(
node
.
args
)
<=
1
and
isinstance
(
node
.
parent
,
ast
.
Block
):
node
.
parent
.
replace
(
node
,
node
.
args
)
return
else
:
for
a
in
node
.
args
:
cleanupBlocks
(
a
)
def
symbolNameToVariableName
(
symbolName
):
"""Replaces characters which are allowed in sympy symbol names but not in C/C++ variable names"""
return
symbolName
.
replace
(
"^"
,
"_"
)
...
...
@@ -546,17 +665,23 @@ def typeAllEquations(eqs, typeForSymbol):
else
:
assert
False
,
"Expected a symbol as left-hand-side"
typedEquations
=
[]
for
eq
in
eqs
:
if
isinstance
(
eq
,
sp
.
Eq
)
or
isinstance
(
eq
,
ast
.
SympyAssignment
):
newLhs
=
processLhs
(
eq
.
lhs
)
newRhs
=
processRhs
(
eq
.
rhs
)
typedEquations
.
append
(
ast
.
SympyAssignment
(
newLhs
,
newRhs
))
def
visit
(
object
):
if
isinstance
(
object
,
list
)
or
isinstance
(
object
,
tuple
):
return
[
visit
(
e
)
for
e
in
object
]
if
isinstance
(
object
,
sp
.
Eq
)
or
isinstance
(
object
,
ast
.
SympyAssignment
):
newLhs
=
processLhs
(
object
.
lhs
)
newRhs
=
processRhs
(
object
.
rhs
)
return
ast
.
SympyAssignment
(
newLhs
,
newRhs
)
elif
isinstance
(
object
,
ast
.
Conditional
):
falseBlock
=
None
if
object
.
falseBlock
is
None
else
visit
(
object
.
falseBlock
)
return
ast
.
Conditional
(
processRhs
(
object
.
conditionExpr
),
trueBlock
=
visit
(
object
.
trueBlock
),
falseBlock
=
falseBlock
)
elif
isinstance
(
object
,
ast
.
Block
):
return
ast
.
Block
([
visit
(
e
)
for
e
in
object
.
args
])
else
:
assert
isinstance
(
eq
,
ast
.
Node
),
"Only equations and ast nodes are allowed in input"
typedEquations
.
append
(
eq
)
return
object
typedEquations
=
typedEquations
typedEquations
=
visit
(
eqs
)
return
fieldsRead
,
fieldsWritten
,
typedEquations
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment