From ab3fd339089e23695af09f777f51a399fcb77506 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Mon, 5 Dec 2016 16:22:17 +0100
Subject: [PATCH] pystencils: display

- colored dot output of AST
- various display helpers used in documentation
---
 backends/dot.py  |  49 ++++++++++++---
 display_utils.py | 155 +++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 196 insertions(+), 8 deletions(-)
 create mode 100644 display_utils.py

diff --git a/backends/dot.py b/backends/dot.py
index fa0f16ec8..3578e525f 100644
--- a/backends/dot.py
+++ b/backends/dot.py
@@ -6,17 +6,18 @@ class DotPrinter(Printer):
     """
     A printer which converts ast to DOT (graph description language).
     """
-    def __init__(self, **kwargs):
+    def __init__(self, nodeToStrFunction, **kwargs):
         super().__init__()
+        self._nodeToStrFunction = nodeToStrFunction
         self.dot = Digraph(**kwargs)
         self.dot.quote_edge = lang.quote
 
     def _print_KernelFunction(self, function):
-        self.dot.node(repr(function))
+        self.dot.node(self._nodeToStrFunction(function), style='filled', fillcolor='#E69F00')
         self._print(function.body)
 
     def _print_LoopOverCoordinate(self, loop):
-        self.dot.node(repr(loop))
+        self.dot.node(self._nodeToStrFunction(loop), style='filled', fillcolor='#56B4E9')
         self._print(loop.body)
 
     def _print_Block(self, block):
@@ -24,18 +25,30 @@ class DotPrinter(Printer):
             self._print(node)
         parent = block.parent
         for node in block.children():
-            self.dot.edge(repr(parent), repr(node))
-            parent = node
+            self.dot.edge(self._nodeToStrFunction(parent), self._nodeToStrFunction(node))
+            #parent = node
 
     def _print_SympyAssignment(self, assignment):
-        self.dot.node(repr(assignment))
+        self.dot.node(self._nodeToStrFunction(assignment))
 
     def doprint(self, expr):
         self._print(expr)
         return self.dot.source
 
 
-def dotprint(ast, view=False, **kwargs):
+def __shortened(node):
+    from pystencils.ast import LoopOverCoordinate, KernelFunction, SympyAssignment
+    if isinstance(node, LoopOverCoordinate):
+        return "Loop over dim %d" % (node.coordinateToLoopOver,)
+    elif isinstance(node, KernelFunction):
+        params = [f.name for f in node.fieldsAccessed]
+        params += [p.name for p in node.parameters if not p.isFieldArgument]
+        return "Func: %s (%s)" % (node.functionName, ",".join(params))
+    elif isinstance(node, SympyAssignment):
+        return "Assignment: " + repr(node.lhs)
+
+
+def dotprint(ast, view=False, short=False, **kwargs):
     """
     Returns a string which can be used to generate a DOT-graph
     :param ast: The ast which should be generated
@@ -43,8 +56,28 @@ def dotprint(ast, view=False, **kwargs):
     :param kwargs: is directly passed to the DotPrinter class: http://graphviz.readthedocs.io/en/latest/api.html#digraph
     :return: string in DOT format
     """
-    printer = DotPrinter(**kwargs)
+    nodeToStrFunction = __shortened if short else repr
+    printer = DotPrinter(nodeToStrFunction, **kwargs)
     dot = printer.doprint(ast)
     if view:
         printer.dot.render(view=view)
     return dot
+
+if __name__ == "__main__":
+    from pystencils import Field
+    import sympy as sp
+    imgField = Field.createGeneric('I',
+                                   spatialDimensions=2, # 2D image
+                                   indexDimensions=1)   # multiple values per pixel: e.g. RGB
+    w1, w2 = sp.symbols("w_1 w_2")
+    sobelX = -w2 * imgField[-1, 0](1) - w1 * imgField[-1, -1](1) - w1 * imgField[-1, +1](1) \
+             + w2 * imgField[+1, 0](1) + w1 * imgField[+1, -1](1) - w1 * imgField[+1, +1](1)
+    sobelX
+
+    dstField = Field.createGeneric('dst', spatialDimensions=2, indexDimensions=0)
+    updateRule = sp.Eq(dstField[0, 0], sobelX)
+    updateRule
+
+    from pystencils.cpu import createKernel
+    ast = createKernel([updateRule])
+    print(dotprint(ast, short=True))
\ No newline at end of file
diff --git a/display_utils.py b/display_utils.py
new file mode 100644
index 000000000..9fff9df1d
--- /dev/null
+++ b/display_utils.py
@@ -0,0 +1,155 @@
+
+
+def toDot(expr, graphStyle={}):
+    """Show a sympy or pystencils AST as dot graph"""
+    from pystencils.ast import Node
+    import graphviz
+    if isinstance(expr, Node):
+        from pystencils.backends.dot import dotprint
+        return graphviz.Source(dotprint(expr, short=True, graph_attr=graphStyle))
+    else:
+        from sympy.printing.dot import dotprint
+        return graphviz.Source(dotprint(expr, graph_attr=graphStyle))
+
+
+def highlightCpp(code):
+    """Highlight the given C/C++ source code with Pygments"""
+    from IPython.display import HTML, display
+    from pygments import highlight
+    from pygments.formatters import HtmlFormatter
+    from pygments.lexers import CppLexer
+
+    display(HTML("""
+            <style>
+            {pygments_css}
+            </style>
+            """.format(pygments_css=HtmlFormatter().get_style_defs('.highlight'))))
+    return HTML(highlight(code, CppLexer(), HtmlFormatter()))
+
+
+# ----------------- Embedding of animations as videos in IPython notebooks ---------------------------------------------
+
+
+# -------   Version 1: Animation is embedded as an HTML5 Video tag  ---------------------------------------
+
+VIDEO_TAG = """<video controls width="100%">
+ <source src="data:video/x-m4v;base64,{0}" type="video/mp4">
+ Your browser does not support the video tag.
+</video>"""
+
+
+def __anim_to_html(anim, fps):
+    from tempfile import NamedTemporaryFile
+    import base64
+
+    if not hasattr(anim, '_encoded_video'):
+        with NamedTemporaryFile(suffix='.mp4') as f:
+            anim.save(f.name, fps=fps, extra_args=['-vcodec', 'libx264', '-pix_fmt',
+                                                   'yuv420p', '-profile:v', 'baseline', '-level', '3.0'])
+            video = open(f.name, "rb").read()
+        anim._encoded_video = base64.b64encode(video).decode('ascii')
+
+    return VIDEO_TAG.format(anim._encoded_video)
+
+
+def disp_as_video(anim, fps=30, show=True, **kwargs):
+    import matplotlib.pyplot as plt
+    from IPython.display import HTML
+
+    try:
+        plt.close(anim._fig)
+        res = __anim_to_html(anim, fps)
+        if show:
+            return HTML(res)
+        else:
+            return HTML("")
+    except KeyboardInterrupt:
+        pass
+# -------   Version 2: Animation is shown in extra matplotlib window ----------------------------------
+
+
+def disp_extra_window(animation, *args,**kwargs):
+    import matplotlib.pyplot as plt
+
+    fig = plt.gcf()
+    try:
+      fig.canvas.manager.window.raise_()
+    except Exception:
+      pass
+    plt.show()
+
+
+# -------   Version 3: Animation is shown in images that are updated directly in website --------------
+
+def disp_image_update(animation, iterations=10000, *args, **kwargs):
+    from IPython import display
+    import matplotlib.pyplot as plt
+
+    try:
+        fig = plt.gcf()
+        animation._init_draw()
+        for i in range(iterations):
+            display.display(fig)
+            animation._step()
+            display.clear_output(wait=True)
+    except KeyboardInterrupt:
+        pass
+
+
+# Dispatcher
+
+animation_display_mode = 'imageupdate'
+display_animation_func = None
+
+
+def disp(*args, **kwargs):
+    if not display_animation_func:
+        raise Exception("Call set_display_mode first")
+    return display_animation_func(*args, **kwargs)
+
+
+def set_display_mode(mode):
+    from IPython import get_ipython
+    ipython = get_ipython()
+    global animation_display_mode
+    global display_animation_func
+    animation_display_mode = mode
+    if animation_display_mode == 'video':
+        ipython.magic("matplotlib inline")
+        display_animation_func = disp_as_video
+    elif animation_display_mode == 'window':
+        ipython.magic("matplotlib qt")
+        display_animation_func = disp_extra_window
+    elif animation_display_mode == 'imageupdate':
+        ipython.magic("matplotlib inline")
+        display_animation_func = disp_image_update
+    else:
+        raise Exception("Unknown mode. Available modes 'imageupdate', 'video' and 'window' ")
+
+
+set_display_mode('video')
+
+
+# --------------------- Convenience functions --------------------------------------------------------------------------
+
+
+def makeSurfacePlotAnimation(runFunction, frames=90, interval=30):
+    from mpl_toolkits.mplot3d import Axes3D
+    import matplotlib.animation as animation
+    import matplotlib.pyplot as plt
+    from matplotlib import cm
+
+    fig = plt.figure()
+    ax = fig.add_subplot(111, projection='3d')
+    X, Y, data = runFunction(1)
+    ax.plot_surface(X, Y, data, rstride=2, cstride=2, color='b', cmap=cm.coolwarm,)
+    ax.set_zlim(-1.0, 1.0)
+
+    def updatefig(*args):
+        X, Y, data = runFunction(1)
+        ax.clear()
+        plot = ax.plot_surface(X, Y, data, rstride=2, cstride=2, color='b', cmap=cm.coolwarm,)
+        ax.set_zlim(-1.0, 1.0)
+        return plot,
+
+    return animation.FuncAnimation(fig, updatefig, interval=interval, frames=frames, blit=False)
-- 
GitLab