Commit 3de207f4 authored by Jan Hoenig's avatar Jan Hoenig
Browse files

Found some bugs. Dtypes should be now everywhere, i believe it was desympied now

parent 53b20223
......@@ -342,7 +342,8 @@ class SympyAssignment(Node):
def replace(self, child, replacement):
if child == self.lhs:
self.lhs = child
replacement.parent = self
self.lhs = replacement
elif child == self.rhs:
replacement.parent = self
self.rhs = replacement
......@@ -478,6 +479,10 @@ class Pow(Expr):
class Indexed(Expr):
def __init__(self, args, base, parent=None):
super(Indexed, self).__init__(args, parent)
self.base = base
def __repr__(self):
return '%s[%s]' % (self.args[0], self.args[1])
......@@ -506,6 +511,6 @@ class Number(Node, sp.AtomicExpr):
raise set()
def __repr__(self):
return repr(self.dtype) + repr(self.value)
return repr(self.value)
from .llvm import generateLLVM
from .cbackend import generateC, generateCUDA
from .dot import dotprint
......@@ -6,9 +6,10 @@ class DotPrinter(Printer):
"""
A printer which converts ast to DOT (graph description language).
"""
def __init__(self, nodeToStrFunction, **kwargs):
def __init__(self, nodeToStrFunction, full, **kwargs):
super(DotPrinter, self).__init__()
self._nodeToStrFunction = nodeToStrFunction
self.full = full
self.dot = Digraph(**kwargs)
self.dot.quote_edge = lang.quote
......@@ -30,6 +31,21 @@ class DotPrinter(Printer):
def _print_SympyAssignment(self, assignment):
self.dot.node(self._nodeToStrFunction(assignment))
if self.full:
for node in assignment.args:
self._print(node)
for node in assignment.args:
self.dot.edge(self._nodeToStrFunction(assignment), self._nodeToStrFunction(node))
def emptyPrinter(self, expr):
if self.full:
self.dot.node(self._nodeToStrFunction(expr))
for node in expr.args:
self._print(node)
for node in expr.args:
self.dot.edge(self._nodeToStrFunction(expr), self._nodeToStrFunction(node))
else:
raise NotImplemented('Dotprinter cannot print', expr)
def doprint(self, expr):
self._print(expr)
......@@ -48,17 +64,20 @@ def __shortened(node):
return "Assignment: " + repr(node.lhs)
def dotprint(ast, view=False, short=False, **kwargs):
def dotprint(node, view=False, short=False, full=False, **kwargs):
"""
Returns a string which can be used to generate a DOT-graph
:param ast: The ast which should be generated
:param node: The ast which should be generated
:param view: Boolen, if rendering of the image directly should occur.
:param short: Uses the __shortened output
:param full: Prints the whole tree with type information
:param kwargs: is directly passed to the DotPrinter class: http://graphviz.readthedocs.io/en/latest/api.html#digraph
:return: string in DOT format
"""
nodeToStrFunction = __shortened if short else repr
printer = DotPrinter(nodeToStrFunction, **kwargs)
dot = printer.doprint(ast)
nodeToStrFunction = lambda expr: repr(type(expr)) + repr(expr) if full else nodeToStrFunction
printer = DotPrinter(nodeToStrFunction, full, **kwargs)
dot = printer.doprint(node)
if view:
printer.dot.render(view=view)
return dot
......@@ -80,4 +99,4 @@ if __name__ == "__main__":
from pystencils.cpu import createKernel
ast = createKernel([updateRule])
print(dotprint(ast, short=True))
\ No newline at end of file
print(dotprint(ast, short=True))
......@@ -548,7 +548,7 @@ def get_type(node):
def insert_casts(node):
"""
Inserts casts where needed
Inserts casts and dtype where needed
:param node: ast which should be traversed
:return: node
"""
......@@ -559,7 +559,7 @@ def insert_casts(node):
print(arg)
insert_casts(arg)
if isinstance(node, ast.Indexed):
pass
node.dtype = node.base.label.dtype
elif isinstance(node, ast.Expr):
print(node)
print([(arg, type(arg), arg.dtype, type(arg.dtype)) for arg in node.args])
......@@ -594,9 +594,31 @@ def desympy_ast(node):
node.replace(arg, ast.Mul(arg.args, node))
elif isinstance(arg, sp.Pow):
node.replace(arg, ast.Pow(arg.args, node))
elif isinstance(arg, sp.tensor.Indexed):
node.replace(arg, ast.Indexed(arg.args, node))
#elif isinstance(arg, )
elif isinstance(arg, sp.tensor.Indexed) or isinstance(arg, sp.tensor.indexed.Indexed):
node.replace(arg, ast.Indexed(arg.args, arg.base, node))
elif isinstance(arg, sp.tensor.IndexedBase):
node.replace(arg, arg.label)
#elif isinstance(arg, sp.containers.Tuple):
#
else:
print('Not transforming:', arg, type(arg))
for arg in node.args:
desympy_ast(arg)
return node
def check_dtype(node):
if isinstance(node, ast.KernelFunction):
pass
elif isinstance(node, ast.Block):
pass
elif isinstance(node, ast.LoopOverCoordinate):
pass
elif isinstance(node, ast.SympyAssignment):
pass
else:
print(node)
print(node.dtype)
for arg in node.args:
check_dtype(arg)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment