diff --git a/ast.py b/ast.py
index 60be3db9b2a543938c6b4f3d0ec88035a7bbcc4a..b0b91b80aad0d6aae169c29fcc7440ea55fe8bcd 100644
--- a/ast.py
+++ b/ast.py
@@ -96,7 +96,7 @@ class KernelFunction(Node):
 
     @property
     def args(self):
-        yield self._body
+        return [self._body]
 
     @property
     def fieldsAccessed(self):
@@ -286,7 +286,6 @@ class LoopOverCoordinate(Node):
 
 
 class SympyAssignment(Node):
-
     def __init__(self, lhsSymbol, rhsTerm, isConst=True):
         self._lhsSymbol = lhsSymbol
         self.rhs = rhsTerm
@@ -337,6 +336,15 @@ class SympyAssignment(Node):
     def isConst(self):
         return self._isConst
 
+    def replace(self, child, replacement):
+        if child == self.lhs:
+            self.lhs = child
+        elif child == self.rhs:
+            replacement.parent = self
+            self.rhs = replacement
+        else:
+            raise ValueError('%s is not in args of %s' % (replacement, self.__class__))
+
     def __repr__(self):
         return repr(self.lhs) + " = " + repr(self.rhs)
 
@@ -378,3 +386,55 @@ class TemporaryMemoryFree(Node):
     def args(self):
         return []
 
+
+# TODO everything which is not Atomic expression: Pow)
+
+class Expr(Node):
+    def __init__(self, args, parent=None):
+        super(Expr, self).__init__(parent)
+        self._args = list(args)
+
+    @property
+    def args(self):
+        return self._args
+
+    @args.setter
+    def args(self, value):
+        self._args = value
+
+    def replace(self, child, replacements):
+        idx = self.args.index(child)
+        del self.args[idx]
+        if type(replacements) is list:
+            for e in replacements:
+                e.parent = self
+            self.args = self.args[:idx] + replacements + self.args[idx:]
+        else:
+            replacements.parent = self
+            self.args.insert(idx, replacements)
+
+    @property
+    def symbolsDefined(self):
+        return set()  # Todo fix for symbol analysis
+
+    @property
+    def undefinedSymbols(self):
+        return set()  # Todo fix for symbol analysis
+
+
+class Mul(Expr):
+    pass
+
+
+class Add(Expr):
+    pass
+
+
+class Pow(Expr):
+    pass
+
+
+class Indexed(Expr):
+    pass
+
+
diff --git a/transformations.py b/transformations.py
index aea9eb56b5062163426c8139d545c33529ae50fb..816e46a089f6827c5ae3b2d108ddca05b8fb8a54 100644
--- a/transformations.py
+++ b/transformations.py
@@ -517,3 +517,32 @@ def getLoopHierarchy(astNode):
             result.append(node.coordinateToLoopOver)
     return reversed(result)
 
+
+def insert_casts(node):
+    if isinstance(node, ast.SympyAssignment):
+        pass
+    elif isinstance(node, sp.Expr):
+        pass
+    else:
+        for arg in node.args:
+            insert_casts(arg)
+
+
+def desympy_ast(node):
+    # if isinstance(node, sp.Expr) and not isinstance(node, sp.AtomicExpr) and not isinstance(node, sp.tensor.IndexedBase):
+    #    print(node, type(node))
+
+    for i in range(len(node.args)):
+        arg = node.args[i]
+        if isinstance(node, ast.SympyAssignment):
+            print(node, type(arg))
+        if isinstance(arg, sp.Add):
+            node.replace(arg, ast.Add(arg.args, node))
+        elif isinstance(arg, sp.Mul):
+            node.replace(arg, ast.Mul(arg.args, node))
+        elif isinstance(arg, sp.Pow):
+            node.replace(arg, ast.Pow(arg.args, node))
+        elif isinstance(arg, sp.tensor.Indexed):
+            node.replace(arg, ast.Indexed(arg.args, node))
+    for arg in node.args:
+        desympy_ast(arg)