Skip to content
Snippets Groups Projects
Commit ab3fd339 authored by Martin Bauer's avatar Martin Bauer
Browse files

pystencils: display

- colored dot output of AST
- various display helpers used in documentation
parent 5eefd760
Branches
Tags
No related merge requests found
...@@ -6,17 +6,18 @@ class DotPrinter(Printer): ...@@ -6,17 +6,18 @@ class DotPrinter(Printer):
""" """
A printer which converts ast to DOT (graph description language). A printer which converts ast to DOT (graph description language).
""" """
def __init__(self, **kwargs): def __init__(self, nodeToStrFunction, **kwargs):
super().__init__() super().__init__()
self._nodeToStrFunction = nodeToStrFunction
self.dot = Digraph(**kwargs) self.dot = Digraph(**kwargs)
self.dot.quote_edge = lang.quote self.dot.quote_edge = lang.quote
def _print_KernelFunction(self, function): def _print_KernelFunction(self, function):
self.dot.node(repr(function)) self.dot.node(self._nodeToStrFunction(function), style='filled', fillcolor='#E69F00')
self._print(function.body) self._print(function.body)
def _print_LoopOverCoordinate(self, loop): def _print_LoopOverCoordinate(self, loop):
self.dot.node(repr(loop)) self.dot.node(self._nodeToStrFunction(loop), style='filled', fillcolor='#56B4E9')
self._print(loop.body) self._print(loop.body)
def _print_Block(self, block): def _print_Block(self, block):
...@@ -24,18 +25,30 @@ class DotPrinter(Printer): ...@@ -24,18 +25,30 @@ class DotPrinter(Printer):
self._print(node) self._print(node)
parent = block.parent parent = block.parent
for node in block.children(): for node in block.children():
self.dot.edge(repr(parent), repr(node)) self.dot.edge(self._nodeToStrFunction(parent), self._nodeToStrFunction(node))
parent = node #parent = node
def _print_SympyAssignment(self, assignment): def _print_SympyAssignment(self, assignment):
self.dot.node(repr(assignment)) self.dot.node(self._nodeToStrFunction(assignment))
def doprint(self, expr): def doprint(self, expr):
self._print(expr) self._print(expr)
return self.dot.source return self.dot.source
def dotprint(ast, view=False, **kwargs): def __shortened(node):
from pystencils.ast import LoopOverCoordinate, KernelFunction, SympyAssignment
if isinstance(node, LoopOverCoordinate):
return "Loop over dim %d" % (node.coordinateToLoopOver,)
elif isinstance(node, KernelFunction):
params = [f.name for f in node.fieldsAccessed]
params += [p.name for p in node.parameters if not p.isFieldArgument]
return "Func: %s (%s)" % (node.functionName, ",".join(params))
elif isinstance(node, SympyAssignment):
return "Assignment: " + repr(node.lhs)
def dotprint(ast, view=False, short=False, **kwargs):
""" """
Returns a string which can be used to generate a DOT-graph Returns a string which can be used to generate a DOT-graph
:param ast: The ast which should be generated :param ast: The ast which should be generated
...@@ -43,8 +56,28 @@ def dotprint(ast, view=False, **kwargs): ...@@ -43,8 +56,28 @@ def dotprint(ast, view=False, **kwargs):
:param kwargs: is directly passed to the DotPrinter class: http://graphviz.readthedocs.io/en/latest/api.html#digraph :param kwargs: is directly passed to the DotPrinter class: http://graphviz.readthedocs.io/en/latest/api.html#digraph
:return: string in DOT format :return: string in DOT format
""" """
printer = DotPrinter(**kwargs) nodeToStrFunction = __shortened if short else repr
printer = DotPrinter(nodeToStrFunction, **kwargs)
dot = printer.doprint(ast) dot = printer.doprint(ast)
if view: if view:
printer.dot.render(view=view) printer.dot.render(view=view)
return dot return dot
if __name__ == "__main__":
from pystencils import Field
import sympy as sp
imgField = Field.createGeneric('I',
spatialDimensions=2, # 2D image
indexDimensions=1) # multiple values per pixel: e.g. RGB
w1, w2 = sp.symbols("w_1 w_2")
sobelX = -w2 * imgField[-1, 0](1) - w1 * imgField[-1, -1](1) - w1 * imgField[-1, +1](1) \
+ w2 * imgField[+1, 0](1) + w1 * imgField[+1, -1](1) - w1 * imgField[+1, +1](1)
sobelX
dstField = Field.createGeneric('dst', spatialDimensions=2, indexDimensions=0)
updateRule = sp.Eq(dstField[0, 0], sobelX)
updateRule
from pystencils.cpu import createKernel
ast = createKernel([updateRule])
print(dotprint(ast, short=True))
\ No newline at end of file
def toDot(expr, graphStyle={}):
"""Show a sympy or pystencils AST as dot graph"""
from pystencils.ast import Node
import graphviz
if isinstance(expr, Node):
from pystencils.backends.dot import dotprint
return graphviz.Source(dotprint(expr, short=True, graph_attr=graphStyle))
else:
from sympy.printing.dot import dotprint
return graphviz.Source(dotprint(expr, graph_attr=graphStyle))
def highlightCpp(code):
"""Highlight the given C/C++ source code with Pygments"""
from IPython.display import HTML, display
from pygments import highlight
from pygments.formatters import HtmlFormatter
from pygments.lexers import CppLexer
display(HTML("""
<style>
{pygments_css}
</style>
""".format(pygments_css=HtmlFormatter().get_style_defs('.highlight'))))
return HTML(highlight(code, CppLexer(), HtmlFormatter()))
# ----------------- Embedding of animations as videos in IPython notebooks ---------------------------------------------
# ------- Version 1: Animation is embedded as an HTML5 Video tag ---------------------------------------
VIDEO_TAG = """<video controls width="100%">
<source src="data:video/x-m4v;base64,{0}" type="video/mp4">
Your browser does not support the video tag.
</video>"""
def __anim_to_html(anim, fps):
from tempfile import NamedTemporaryFile
import base64
if not hasattr(anim, '_encoded_video'):
with NamedTemporaryFile(suffix='.mp4') as f:
anim.save(f.name, fps=fps, extra_args=['-vcodec', 'libx264', '-pix_fmt',
'yuv420p', '-profile:v', 'baseline', '-level', '3.0'])
video = open(f.name, "rb").read()
anim._encoded_video = base64.b64encode(video).decode('ascii')
return VIDEO_TAG.format(anim._encoded_video)
def disp_as_video(anim, fps=30, show=True, **kwargs):
import matplotlib.pyplot as plt
from IPython.display import HTML
try:
plt.close(anim._fig)
res = __anim_to_html(anim, fps)
if show:
return HTML(res)
else:
return HTML("")
except KeyboardInterrupt:
pass
# ------- Version 2: Animation is shown in extra matplotlib window ----------------------------------
def disp_extra_window(animation, *args,**kwargs):
import matplotlib.pyplot as plt
fig = plt.gcf()
try:
fig.canvas.manager.window.raise_()
except Exception:
pass
plt.show()
# ------- Version 3: Animation is shown in images that are updated directly in website --------------
def disp_image_update(animation, iterations=10000, *args, **kwargs):
from IPython import display
import matplotlib.pyplot as plt
try:
fig = plt.gcf()
animation._init_draw()
for i in range(iterations):
display.display(fig)
animation._step()
display.clear_output(wait=True)
except KeyboardInterrupt:
pass
# Dispatcher
animation_display_mode = 'imageupdate'
display_animation_func = None
def disp(*args, **kwargs):
if not display_animation_func:
raise Exception("Call set_display_mode first")
return display_animation_func(*args, **kwargs)
def set_display_mode(mode):
from IPython import get_ipython
ipython = get_ipython()
global animation_display_mode
global display_animation_func
animation_display_mode = mode
if animation_display_mode == 'video':
ipython.magic("matplotlib inline")
display_animation_func = disp_as_video
elif animation_display_mode == 'window':
ipython.magic("matplotlib qt")
display_animation_func = disp_extra_window
elif animation_display_mode == 'imageupdate':
ipython.magic("matplotlib inline")
display_animation_func = disp_image_update
else:
raise Exception("Unknown mode. Available modes 'imageupdate', 'video' and 'window' ")
set_display_mode('video')
# --------------------- Convenience functions --------------------------------------------------------------------------
def makeSurfacePlotAnimation(runFunction, frames=90, interval=30):
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.animation as animation
import matplotlib.pyplot as plt
from matplotlib import cm
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
X, Y, data = runFunction(1)
ax.plot_surface(X, Y, data, rstride=2, cstride=2, color='b', cmap=cm.coolwarm,)
ax.set_zlim(-1.0, 1.0)
def updatefig(*args):
X, Y, data = runFunction(1)
ax.clear()
plot = ax.plot_surface(X, Y, data, rstride=2, cstride=2, color='b', cmap=cm.coolwarm,)
ax.set_zlim(-1.0, 1.0)
return plot,
return animation.FuncAnimation(fig, updatefig, interval=interval, frames=frames, blit=False)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment