Commit ab3fd339 authored by Martin Bauer's avatar Martin Bauer
Browse files

pystencils: display

- colored dot output of AST
- various display helpers used in documentation
parent 5eefd760
......@@ -6,17 +6,18 @@ class DotPrinter(Printer):
A printer which converts ast to DOT (graph description language).
def __init__(self, **kwargs):
def __init__(self, nodeToStrFunction, **kwargs):
self._nodeToStrFunction = nodeToStrFunction = Digraph(**kwargs) = lang.quote
def _print_KernelFunction(self, function):, style='filled', fillcolor='#E69F00')
def _print_LoopOverCoordinate(self, loop):, style='filled', fillcolor='#56B4E9')
def _print_Block(self, block):
......@@ -24,18 +25,30 @@ class DotPrinter(Printer):
parent = block.parent
for node in block.children():, repr(node))
parent = node, self._nodeToStrFunction(node))
#parent = node
def _print_SympyAssignment(self, assignment):
def doprint(self, expr):
def dotprint(ast, view=False, **kwargs):
def __shortened(node):
from pystencils.ast import LoopOverCoordinate, KernelFunction, SympyAssignment
if isinstance(node, LoopOverCoordinate):
return "Loop over dim %d" % (node.coordinateToLoopOver,)
elif isinstance(node, KernelFunction):
params = [ for f in node.fieldsAccessed]
params += [ for p in node.parameters if not p.isFieldArgument]
return "Func: %s (%s)" % (node.functionName, ",".join(params))
elif isinstance(node, SympyAssignment):
return "Assignment: " + repr(node.lhs)
def dotprint(ast, view=False, short=False, **kwargs):
Returns a string which can be used to generate a DOT-graph
:param ast: The ast which should be generated
......@@ -43,8 +56,28 @@ def dotprint(ast, view=False, **kwargs):
:param kwargs: is directly passed to the DotPrinter class:
:return: string in DOT format
printer = DotPrinter(**kwargs)
nodeToStrFunction = __shortened if short else repr
printer = DotPrinter(nodeToStrFunction, **kwargs)
dot = printer.doprint(ast)
if view:
return dot
if __name__ == "__main__":
from pystencils import Field
import sympy as sp
imgField = Field.createGeneric('I',
spatialDimensions=2, # 2D image
indexDimensions=1) # multiple values per pixel: e.g. RGB
w1, w2 = sp.symbols("w_1 w_2")
sobelX = -w2 * imgField[-1, 0](1) - w1 * imgField[-1, -1](1) - w1 * imgField[-1, +1](1) \
+ w2 * imgField[+1, 0](1) + w1 * imgField[+1, -1](1) - w1 * imgField[+1, +1](1)
dstField = Field.createGeneric('dst', spatialDimensions=2, indexDimensions=0)
updateRule = sp.Eq(dstField[0, 0], sobelX)
from pystencils.cpu import createKernel
ast = createKernel([updateRule])
print(dotprint(ast, short=True))
\ No newline at end of file
def toDot(expr, graphStyle={}):
"""Show a sympy or pystencils AST as dot graph"""
from pystencils.ast import Node
import graphviz
if isinstance(expr, Node):
from import dotprint
return graphviz.Source(dotprint(expr, short=True, graph_attr=graphStyle))
from import dotprint
return graphviz.Source(dotprint(expr, graph_attr=graphStyle))
def highlightCpp(code):
"""Highlight the given C/C++ source code with Pygments"""
from IPython.display import HTML, display
from pygments import highlight
from pygments.formatters import HtmlFormatter
from pygments.lexers import CppLexer
return HTML(highlight(code, CppLexer(), HtmlFormatter()))
# ----------------- Embedding of animations as videos in IPython notebooks ---------------------------------------------
# ------- Version 1: Animation is embedded as an HTML5 Video tag ---------------------------------------
VIDEO_TAG = """<video controls width="100%">
<source src="data:video/x-m4v;base64,{0}" type="video/mp4">
Your browser does not support the video tag.
def __anim_to_html(anim, fps):
from tempfile import NamedTemporaryFile
import base64
if not hasattr(anim, '_encoded_video'):
with NamedTemporaryFile(suffix='.mp4') as f:, fps=fps, extra_args=['-vcodec', 'libx264', '-pix_fmt',
'yuv420p', '-profile:v', 'baseline', '-level', '3.0'])
video = open(, "rb").read()
anim._encoded_video = base64.b64encode(video).decode('ascii')
return VIDEO_TAG.format(anim._encoded_video)
def disp_as_video(anim, fps=30, show=True, **kwargs):
import matplotlib.pyplot as plt
from IPython.display import HTML
res = __anim_to_html(anim, fps)
if show:
return HTML(res)
return HTML("")
except KeyboardInterrupt:
# ------- Version 2: Animation is shown in extra matplotlib window ----------------------------------
def disp_extra_window(animation, *args,**kwargs):
import matplotlib.pyplot as plt
fig = plt.gcf()
except Exception:
# ------- Version 3: Animation is shown in images that are updated directly in website --------------
def disp_image_update(animation, iterations=10000, *args, **kwargs):
from IPython import display
import matplotlib.pyplot as plt
fig = plt.gcf()
for i in range(iterations):
except KeyboardInterrupt:
# Dispatcher
animation_display_mode = 'imageupdate'
display_animation_func = None
def disp(*args, **kwargs):
if not display_animation_func:
raise Exception("Call set_display_mode first")
return display_animation_func(*args, **kwargs)
def set_display_mode(mode):
from IPython import get_ipython
ipython = get_ipython()
global animation_display_mode
global display_animation_func
animation_display_mode = mode
if animation_display_mode == 'video':
ipython.magic("matplotlib inline")
display_animation_func = disp_as_video
elif animation_display_mode == 'window':
ipython.magic("matplotlib qt")
display_animation_func = disp_extra_window
elif animation_display_mode == 'imageupdate':
ipython.magic("matplotlib inline")
display_animation_func = disp_image_update
raise Exception("Unknown mode. Available modes 'imageupdate', 'video' and 'window' ")
# --------------------- Convenience functions --------------------------------------------------------------------------
def makeSurfacePlotAnimation(runFunction, frames=90, interval=30):
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.animation as animation
import matplotlib.pyplot as plt
from matplotlib import cm
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
X, Y, data = runFunction(1)
ax.plot_surface(X, Y, data, rstride=2, cstride=2, color='b', cmap=cm.coolwarm,)
ax.set_zlim(-1.0, 1.0)
def updatefig(*args):
X, Y, data = runFunction(1)
plot = ax.plot_surface(X, Y, data, rstride=2, cstride=2, color='b', cmap=cm.coolwarm,)
ax.set_zlim(-1.0, 1.0)
return plot,
return animation.FuncAnimation(fig, updatefig, interval=interval, frames=frames, blit=False)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment