From ab3fd339089e23695af09f777f51a399fcb77506 Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Mon, 5 Dec 2016 16:22:17 +0100 Subject: [PATCH] pystencils: display - colored dot output of AST - various display helpers used in documentation --- backends/dot.py | 49 ++++++++++++--- display_utils.py | 155 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 196 insertions(+), 8 deletions(-) create mode 100644 display_utils.py diff --git a/backends/dot.py b/backends/dot.py index fa0f16ec8..3578e525f 100644 --- a/backends/dot.py +++ b/backends/dot.py @@ -6,17 +6,18 @@ class DotPrinter(Printer): """ A printer which converts ast to DOT (graph description language). """ - def __init__(self, **kwargs): + def __init__(self, nodeToStrFunction, **kwargs): super().__init__() + self._nodeToStrFunction = nodeToStrFunction self.dot = Digraph(**kwargs) self.dot.quote_edge = lang.quote def _print_KernelFunction(self, function): - self.dot.node(repr(function)) + self.dot.node(self._nodeToStrFunction(function), style='filled', fillcolor='#E69F00') self._print(function.body) def _print_LoopOverCoordinate(self, loop): - self.dot.node(repr(loop)) + self.dot.node(self._nodeToStrFunction(loop), style='filled', fillcolor='#56B4E9') self._print(loop.body) def _print_Block(self, block): @@ -24,18 +25,30 @@ class DotPrinter(Printer): self._print(node) parent = block.parent for node in block.children(): - self.dot.edge(repr(parent), repr(node)) - parent = node + self.dot.edge(self._nodeToStrFunction(parent), self._nodeToStrFunction(node)) + #parent = node def _print_SympyAssignment(self, assignment): - self.dot.node(repr(assignment)) + self.dot.node(self._nodeToStrFunction(assignment)) def doprint(self, expr): self._print(expr) return self.dot.source -def dotprint(ast, view=False, **kwargs): +def __shortened(node): + from pystencils.ast import LoopOverCoordinate, KernelFunction, SympyAssignment + if isinstance(node, LoopOverCoordinate): + return "Loop over dim %d" % (node.coordinateToLoopOver,) + elif isinstance(node, KernelFunction): + params = [f.name for f in node.fieldsAccessed] + params += [p.name for p in node.parameters if not p.isFieldArgument] + return "Func: %s (%s)" % (node.functionName, ",".join(params)) + elif isinstance(node, SympyAssignment): + return "Assignment: " + repr(node.lhs) + + +def dotprint(ast, view=False, short=False, **kwargs): """ Returns a string which can be used to generate a DOT-graph :param ast: The ast which should be generated @@ -43,8 +56,28 @@ def dotprint(ast, view=False, **kwargs): :param kwargs: is directly passed to the DotPrinter class: http://graphviz.readthedocs.io/en/latest/api.html#digraph :return: string in DOT format """ - printer = DotPrinter(**kwargs) + nodeToStrFunction = __shortened if short else repr + printer = DotPrinter(nodeToStrFunction, **kwargs) dot = printer.doprint(ast) if view: printer.dot.render(view=view) return dot + +if __name__ == "__main__": + from pystencils import Field + import sympy as sp + imgField = Field.createGeneric('I', + spatialDimensions=2, # 2D image + indexDimensions=1) # multiple values per pixel: e.g. RGB + w1, w2 = sp.symbols("w_1 w_2") + sobelX = -w2 * imgField[-1, 0](1) - w1 * imgField[-1, -1](1) - w1 * imgField[-1, +1](1) \ + + w2 * imgField[+1, 0](1) + w1 * imgField[+1, -1](1) - w1 * imgField[+1, +1](1) + sobelX + + dstField = Field.createGeneric('dst', spatialDimensions=2, indexDimensions=0) + updateRule = sp.Eq(dstField[0, 0], sobelX) + updateRule + + from pystencils.cpu import createKernel + ast = createKernel([updateRule]) + print(dotprint(ast, short=True)) \ No newline at end of file diff --git a/display_utils.py b/display_utils.py new file mode 100644 index 000000000..9fff9df1d --- /dev/null +++ b/display_utils.py @@ -0,0 +1,155 @@ + + +def toDot(expr, graphStyle={}): + """Show a sympy or pystencils AST as dot graph""" + from pystencils.ast import Node + import graphviz + if isinstance(expr, Node): + from pystencils.backends.dot import dotprint + return graphviz.Source(dotprint(expr, short=True, graph_attr=graphStyle)) + else: + from sympy.printing.dot import dotprint + return graphviz.Source(dotprint(expr, graph_attr=graphStyle)) + + +def highlightCpp(code): + """Highlight the given C/C++ source code with Pygments""" + from IPython.display import HTML, display + from pygments import highlight + from pygments.formatters import HtmlFormatter + from pygments.lexers import CppLexer + + display(HTML(""" + <style> + {pygments_css} + </style> + """.format(pygments_css=HtmlFormatter().get_style_defs('.highlight')))) + return HTML(highlight(code, CppLexer(), HtmlFormatter())) + + +# ----------------- Embedding of animations as videos in IPython notebooks --------------------------------------------- + + +# ------- Version 1: Animation is embedded as an HTML5 Video tag --------------------------------------- + +VIDEO_TAG = """<video controls width="100%"> + <source src="data:video/x-m4v;base64,{0}" type="video/mp4"> + Your browser does not support the video tag. +</video>""" + + +def __anim_to_html(anim, fps): + from tempfile import NamedTemporaryFile + import base64 + + if not hasattr(anim, '_encoded_video'): + with NamedTemporaryFile(suffix='.mp4') as f: + anim.save(f.name, fps=fps, extra_args=['-vcodec', 'libx264', '-pix_fmt', + 'yuv420p', '-profile:v', 'baseline', '-level', '3.0']) + video = open(f.name, "rb").read() + anim._encoded_video = base64.b64encode(video).decode('ascii') + + return VIDEO_TAG.format(anim._encoded_video) + + +def disp_as_video(anim, fps=30, show=True, **kwargs): + import matplotlib.pyplot as plt + from IPython.display import HTML + + try: + plt.close(anim._fig) + res = __anim_to_html(anim, fps) + if show: + return HTML(res) + else: + return HTML("") + except KeyboardInterrupt: + pass +# ------- Version 2: Animation is shown in extra matplotlib window ---------------------------------- + + +def disp_extra_window(animation, *args,**kwargs): + import matplotlib.pyplot as plt + + fig = plt.gcf() + try: + fig.canvas.manager.window.raise_() + except Exception: + pass + plt.show() + + +# ------- Version 3: Animation is shown in images that are updated directly in website -------------- + +def disp_image_update(animation, iterations=10000, *args, **kwargs): + from IPython import display + import matplotlib.pyplot as plt + + try: + fig = plt.gcf() + animation._init_draw() + for i in range(iterations): + display.display(fig) + animation._step() + display.clear_output(wait=True) + except KeyboardInterrupt: + pass + + +# Dispatcher + +animation_display_mode = 'imageupdate' +display_animation_func = None + + +def disp(*args, **kwargs): + if not display_animation_func: + raise Exception("Call set_display_mode first") + return display_animation_func(*args, **kwargs) + + +def set_display_mode(mode): + from IPython import get_ipython + ipython = get_ipython() + global animation_display_mode + global display_animation_func + animation_display_mode = mode + if animation_display_mode == 'video': + ipython.magic("matplotlib inline") + display_animation_func = disp_as_video + elif animation_display_mode == 'window': + ipython.magic("matplotlib qt") + display_animation_func = disp_extra_window + elif animation_display_mode == 'imageupdate': + ipython.magic("matplotlib inline") + display_animation_func = disp_image_update + else: + raise Exception("Unknown mode. Available modes 'imageupdate', 'video' and 'window' ") + + +set_display_mode('video') + + +# --------------------- Convenience functions -------------------------------------------------------------------------- + + +def makeSurfacePlotAnimation(runFunction, frames=90, interval=30): + from mpl_toolkits.mplot3d import Axes3D + import matplotlib.animation as animation + import matplotlib.pyplot as plt + from matplotlib import cm + + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + X, Y, data = runFunction(1) + ax.plot_surface(X, Y, data, rstride=2, cstride=2, color='b', cmap=cm.coolwarm,) + ax.set_zlim(-1.0, 1.0) + + def updatefig(*args): + X, Y, data = runFunction(1) + ax.clear() + plot = ax.plot_surface(X, Y, data, rstride=2, cstride=2, color='b', cmap=cm.coolwarm,) + ax.set_zlim(-1.0, 1.0) + return plot, + + return animation.FuncAnimation(fig, updatefig, interval=interval, frames=frames, blit=False) -- GitLab