diff --git a/astnodes.py b/astnodes.py index ce68fd44412d8d94f59b99e954262bf43e10ee6c..28a4d6c67c934d6ccb7bfa920bc6b35387525c6f 100644 --- a/astnodes.py +++ b/astnodes.py @@ -342,7 +342,8 @@ class SympyAssignment(Node): def replace(self, child, replacement): if child == self.lhs: - self.lhs = child + replacement.parent = self + self.lhs = replacement elif child == self.rhs: replacement.parent = self self.rhs = replacement @@ -478,6 +479,10 @@ class Pow(Expr): class Indexed(Expr): + def __init__(self, args, base, parent=None): + super(Indexed, self).__init__(args, parent) + self.base = base + def __repr__(self): return '%s[%s]' % (self.args[0], self.args[1]) @@ -506,6 +511,6 @@ class Number(Node, sp.AtomicExpr): raise set() def __repr__(self): - return repr(self.dtype) + repr(self.value) + return repr(self.value) diff --git a/backends/__init__.py b/backends/__init__.py index b4f7b6786b4061c7174fbad0bff8bf9a1deb55a6..fe72a215430a981ab1e29a05551e23770914d67c 100644 --- a/backends/__init__.py +++ b/backends/__init__.py @@ -1,2 +1,3 @@ from .llvm import generateLLVM from .cbackend import generateC, generateCUDA +from .dot import dotprint diff --git a/backends/dot.py b/backends/dot.py index b90ffc40a8434f16b8beac04495b5731d78d73c7..aac41d02de940ae80fe1ddb43a563eec8aba4e5f 100644 --- a/backends/dot.py +++ b/backends/dot.py @@ -6,9 +6,10 @@ class DotPrinter(Printer): """ A printer which converts ast to DOT (graph description language). """ - def __init__(self, nodeToStrFunction, **kwargs): + def __init__(self, nodeToStrFunction, full, **kwargs): super(DotPrinter, self).__init__() self._nodeToStrFunction = nodeToStrFunction + self.full = full self.dot = Digraph(**kwargs) self.dot.quote_edge = lang.quote @@ -30,6 +31,21 @@ class DotPrinter(Printer): def _print_SympyAssignment(self, assignment): self.dot.node(self._nodeToStrFunction(assignment)) + if self.full: + for node in assignment.args: + self._print(node) + for node in assignment.args: + self.dot.edge(self._nodeToStrFunction(assignment), self._nodeToStrFunction(node)) + + def emptyPrinter(self, expr): + if self.full: + self.dot.node(self._nodeToStrFunction(expr)) + for node in expr.args: + self._print(node) + for node in expr.args: + self.dot.edge(self._nodeToStrFunction(expr), self._nodeToStrFunction(node)) + else: + raise NotImplemented('Dotprinter cannot print', expr) def doprint(self, expr): self._print(expr) @@ -48,17 +64,20 @@ def __shortened(node): return "Assignment: " + repr(node.lhs) -def dotprint(ast, view=False, short=False, **kwargs): +def dotprint(node, view=False, short=False, full=False, **kwargs): """ Returns a string which can be used to generate a DOT-graph - :param ast: The ast which should be generated + :param node: The ast which should be generated :param view: Boolen, if rendering of the image directly should occur. + :param short: Uses the __shortened output + :param full: Prints the whole tree with type information :param kwargs: is directly passed to the DotPrinter class: http://graphviz.readthedocs.io/en/latest/api.html#digraph :return: string in DOT format """ nodeToStrFunction = __shortened if short else repr - printer = DotPrinter(nodeToStrFunction, **kwargs) - dot = printer.doprint(ast) + nodeToStrFunction = lambda expr: repr(type(expr)) + repr(expr) if full else nodeToStrFunction + printer = DotPrinter(nodeToStrFunction, full, **kwargs) + dot = printer.doprint(node) if view: printer.dot.render(view=view) return dot @@ -80,4 +99,4 @@ if __name__ == "__main__": from pystencils.cpu import createKernel ast = createKernel([updateRule]) - print(dotprint(ast, short=True)) \ No newline at end of file + print(dotprint(ast, short=True)) diff --git a/transformations.py b/transformations.py index efa7bd5642b93e0995f9f33a9ad48c9930f8950f..74aba9d87c0793dcfde226b113278e2a0e4f786d 100644 --- a/transformations.py +++ b/transformations.py @@ -548,7 +548,7 @@ def get_type(node): def insert_casts(node): """ - Inserts casts where needed + Inserts casts and dtype where needed :param node: ast which should be traversed :return: node """ @@ -559,7 +559,7 @@ def insert_casts(node): print(arg) insert_casts(arg) if isinstance(node, ast.Indexed): - pass + node.dtype = node.base.label.dtype elif isinstance(node, ast.Expr): print(node) print([(arg, type(arg), arg.dtype, type(arg.dtype)) for arg in node.args]) @@ -594,9 +594,31 @@ def desympy_ast(node): node.replace(arg, ast.Mul(arg.args, node)) elif isinstance(arg, sp.Pow): node.replace(arg, ast.Pow(arg.args, node)) - elif isinstance(arg, sp.tensor.Indexed): - node.replace(arg, ast.Indexed(arg.args, node)) - #elif isinstance(arg, ) + elif isinstance(arg, sp.tensor.Indexed) or isinstance(arg, sp.tensor.indexed.Indexed): + node.replace(arg, ast.Indexed(arg.args, arg.base, node)) + elif isinstance(arg, sp.tensor.IndexedBase): + node.replace(arg, arg.label) + #elif isinstance(arg, sp.containers.Tuple): + # + else: + print('Not transforming:', arg, type(arg)) for arg in node.args: desympy_ast(arg) return node + + +def check_dtype(node): + if isinstance(node, ast.KernelFunction): + pass + elif isinstance(node, ast.Block): + pass + elif isinstance(node, ast.LoopOverCoordinate): + pass + elif isinstance(node, ast.SympyAssignment): + pass + else: + print(node) + print(node.dtype) + for arg in node.args: + check_dtype(arg) +