dot.py 3.42 KB
Newer Older
1
from sympy.printing.printer import Printer
2
from graphviz import Digraph, lang
3
import graphviz
4
5


Martin Bauer's avatar
Martin Bauer committed
6
# noinspection PyPep8Naming
7
8
9
10
class DotPrinter(Printer):
    """
    A printer which converts ast to DOT (graph description language).
    """
11
    def __init__(self, node_to_str_function, **kwargs):
Michael Kuron's avatar
Michael Kuron committed
12
        super(DotPrinter, self).__init__()
Martin Bauer's avatar
Martin Bauer committed
13
        self._node_to_str_function = node_to_str_function
14
        self.dot = Digraph(**kwargs)
15
        self.dot.quote_edge = lang.quote
16

17
    def _print_KernelFunction(self, func):
Martin Bauer's avatar
Martin Bauer committed
18
        self.dot.node(str(id(func)), style='filled', fillcolor='#a056db', label=self._node_to_str_function(func))
19
        self._print(func.body)
Jan Hönig's avatar
Jan Hönig committed
20
        self.dot.edge(str(id(func)), str(id(func.body)))
21
22

    def _print_LoopOverCoordinate(self, loop):
Martin Bauer's avatar
Martin Bauer committed
23
        self.dot.node(str(id(loop)), style='filled', fillcolor='#3498db', label=self._node_to_str_function(loop))
24
        self._print(loop.body)
Jan Hönig's avatar
Jan Hönig committed
25
        self.dot.edge(str(id(loop)), str(id(loop.body)))
26
27

    def _print_Block(self, block):
28
        for node in block.args:
29
            self._print(node)
Martin Bauer's avatar
Martin Bauer committed
30

Jan Hönig's avatar
Jan Hönig committed
31
        self.dot.node(str(id(block)), style='filled', fillcolor='#dbc256', label=repr(block))
32
        for node in block.args:
Jan Hönig's avatar
Jan Hönig committed
33
            self.dot.edge(str(id(block)), str(id(node)))
34
35

    def _print_SympyAssignment(self, assignment):
Martin Bauer's avatar
Martin Bauer committed
36
        self.dot.node(str(id(assignment)), style='filled', fillcolor='#56db7f',
Martin Bauer's avatar
Martin Bauer committed
37
                      label=self._node_to_str_function(assignment))
Jan Hönig's avatar
Jan Hönig committed
38
39

    def _print_Conditional(self, expr):
Martin Bauer's avatar
Martin Bauer committed
40
41
42
43
44
45
        self.dot.node(str(id(expr)), style='filled', fillcolor='#56bd7f', label=self._node_to_str_function(expr))
        self._print(expr.true_block)
        self.dot.edge(str(id(expr)), str(id(expr.true_block)))
        if expr.false_block:
            self._print(expr.false_block)
            self.dot.edge(str(id(expr)), str(id(expr.false_block)))
46

47
48
49
50
51
    def doprint(self, expr):
        self._print(expr)
        return self.dot.source


Martin Bauer's avatar
Martin Bauer committed
52
def __shortened(node):
Jan Hönig's avatar
Jan Hönig committed
53
    from pystencils.astnodes import LoopOverCoordinate, KernelFunction, SympyAssignment, Block, Conditional
Martin Bauer's avatar
Martin Bauer committed
54
    if isinstance(node, LoopOverCoordinate):
Martin Bauer's avatar
Martin Bauer committed
55
        return "Loop over dim %d" % (node.coordinate_to_loop_over,)
Martin Bauer's avatar
Martin Bauer committed
56
    elif isinstance(node, KernelFunction):
57
58
59
60
        params = node.get_parameters()
        param_names = [p.field_name for p in params if p.is_field_pointer]
        param_names += [p.symbol.name for p in params if not p.is_field_parameter]
        return "Func: %s (%s)" % (node.function_name, ",".join(param_names))
Martin Bauer's avatar
Martin Bauer committed
61
    elif isinstance(node, SympyAssignment):
Martin Bauer's avatar
Martin Bauer committed
62
63
64
        return repr(node.lhs)
    elif isinstance(node, Block):
        return "Block" + str(id(node))
Jan Hönig's avatar
Jan Hönig committed
65
66
    elif isinstance(node, Conditional):
        return repr(node)
Martin Bauer's avatar
Martin Bauer committed
67
68
    else:
        raise NotImplementedError("Cannot handle node type %s" % (type(node),))
Martin Bauer's avatar
Martin Bauer committed
69
70


71
def print_dot(node, view=False, short=False, **kwargs):
72
73
    """
    Returns a string which can be used to generate a DOT-graph
74
    :param node: The ast which should be generated
Martin Bauer's avatar
Martin Bauer committed
75
    :param view: Boolean, if rendering of the image directly should occur.
76
    :param short: Uses the __shortened output
77
    :param kwargs: is directly passed to the DotPrinter class: http://graphviz.readthedocs.io/en/latest/api.html#digraph
78
79
    :return: string in DOT format
    """
80
    node_to_str_function = repr
81
    if short:
82
        node_to_str_function = __shortened
83
    printer = DotPrinter(node_to_str_function, **kwargs)
84
    dot = printer.doprint(node)
85
    if view:
86
        return graphviz.Source(dot)
87
    return dot