Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Sebastian Bindgen
pystencils
Commits
3de207f4
Commit
3de207f4
authored
Feb 23, 2017
by
Jan Hoenig
Browse files
Found some bugs. Dtypes should be now everywhere, i believe it was desympied now
parent
53b20223
Changes
4
Hide whitespace changes
Inline
Side-by-side
astnodes.py
View file @
3de207f4
...
...
@@ -342,7 +342,8 @@ class SympyAssignment(Node):
def
replace
(
self
,
child
,
replacement
):
if
child
==
self
.
lhs
:
self
.
lhs
=
child
replacement
.
parent
=
self
self
.
lhs
=
replacement
elif
child
==
self
.
rhs
:
replacement
.
parent
=
self
self
.
rhs
=
replacement
...
...
@@ -478,6 +479,10 @@ class Pow(Expr):
class
Indexed
(
Expr
):
def
__init__
(
self
,
args
,
base
,
parent
=
None
):
super
(
Indexed
,
self
).
__init__
(
args
,
parent
)
self
.
base
=
base
def
__repr__
(
self
):
return
'%s[%s]'
%
(
self
.
args
[
0
],
self
.
args
[
1
])
...
...
@@ -506,6 +511,6 @@ class Number(Node, sp.AtomicExpr):
raise
set
()
def
__repr__
(
self
):
return
repr
(
self
.
dtype
)
+
repr
(
self
.
value
)
return
repr
(
self
.
value
)
backends/__init__.py
View file @
3de207f4
from
.llvm
import
generateLLVM
from
.cbackend
import
generateC
,
generateCUDA
from
.dot
import
dotprint
backends/dot.py
View file @
3de207f4
...
...
@@ -6,9 +6,10 @@ class DotPrinter(Printer):
"""
A printer which converts ast to DOT (graph description language).
"""
def
__init__
(
self
,
nodeToStrFunction
,
**
kwargs
):
def
__init__
(
self
,
nodeToStrFunction
,
full
,
**
kwargs
):
super
(
DotPrinter
,
self
).
__init__
()
self
.
_nodeToStrFunction
=
nodeToStrFunction
self
.
full
=
full
self
.
dot
=
Digraph
(
**
kwargs
)
self
.
dot
.
quote_edge
=
lang
.
quote
...
...
@@ -30,6 +31,21 @@ class DotPrinter(Printer):
def
_print_SympyAssignment
(
self
,
assignment
):
self
.
dot
.
node
(
self
.
_nodeToStrFunction
(
assignment
))
if
self
.
full
:
for
node
in
assignment
.
args
:
self
.
_print
(
node
)
for
node
in
assignment
.
args
:
self
.
dot
.
edge
(
self
.
_nodeToStrFunction
(
assignment
),
self
.
_nodeToStrFunction
(
node
))
def
emptyPrinter
(
self
,
expr
):
if
self
.
full
:
self
.
dot
.
node
(
self
.
_nodeToStrFunction
(
expr
))
for
node
in
expr
.
args
:
self
.
_print
(
node
)
for
node
in
expr
.
args
:
self
.
dot
.
edge
(
self
.
_nodeToStrFunction
(
expr
),
self
.
_nodeToStrFunction
(
node
))
else
:
raise
NotImplemented
(
'Dotprinter cannot print'
,
expr
)
def
doprint
(
self
,
expr
):
self
.
_print
(
expr
)
...
...
@@ -48,17 +64,20 @@ def __shortened(node):
return
"Assignment: "
+
repr
(
node
.
lhs
)
def
dotprint
(
ast
,
view
=
False
,
short
=
False
,
**
kwargs
):
def
dotprint
(
node
,
view
=
False
,
short
=
False
,
full
=
False
,
**
kwargs
):
"""
Returns a string which can be used to generate a DOT-graph
:param
ast
: The ast which should be generated
:param
node
: The ast which should be generated
:param view: Boolen, if rendering of the image directly should occur.
:param short: Uses the __shortened output
:param full: Prints the whole tree with type information
:param kwargs: is directly passed to the DotPrinter class: http://graphviz.readthedocs.io/en/latest/api.html#digraph
:return: string in DOT format
"""
nodeToStrFunction
=
__shortened
if
short
else
repr
printer
=
DotPrinter
(
nodeToStrFunction
,
**
kwargs
)
dot
=
printer
.
doprint
(
ast
)
nodeToStrFunction
=
lambda
expr
:
repr
(
type
(
expr
))
+
repr
(
expr
)
if
full
else
nodeToStrFunction
printer
=
DotPrinter
(
nodeToStrFunction
,
full
,
**
kwargs
)
dot
=
printer
.
doprint
(
node
)
if
view
:
printer
.
dot
.
render
(
view
=
view
)
return
dot
...
...
@@ -80,4 +99,4 @@ if __name__ == "__main__":
from
pystencils.cpu
import
createKernel
ast
=
createKernel
([
updateRule
])
print
(
dotprint
(
ast
,
short
=
True
))
\ No newline at end of file
print
(
dotprint
(
ast
,
short
=
True
))
transformations.py
View file @
3de207f4
...
...
@@ -548,7 +548,7 @@ def get_type(node):
def
insert_casts
(
node
):
"""
Inserts casts where needed
Inserts casts
and dtype
where needed
:param node: ast which should be traversed
:return: node
"""
...
...
@@ -559,7 +559,7 @@ def insert_casts(node):
print
(
arg
)
insert_casts
(
arg
)
if
isinstance
(
node
,
ast
.
Indexed
):
pass
node
.
dtype
=
node
.
base
.
label
.
dtype
elif
isinstance
(
node
,
ast
.
Expr
):
print
(
node
)
print
([(
arg
,
type
(
arg
),
arg
.
dtype
,
type
(
arg
.
dtype
))
for
arg
in
node
.
args
])
...
...
@@ -594,9 +594,31 @@ def desympy_ast(node):
node
.
replace
(
arg
,
ast
.
Mul
(
arg
.
args
,
node
))
elif
isinstance
(
arg
,
sp
.
Pow
):
node
.
replace
(
arg
,
ast
.
Pow
(
arg
.
args
,
node
))
elif
isinstance
(
arg
,
sp
.
tensor
.
Indexed
):
node
.
replace
(
arg
,
ast
.
Indexed
(
arg
.
args
,
node
))
#elif isinstance(arg, )
elif
isinstance
(
arg
,
sp
.
tensor
.
Indexed
)
or
isinstance
(
arg
,
sp
.
tensor
.
indexed
.
Indexed
):
node
.
replace
(
arg
,
ast
.
Indexed
(
arg
.
args
,
arg
.
base
,
node
))
elif
isinstance
(
arg
,
sp
.
tensor
.
IndexedBase
):
node
.
replace
(
arg
,
arg
.
label
)
#elif isinstance(arg, sp.containers.Tuple):
#
else
:
print
(
'Not transforming:'
,
arg
,
type
(
arg
))
for
arg
in
node
.
args
:
desympy_ast
(
arg
)
return
node
def
check_dtype
(
node
):
if
isinstance
(
node
,
ast
.
KernelFunction
):
pass
elif
isinstance
(
node
,
ast
.
Block
):
pass
elif
isinstance
(
node
,
ast
.
LoopOverCoordinate
):
pass
elif
isinstance
(
node
,
ast
.
SympyAssignment
):
pass
else
:
print
(
node
)
print
(
node
.
dtype
)
for
arg
in
node
.
args
:
check_dtype
(
arg
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment