From 9a21e0c7f142d4a57d96aa14432ca117d13e24ac Mon Sep 17 00:00:00 2001 From: Jan Hoenig <hrominium@gmail.com> Date: Fri, 16 Dec 2016 18:24:22 +0100 Subject: [PATCH] Transforming ast to nonsympy ast --- ast.py | 64 ++++++++++++++++++++++++++++++++++++++++++++-- transformations.py | 29 +++++++++++++++++++++ 2 files changed, 91 insertions(+), 2 deletions(-) diff --git a/ast.py b/ast.py index 60be3db9b..b0b91b80a 100644 --- a/ast.py +++ b/ast.py @@ -96,7 +96,7 @@ class KernelFunction(Node): @property def args(self): - yield self._body + return [self._body] @property def fieldsAccessed(self): @@ -286,7 +286,6 @@ class LoopOverCoordinate(Node): class SympyAssignment(Node): - def __init__(self, lhsSymbol, rhsTerm, isConst=True): self._lhsSymbol = lhsSymbol self.rhs = rhsTerm @@ -337,6 +336,15 @@ class SympyAssignment(Node): def isConst(self): return self._isConst + def replace(self, child, replacement): + if child == self.lhs: + self.lhs = child + elif child == self.rhs: + replacement.parent = self + self.rhs = replacement + else: + raise ValueError('%s is not in args of %s' % (replacement, self.__class__)) + def __repr__(self): return repr(self.lhs) + " = " + repr(self.rhs) @@ -378,3 +386,55 @@ class TemporaryMemoryFree(Node): def args(self): return [] + +# TODO everything which is not Atomic expression: Pow) + +class Expr(Node): + def __init__(self, args, parent=None): + super(Expr, self).__init__(parent) + self._args = list(args) + + @property + def args(self): + return self._args + + @args.setter + def args(self, value): + self._args = value + + def replace(self, child, replacements): + idx = self.args.index(child) + del self.args[idx] + if type(replacements) is list: + for e in replacements: + e.parent = self + self.args = self.args[:idx] + replacements + self.args[idx:] + else: + replacements.parent = self + self.args.insert(idx, replacements) + + @property + def symbolsDefined(self): + return set() # Todo fix for symbol analysis + + @property + def undefinedSymbols(self): + return set() # Todo fix for symbol analysis + + +class Mul(Expr): + pass + + +class Add(Expr): + pass + + +class Pow(Expr): + pass + + +class Indexed(Expr): + pass + + diff --git a/transformations.py b/transformations.py index aea9eb56b..816e46a08 100644 --- a/transformations.py +++ b/transformations.py @@ -517,3 +517,32 @@ def getLoopHierarchy(astNode): result.append(node.coordinateToLoopOver) return reversed(result) + +def insert_casts(node): + if isinstance(node, ast.SympyAssignment): + pass + elif isinstance(node, sp.Expr): + pass + else: + for arg in node.args: + insert_casts(arg) + + +def desympy_ast(node): + # if isinstance(node, sp.Expr) and not isinstance(node, sp.AtomicExpr) and not isinstance(node, sp.tensor.IndexedBase): + # print(node, type(node)) + + for i in range(len(node.args)): + arg = node.args[i] + if isinstance(node, ast.SympyAssignment): + print(node, type(arg)) + if isinstance(arg, sp.Add): + node.replace(arg, ast.Add(arg.args, node)) + elif isinstance(arg, sp.Mul): + node.replace(arg, ast.Mul(arg.args, node)) + elif isinstance(arg, sp.Pow): + node.replace(arg, ast.Pow(arg.args, node)) + elif isinstance(arg, sp.tensor.Indexed): + node.replace(arg, ast.Indexed(arg.args, node)) + for arg in node.args: + desympy_ast(arg) -- GitLab