diff --git a/backends/dot.py b/backends/dot.py index 4ed76f5110f65c2fe421559dedd67d1ec991f8e6..8797c49f07ef6e244f07deac2458c4026d29b2bf 100644 --- a/backends/dot.py +++ b/backends/dot.py @@ -15,40 +15,48 @@ class DotPrinter(Printer): self.dot.quote_edge = lang.quote def _print_KernelFunction(self, func): - self.dot.node(repr(func), style='filled', fillcolor='#a056db', label=self._nodeToStrFunction(func)) + self.dot.node(str(id(func)), style='filled', fillcolor='#a056db', label=self._nodeToStrFunction(func)) self._print(func.body) - self.dot.edge(repr(func), self._nodeToStrFunction(func.body)) + self.dot.edge(str(id(func)), str(id(func.body))) def _print_LoopOverCoordinate(self, loop): - self.dot.node(self._nodeToStrFunction(loop), style='filled', fillcolor='#3498db') + self.dot.node(str(id(loop)), style='filled', fillcolor='#3498db', label=self._nodeToStrFunction(loop)) self._print(loop.body) - self.dot.edge(self._nodeToStrFunction(loop), self._nodeToStrFunction(loop.body)) + self.dot.edge(str(id(loop)), str(id(loop.body))) def _print_Block(self, block): for node in block.args: self._print(node) - self.dot.node(self._nodeToStrFunction(block), style='filled', fillcolor='#dbc256', label=repr(block)) + self.dot.node(str(id(block)), style='filled', fillcolor='#dbc256', label=repr(block)) for node in block.args: - self.dot.edge(self._nodeToStrFunction(block), self._nodeToStrFunction(node)) + self.dot.edge(str(id(block)), str(id(node))) def _print_SympyAssignment(self, assignment): - self.dot.node(self._nodeToStrFunction(assignment), style='filled', fillcolor='#56db7f') + self.dot.node(str(id(assignment)), style='filled', fillcolor='#56db7f', label=self._nodeToStrFunction(assignment)) if self.full: for node in assignment.args: self._print(node) for node in assignment.args: - self.dot.edge(self._nodeToStrFunction(assignment), self._nodeToStrFunction(node)) + self.dot.edge(str(id(assignment)), str(id(node))) + + def _print_Conditional(self, expr): + self.dot.node(str(id(expr)), style='filled', fillcolor='#56bd7f', label=self._nodeToStrFunction(expr)) + self._print(expr.trueBlock) + self.dot.edge(str(id(expr)), str(id(expr.trueBlock))) + if expr.falseBlock: + self._print(expr.falseBlock) + self.dot.edge(str(id(expr)), str(id(expr.falseBlock))) def emptyPrinter(self, expr): if self.full: - self.dot.node(self._nodeToStrFunction(expr)) + self.dot.node(str(id(expr)), label=self._nodeToStrFunction(expr)) for node in expr.args: self._print(node) for node in expr.args: - self.dot.edge(self._nodeToStrFunction(expr), self._nodeToStrFunction(node)) + self.dot.edge(str(id(expr)), str(id(node))) else: - raise NotImplemented('Dotprinter cannot print', expr) + raise NotImplementedError('Dotprinter cannot print', type(expr), expr) def doprint(self, expr): self._print(expr) @@ -56,18 +64,19 @@ class DotPrinter(Printer): def __shortened(node): - from pystencils.astnodes import LoopOverCoordinate, KernelFunction, SympyAssignment, Block + from pystencils.astnodes import LoopOverCoordinate, KernelFunction, SympyAssignment, Block, Conditional if isinstance(node, LoopOverCoordinate): return "Loop over dim %d" % (node.coordinateToLoopOver,) elif isinstance(node, KernelFunction): params = [f.name for f in node.fieldsAccessed] params += [p.name for p in node.parameters if not p.isFieldArgument] - print(params) return "Func: %s (%s)" % (node.functionName, ",".join(params)) elif isinstance(node, SympyAssignment): return repr(node.lhs) elif isinstance(node, Block): return "Block" + str(id(node)) + elif isinstance(node, Conditional): + return repr(node) else: raise NotImplementedError("Cannot handle node type %s" % (type(node),))