from sympy.printing.printer import Printer from graphviz import Digraph, lang class DotPrinter(Printer): """ A printer which converts ast to DOT (graph description language). """ def __init__(self, nodeToStrFunction, **kwargs): super().__init__() self._nodeToStrFunction = nodeToStrFunction self.dot = Digraph(**kwargs) self.dot.quote_edge = lang.quote def _print_KernelFunction(self, function): self.dot.node(self._nodeToStrFunction(function), style='filled', fillcolor='#E69F00') self._print(function.body) def _print_LoopOverCoordinate(self, loop): self.dot.node(self._nodeToStrFunction(loop), style='filled', fillcolor='#56B4E9') self._print(loop.body) def _print_Block(self, block): for node in block.children(): self._print(node) parent = block.parent for node in block.children(): self.dot.edge(self._nodeToStrFunction(parent), self._nodeToStrFunction(node)) #parent = node def _print_SympyAssignment(self, assignment): self.dot.node(self._nodeToStrFunction(assignment)) def doprint(self, expr): self._print(expr) return self.dot.source def __shortened(node): from pystencils.ast import LoopOverCoordinate, KernelFunction, SympyAssignment if isinstance(node, LoopOverCoordinate): return "Loop over dim %d" % (node.coordinateToLoopOver,) elif isinstance(node, KernelFunction): params = [f.name for f in node.fieldsAccessed] params += [p.name for p in node.parameters if not p.isFieldArgument] return "Func: %s (%s)" % (node.functionName, ",".join(params)) elif isinstance(node, SympyAssignment): return "Assignment: " + repr(node.lhs) def dotprint(ast, view=False, short=False, **kwargs): """ Returns a string which can be used to generate a DOT-graph :param ast: The ast which should be generated :param view: Boolen, if rendering of the image directly should occur. :param kwargs: is directly passed to the DotPrinter class: http://graphviz.readthedocs.io/en/latest/api.html#digraph :return: string in DOT format """ nodeToStrFunction = __shortened if short else repr printer = DotPrinter(nodeToStrFunction, **kwargs) dot = printer.doprint(ast) if view: printer.dot.render(view=view) return dot if __name__ == "__main__": from pystencils import Field import sympy as sp imgField = Field.createGeneric('I', spatialDimensions=2, # 2D image indexDimensions=1) # multiple values per pixel: e.g. RGB w1, w2 = sp.symbols("w_1 w_2") sobelX = -w2 * imgField[-1, 0](1) - w1 * imgField[-1, -1](1) - w1 * imgField[-1, +1](1) \ + w2 * imgField[+1, 0](1) + w1 * imgField[+1, -1](1) - w1 * imgField[+1, +1](1) sobelX dstField = Field.createGeneric('dst', spatialDimensions=2, indexDimensions=0) updateRule = sp.Eq(dstField[0, 0], sobelX) updateRule from pystencils.cpu import createKernel ast = createKernel([updateRule]) print(dotprint(ast, short=True))