dot.py 4.05 KB
Newer Older
1
from sympy.printing.printer import Printer
2
from graphviz import Digraph, lang
3
import graphviz
4
5
6
7
8
9


class DotPrinter(Printer):
    """
    A printer which converts ast to DOT (graph description language).
    """
10
    def __init__(self, nodeToStrFunction, full, **kwargs):
Michael Kuron's avatar
Michael Kuron committed
11
        super(DotPrinter, self).__init__()
Martin Bauer's avatar
Martin Bauer committed
12
        self._nodeToStrFunction = nodeToStrFunction
13
        self.full = full
14
        self.dot = Digraph(**kwargs)
15
        self.dot.quote_edge = lang.quote
16

17
    def _print_KernelFunction(self, func):
Jan Hönig's avatar
Jan Hönig committed
18
        self.dot.node(str(id(func)), style='filled', fillcolor='#a056db', label=self._nodeToStrFunction(func))
19
        self._print(func.body)
Jan Hönig's avatar
Jan Hönig committed
20
        self.dot.edge(str(id(func)), str(id(func.body)))
21
22

    def _print_LoopOverCoordinate(self, loop):
Jan Hönig's avatar
Jan Hönig committed
23
        self.dot.node(str(id(loop)), style='filled', fillcolor='#3498db', label=self._nodeToStrFunction(loop))
24
        self._print(loop.body)
Jan Hönig's avatar
Jan Hönig committed
25
        self.dot.edge(str(id(loop)), str(id(loop.body)))
26
27

    def _print_Block(self, block):
28
        for node in block.args:
29
            self._print(node)
Martin Bauer's avatar
Martin Bauer committed
30

Jan Hönig's avatar
Jan Hönig committed
31
        self.dot.node(str(id(block)), style='filled', fillcolor='#dbc256', label=repr(block))
32
        for node in block.args:
Jan Hönig's avatar
Jan Hönig committed
33
            self.dot.edge(str(id(block)), str(id(node)))
34
35

    def _print_SympyAssignment(self, assignment):
Jan Hönig's avatar
Jan Hönig committed
36
        self.dot.node(str(id(assignment)), style='filled', fillcolor='#56db7f', label=self._nodeToStrFunction(assignment))
37
38
39
40
        if self.full:
            for node in assignment.args:
                self._print(node)
            for node in assignment.args:
Jan Hönig's avatar
Jan Hönig committed
41
42
43
44
45
46
47
48
49
                self.dot.edge(str(id(assignment)), str(id(node)))

    def _print_Conditional(self, expr):
        self.dot.node(str(id(expr)), style='filled', fillcolor='#56bd7f', label=self._nodeToStrFunction(expr))
        self._print(expr.trueBlock)
        self.dot.edge(str(id(expr)), str(id(expr.trueBlock)))
        if expr.falseBlock:
            self._print(expr.falseBlock)
            self.dot.edge(str(id(expr)), str(id(expr.falseBlock)))
50
51
52

    def emptyPrinter(self, expr):
        if self.full:
Jan Hönig's avatar
Jan Hönig committed
53
            self.dot.node(str(id(expr)), label=self._nodeToStrFunction(expr))
54
55
56
            for node in expr.args:
                self._print(node)
            for node in expr.args:
Jan Hönig's avatar
Jan Hönig committed
57
                self.dot.edge(str(id(expr)), str(id(node)))
58
        else:
Jan Hönig's avatar
Jan Hönig committed
59
            raise NotImplementedError('Dotprinter cannot print', type(expr), expr)
60
61
62
63
64
65

    def doprint(self, expr):
        self._print(expr)
        return self.dot.source


Martin Bauer's avatar
Martin Bauer committed
66
def __shortened(node):
Jan Hönig's avatar
Jan Hönig committed
67
    from pystencils.astnodes import LoopOverCoordinate, KernelFunction, SympyAssignment, Block, Conditional
Martin Bauer's avatar
Martin Bauer committed
68
69
70
71
72
73
74
    if isinstance(node, LoopOverCoordinate):
        return "Loop over dim %d" % (node.coordinateToLoopOver,)
    elif isinstance(node, KernelFunction):
        params = [f.name for f in node.fieldsAccessed]
        params += [p.name for p in node.parameters if not p.isFieldArgument]
        return "Func: %s (%s)" % (node.functionName, ",".join(params))
    elif isinstance(node, SympyAssignment):
Martin Bauer's avatar
Martin Bauer committed
75
76
77
        return repr(node.lhs)
    elif isinstance(node, Block):
        return "Block" + str(id(node))
Jan Hönig's avatar
Jan Hönig committed
78
79
    elif isinstance(node, Conditional):
        return repr(node)
Martin Bauer's avatar
Martin Bauer committed
80
81
    else:
        raise NotImplementedError("Cannot handle node type %s" % (type(node),))
Martin Bauer's avatar
Martin Bauer committed
82
83


84
def dotprint(node, view=False, short=False, full=False, **kwargs):
85
86
    """
    Returns a string which can be used to generate a DOT-graph
87
    :param node: The ast which should be generated
88
    :param view: Boolen, if rendering of the image directly should occur.
89
90
    :param short: Uses the __shortened output
    :param full: Prints the whole tree with type information
91
    :param kwargs: is directly passed to the DotPrinter class: http://graphviz.readthedocs.io/en/latest/api.html#digraph
92
93
    :return: string in DOT format
    """
94
    node_to_str_function = repr
95
    if short:
96
        node_to_str_function = __shortened
97
    elif full:
98
99
        node_to_str_function = lambda expr: repr(type(expr)) + repr(expr)
    printer = DotPrinter(node_to_str_function, full, **kwargs)
100
    dot = printer.doprint(node)
101
    if view:
102
        return graphviz.Source(dot)
103
    return dot
Martin Bauer's avatar
Martin Bauer committed
104