Commit 9a21e0c7 authored by Jan Hoenig's avatar Jan Hoenig
Browse files

Transforming ast to nonsympy ast

parent a211a120
...@@ -96,7 +96,7 @@ class KernelFunction(Node): ...@@ -96,7 +96,7 @@ class KernelFunction(Node):
@property @property
def args(self): def args(self):
yield self._body return [self._body]
@property @property
def fieldsAccessed(self): def fieldsAccessed(self):
...@@ -286,7 +286,6 @@ class LoopOverCoordinate(Node): ...@@ -286,7 +286,6 @@ class LoopOverCoordinate(Node):
class SympyAssignment(Node): class SympyAssignment(Node):
def __init__(self, lhsSymbol, rhsTerm, isConst=True): def __init__(self, lhsSymbol, rhsTerm, isConst=True):
self._lhsSymbol = lhsSymbol self._lhsSymbol = lhsSymbol
self.rhs = rhsTerm self.rhs = rhsTerm
...@@ -337,6 +336,15 @@ class SympyAssignment(Node): ...@@ -337,6 +336,15 @@ class SympyAssignment(Node):
def isConst(self): def isConst(self):
return self._isConst return self._isConst
def replace(self, child, replacement):
if child == self.lhs:
self.lhs = child
elif child == self.rhs:
replacement.parent = self
self.rhs = replacement
raise ValueError('%s is not in args of %s' % (replacement, self.__class__))
def __repr__(self): def __repr__(self):
return repr(self.lhs) + " = " + repr(self.rhs) return repr(self.lhs) + " = " + repr(self.rhs)
...@@ -378,3 +386,55 @@ class TemporaryMemoryFree(Node): ...@@ -378,3 +386,55 @@ class TemporaryMemoryFree(Node):
def args(self): def args(self):
return [] return []
# TODO everything which is not Atomic expression: Pow)
class Expr(Node):
def __init__(self, args, parent=None):
super(Expr, self).__init__(parent)
self._args = list(args)
def args(self):
return self._args
def args(self, value):
self._args = value
def replace(self, child, replacements):
idx = self.args.index(child)
del self.args[idx]
if type(replacements) is list:
for e in replacements:
e.parent = self
self.args = self.args[:idx] + replacements + self.args[idx:]
replacements.parent = self
self.args.insert(idx, replacements)
def symbolsDefined(self):
return set() # Todo fix for symbol analysis
def undefinedSymbols(self):
return set() # Todo fix for symbol analysis
class Mul(Expr):
class Add(Expr):
class Pow(Expr):
class Indexed(Expr):
...@@ -517,3 +517,32 @@ def getLoopHierarchy(astNode): ...@@ -517,3 +517,32 @@ def getLoopHierarchy(astNode):
result.append(node.coordinateToLoopOver) result.append(node.coordinateToLoopOver)
return reversed(result) return reversed(result)
def insert_casts(node):
if isinstance(node, ast.SympyAssignment):
elif isinstance(node, sp.Expr):
for arg in node.args:
def desympy_ast(node):
# if isinstance(node, sp.Expr) and not isinstance(node, sp.AtomicExpr) and not isinstance(node, sp.tensor.IndexedBase):
# print(node, type(node))
for i in range(len(node.args)):
arg = node.args[i]
if isinstance(node, ast.SympyAssignment):
print(node, type(arg))
if isinstance(arg, sp.Add):
node.replace(arg, ast.Add(arg.args, node))
elif isinstance(arg, sp.Mul):
node.replace(arg, ast.Mul(arg.args, node))
elif isinstance(arg, sp.Pow):
node.replace(arg, ast.Pow(arg.args, node))
elif isinstance(arg, sp.tensor.Indexed):
node.replace(arg, ast.Indexed(arg.args, node))
for arg in node.args:
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment