From bd56d31fc1e9c9f526e6f122b0ca06140866dfd6 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Wed, 2 Nov 2016 16:30:41 +0100
Subject: [PATCH] Major refactoring: separated Ast and generateC code

---
 ast.py               | 320 ++++++++++++++++
 backends/__init__.py |   0
 backends/cbackend.py | 164 ++++++++
 field.py             |   1 +
 finitedifferences.py |   2 +-
 generator.py         | 893 -------------------------------------------
 jit.py               |   4 +-
 transformations.py   | 444 +++++++++++++++++++++
 8 files changed, 932 insertions(+), 896 deletions(-)
 create mode 100644 ast.py
 create mode 100644 backends/__init__.py
 create mode 100644 backends/cbackend.py
 delete mode 100644 generator.py
 create mode 100644 transformations.py

diff --git a/ast.py b/ast.py
new file mode 100644
index 000000000..5eed391d0
--- /dev/null
+++ b/ast.py
@@ -0,0 +1,320 @@
+import sympy as sp
+from sympy.tensor import IndexedBase, Indexed
+from pystencils.field import Field
+from pystencils.typedsymbol import TypedSymbol
+
+
+class Node:
+    def __init__(self, parent=None):
+        self.parent = parent
+
+    def args(self):
+        return []
+
+    @property
+    def symbolsDefined(self):
+        return set()
+
+    @property
+    def symbolsRead(self):
+        return set()
+
+    def atoms(self, argType):
+        result = set()
+        for arg in self.args:
+            if isinstance(arg, argType):
+                result.add(arg)
+            result.update(arg.atoms(argType))
+        return result
+
+
+class KernelFunction(Node):
+
+    class Argument:
+        def __init__(self, name, dtype):
+            self.name = name
+            self.dtype = dtype
+            self.isFieldPtrArgument = False
+            self.isFieldShapeArgument = False
+            self.isFieldStrideArgument = False
+            self.isFieldArgument = False
+            self.fieldName = ""
+            self.coordinate = None
+
+            if name.startswith(Field.DATA_PREFIX):
+                self.isFieldPtrArgument = True
+                self.isFieldArgument = True
+                self.fieldName = name[len(Field.DATA_PREFIX):]
+            elif name.startswith(Field.SHAPE_PREFIX):
+                self.isFieldShapeArgument = True
+                self.isFieldArgument = True
+                self.fieldName = name[len(Field.SHAPE_PREFIX):]
+            elif name.startswith(Field.STRIDE_PREFIX):
+                self.isFieldStrideArgument = True
+                self.isFieldArgument = True
+                self.fieldName = name[len(Field.STRIDE_PREFIX):]
+
+    def __init__(self, body, functionName="kernel"):
+        super(KernelFunction, self).__init__()
+        self._body = body
+        self._parameters = None
+        self._functionName = functionName
+        self._body.parent = self
+        self.variablesToIgnore = set()
+
+    @property
+    def symbolsDefined(self):
+        return set()
+
+    @property
+    def symbolsRead(self):
+        return set()
+
+    @property
+    def parameters(self):
+        self._updateParameters()
+        return self._parameters
+
+    @property
+    def body(self):
+        return self._body
+
+    @property
+    def args(self):
+        return [self._body]
+
+    @property
+    def functionName(self):
+        return self._functionName
+
+    def _updateParameters(self):
+        undefinedSymbols = self._body.symbolsRead - self._body.symbolsDefined - self.variablesToIgnore
+        self._parameters = [KernelFunction.Argument(s.name, s.dtype) for s in undefinedSymbols]
+        self._parameters.sort(key=lambda l: (l.fieldName, l.isFieldPtrArgument, l.isFieldShapeArgument,
+                                             l.isFieldStrideArgument, l.name),
+                              reverse=True)
+
+
+class Block(Node):
+    def __init__(self, listOfNodes):
+        super(Node, self).__init__()
+        self._nodes = listOfNodes
+        for n in self._nodes:
+            n.parent = self
+
+    @property
+    def args(self):
+        return self._nodes
+
+    def insertFront(self, node):
+        node.parent = self
+        self._nodes.insert(0, node)
+
+    def append(self, node):
+        node.parent = self
+        self._nodes.append(node)
+
+    def takeChildNodes(self):
+        tmp = self._nodes
+        self._nodes = []
+        return tmp
+
+    def replace(self, child, replacements):
+        idx = self._nodes.index(child)
+        del self._nodes[idx]
+        if type(replacements) is list:
+            for e in replacements:
+                e.parent = self
+            self._nodes = self._nodes[:idx] + replacements + self._nodes[idx:]
+        else:
+            replacements.parent = self
+            self._nodes.insert(idx, replacements)
+
+    @property
+    def symbolsDefined(self):
+        result = set()
+        for a in self.args:
+            result.update(a.symbolsDefined)
+        return result
+
+    @property
+    def symbolsRead(self):
+        result = set()
+        for a in self.args:
+            result.update(a.symbolsRead)
+        return result
+
+
+class PragmaBlock(Block):
+    def __init__(self, pragmaLine, listOfNodes):
+        super(PragmaBlock, self).__init__(listOfNodes)
+        self.pragmaLine = pragmaLine
+
+
+class LoopOverCoordinate(Node):
+    LOOP_COUNTER_NAME_PREFIX = "ctr"
+
+    def __init__(self, body, coordinateToLoopOver, shape, increment=1, ghostLayers=1,
+                 isInnermostLoop=False, isOutermostLoop=False):
+        self._body = body
+        self._coordinateToLoopOver = coordinateToLoopOver
+        self._shape = shape
+        self._increment = increment
+        self._ghostLayers = ghostLayers
+        self._body.parent = self
+        self.prefixLines = []
+        self._isInnermostLoop = isInnermostLoop
+        self._isOutermostLoop = isOutermostLoop
+
+    def newLoopWithDifferentBody(self, newBody):
+        result = LoopOverCoordinate(newBody, self._coordinateToLoopOver, self._shape, self._increment,
+                                    self._ghostLayers, self._isInnermostLoop, self._isOutermostLoop)
+        result.prefixLines = self.prefixLines
+        return result
+
+    @property
+    def args(self):
+        result = [self._body]
+        limit = self._shape[self._coordinateToLoopOver]
+        if isinstance(limit, Node) or isinstance(limit, sp.Basic):
+            result.append(limit)
+        return result
+
+    @property
+    def body(self):
+        return self._body
+
+    @property
+    def iterationEnd(self):
+        return self._shape[self.coordinateToLoopOver] - self.ghostLayers
+
+    @property
+    def coordinateToLoopOver(self):
+        return self._coordinateToLoopOver
+
+    @property
+    def symbolsDefined(self):
+        result = self._body.symbolsDefined
+        result.add(self.loopCounterSymbol)
+        return result
+
+    @property
+    def loopCounterName(self):
+        return "%s_%s" % (LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX, self._coordinateToLoopOver)
+
+    @property
+    def loopCounterSymbol(self):
+        return TypedSymbol(self.loopCounterName, "int")
+
+    @property
+    def symbolsRead(self):
+        result = self._body.symbolsRead
+        limit = self._shape[self._coordinateToLoopOver]
+        if isinstance(limit, sp.Basic):
+            result.update(limit.atoms(sp.Symbol))
+        return result
+
+    @property
+    def isOutermostLoop(self):
+        return self._isOutermostLoop
+
+    @property
+    def isInnermostLoop(self):
+        return self._isInnermostLoop
+
+    @property
+    def coordinateToLoopOver(self):
+        return self._coordinateToLoopOver
+
+    @property
+    def iterationRegionWithGhostLayer(self):
+        return self._shape[self._coordinateToLoopOver]
+
+    @property
+    def ghostLayers(self):
+        return self._ghostLayers
+
+
+class SympyAssignment(Node):
+
+    def __init__(self, lhsSymbol, rhsTerm, isConst=True):
+        self._lhsSymbol = lhsSymbol
+        self.rhs = rhsTerm
+        self._isDeclaration = True
+        if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, IndexedBase):
+            self._isDeclaration = False
+        self._isConst = isConst
+
+    @property
+    def lhs(self):
+        return self._lhsSymbol
+
+    @lhs.setter
+    def lhs(self, newValue):
+        self._lhsSymbol = newValue
+        self._isDeclaration = True
+        if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, Indexed):
+            self._isDeclaration = False
+
+    @property
+    def args(self):
+        return [self._lhsSymbol, self.rhs]
+
+    @property
+    def symbolsDefined(self):
+        if not self._isDeclaration:
+            return set()
+        return set([self._lhsSymbol])
+
+    @property
+    def symbolsRead(self):
+        result = self.rhs.atoms(sp.Symbol)
+        result.update(self._lhsSymbol.atoms(sp.Symbol))
+        return result
+
+    @property
+    def isDeclaration(self):
+        return self._isDeclaration
+
+    @property
+    def isConst(self):
+        return self._isConst
+
+    def __repr__(self):
+        return repr(self.lhs) + " = " + repr(self.rhs)
+
+
+class TemporaryMemoryAllocation(Node):
+    def __init__(self, typedSymbol, size):
+        self.symbol = typedSymbol
+        self.size = size
+
+    @property
+    def symbolsDefined(self):
+        return set([self._symbol])
+
+    @property
+    def symbolsRead(self):
+        return set()
+
+    @property
+    def args(self):
+        return [self._symbol]
+
+
+class TemporaryMemoryFree(Node):
+    def __init__(self, typedSymbol):
+        self._symbol = typedSymbol
+
+    @property
+    def symbolsDefined(self):
+        return set()
+
+    @property
+    def symbolsRead(self):
+        return set()
+
+    @property
+    def args(self):
+        return []
+
diff --git a/backends/__init__.py b/backends/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/backends/cbackend.py b/backends/cbackend.py
new file mode 100644
index 000000000..70d1eb2c0
--- /dev/null
+++ b/backends/cbackend.py
@@ -0,0 +1,164 @@
+import cgen as c
+from sympy.utilities.codegen import CCodePrinter
+from pystencils.ast import Node
+
+
+def printCCode(astNode):
+    printer = CBackend(cuda=False)
+    return printer(astNode)
+
+
+def printCudaCode(astNode):
+    printer = CBackend(cuda=True)
+    return printer(astNode)
+
+# --------------------------------------- Backend Specific Nodes -------------------------------------------------------
+
+
+class CustomCppCode(Node):
+    def __init__(self, code, symbolsRead, symbolsDefined):
+        self._code = "\n" + code
+        self._symbolsRead = set(symbolsRead)
+        self._symbolsDefined = set(symbolsDefined)
+
+    @property
+    def code(self):
+        return self._code
+
+    @property
+    def args(self):
+        return []
+
+    @property
+    def symbolsDefined(self):
+        return self._symbolsDefined
+
+    @property
+    def symbolsRead(self):
+        return self._symbolsRead
+
+    def generateC(self):
+        return c.LiteralLines(self._code)
+
+
+class PrintNode(CustomCppCode):
+    def __init__(self, symbolToPrint):
+        code = '\nstd::cout << "%s  =  " << %s << std::endl; \n' % (symbolToPrint.name, symbolToPrint.name)
+        super(PrintNode, self).__init__(code, symbolsRead=[symbolToPrint], symbolsDefined=set())
+
+
+# ------------------------------------------- Printer ------------------------------------------------------------------
+
+
+class CBackend:
+
+    def __init__(self, cuda=False):
+        self.cuda = cuda
+        self.sympyPrinter = CustomSympyPrinter()
+
+    def __call__(self, node):
+        return self._print(node)
+
+    def _print(self, node):
+        for cls in type(node).__mro__:
+            methodName = "_print_" + cls.__name__
+            if hasattr(self, methodName):
+                return getattr(self, methodName)(node)
+        raise NotImplementedError("CBackend does not support node of type " + cls.__name__)
+
+    def _print_KernelFunction(self, node):
+        functionArguments = [MyPOD(s.dtype, s.name) for s in node.parameters]
+        prefix = "__global__ void" if self.cuda else "void"
+        functionPOD = MyPOD(prefix, node.functionName, )
+        funcDeclaration = c.FunctionDeclaration(functionPOD, functionArguments)
+        return c.FunctionBody(funcDeclaration, self._print(node.body))
+
+    def _print_Block(self, node):
+        return c.Block([self._print(child) for child in node.args])
+
+    def _print_PragmaBlock(self, node):
+        class PragmaGenerable(c.Generable):
+            def __init__(self, line, block):
+                self._line = line
+                self._block = block
+
+            def generate(self):
+                yield self._line
+                for e in self._block.generate():
+                    yield e
+        return PragmaGenerable(node.pragmaLine, self._print_Block(node))
+
+    def _print_LoopOverCoordinate(self, node):
+        class LoopWithOptionalPrefix(c.CustomLoop):
+            def __init__(self, intro_line, body, prefixLines=[]):
+                super(LoopWithOptionalPrefix, self).__init__(intro_line, body)
+                self.prefixLines = prefixLines
+
+            def generate(self):
+                for l in self.prefixLines:
+                    yield l
+
+                for e in super(LoopWithOptionalPrefix, self).generate():
+                    yield e
+
+        counterVar = node.loopCounterName
+        start = "int %s = %d" % (counterVar, node.ghostLayers)
+        condition = "%s < %s" % (counterVar, self.sympyPrinter.doprint(node.iterationEnd))
+        update = "++%s" % (counterVar,)
+        loopStr = "for (%s; %s; %s)" % (start, condition, update)
+        return LoopWithOptionalPrefix(loopStr, self._print(node.body), prefixLines=node.prefixLines)
+
+    def _print_SympyAssignment(self, node):
+        dtype = ""
+        if node.isDeclaration:
+            if node.isConst:
+                dtype = "const " + node.lhs.dtype + " "
+            else:
+                dtype = node.lhs.dtype + " "
+
+        return c.Assign(dtype + self.sympyPrinter.doprint(node.lhs),
+                        self.sympyPrinter.doprint(node.rhs))
+
+    def _print_TemporaryMemoryAllocation(self, node):
+        return c.Assign(node.symbol.dtype + " * " + self.sympyPrinter.doprint(node.symbol),
+                        "new %s[%s]" % (node.symbol.dtype, self.sympyPrinter.doprint(node.size)))
+
+    def _print_TemporaryMemoryFree(self, node):
+        return c.Statement("delete [] %s" % (self.sympyPrinter.doprint(node.symbol),))
+
+    def _print_CustomCppCode(self, node):
+        return c.LiteralLines(node.code)
+
+
+# ------------------------------------------ Helper function & classes -------------------------------------------------
+
+
+class CustomSympyPrinter(CCodePrinter):
+    def _print_Pow(self, expr):
+        """Don't use std::pow function, for small integer exponents, write as multiplication"""
+        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
+            return '(' + '*'.join(["(" + self._print(expr.base) + ")"] * expr.exp) + ')'
+        else:
+            return super(CustomSympyPrinter, self)._print_Pow(expr)
+
+    def _print_Rational(self, expr):
+        """Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0"""
+        return str(expr.evalf().num)
+
+    def _print_Equality(self, expr):
+        """Equality operator is not printable in default printer"""
+        return '((' + self._print(expr.lhs) + ") == (" + self._print(expr.rhs) + '))'
+
+    def _print_Piecewise(self, expr):
+        """Print piecewise in one line (remove newlines)"""
+        result = super(CustomSympyPrinter, self)._print_Piecewise(expr)
+        return result.replace("\n", "")
+
+
+class MyPOD(c.Declarator):
+    def __init__(self, dtype, name):
+        self.dtype = dtype
+        self.name = name
+
+    def get_decl_pair(self):
+        return [self.dtype], self.name
diff --git a/field.py b/field.py
index 5067bc9de..2fe288e58 100644
--- a/field.py
+++ b/field.py
@@ -300,6 +300,7 @@ class Field:
     SHAPE_PREFIX = PREFIX + "shape_"
     STRIDE_DTYPE = "const int *"
     SHAPE_DTYPE = "const int *"
+    DATA_PREFIX = PREFIX + "d_"
 
     class Access(sp.Symbol):
         def __new__(cls, name, *args, **kwargs):
diff --git a/finitedifferences.py b/finitedifferences.py
index 80357943d..b31bc5646 100644
--- a/finitedifferences.py
+++ b/finitedifferences.py
@@ -1,7 +1,7 @@
 import numpy as np
 import sympy as sp
 
-from pystencils.generator import Field
+from pystencils.field import Field
 
 
 def __upDownOffsets(d, dim):
diff --git a/generator.py b/generator.py
deleted file mode 100644
index 5a8dcec4c..000000000
--- a/generator.py
+++ /dev/null
@@ -1,893 +0,0 @@
-from collections import defaultdict
-import cgen as c
-import sympy as sp
-from sympy.logic.boolalg import Boolean
-from sympy.utilities.codegen import CCodePrinter
-from sympy.tensor import IndexedBase, Indexed
-from pystencils.field import Field, offsetComponentToDirectionString
-from pystencils.typedsymbol import TypedSymbol
-
-COORDINATE_LOOP_COUNTER_NAME = "ctr"
-FIELD_PTR_PREFIX = Field.PREFIX + "d_"
-
-
-# --------------------------------------- Helper Functions -------------------------------------------------------------
-
-
-class CodePrinter(CCodePrinter):
-    def _print_Pow(self, expr):
-        if expr.exp.is_integer and expr.exp.is_number and expr.exp > 0:
-            return '(' + '*'.join(["(" + self._print(expr.base) + ")"] * expr.exp) + ')'
-        else:
-            return super(CodePrinter, self)._print_Pow(expr)
-
-    def _print_Rational(self, expr):
-        return str(expr.evalf().num)
-
-    def _print_Equality(self, expr):
-        return '((' + self._print(expr.lhs) + ") == (" + self._print(expr.rhs) + '))'
-
-    def _print_Piecewise(self, expr):
-        result = super(CodePrinter, self)._print_Piecewise(expr)
-        return result.replace("\n", "")
-
-codePrinter = CodePrinter()
-
-
-class MyPOD(c.Declarator):
-    def __init__(self, dtype, name):
-        self.dtype = dtype
-        self.name = name
-
-    def get_decl_pair(self):
-        return [self.dtype], self.name
-
-
-def getNextParentOfType(node, parentType):
-    parent = node.parent
-    while parent is not None:
-        if isinstance(parent, parentType):
-            return parent
-        parent = parent.parent
-    return None
-
-
-# --------------------------------------- AST Nodes  -------------------------------------------------------------------
-
-
-class Node:
-    def __init__(self, parent=None):
-        self.parent = parent
-
-    def args(self):
-        return []
-
-    def atoms(self, argType):
-        result = set()
-        for arg in self.args:
-            if isinstance(arg, argType):
-                result.add(arg)
-            result.update(arg.atoms(argType))
-        return result
-
-
-class DebugNode(Node):
-    def __init__(self, code, symbolsRead=[]):
-        self._code = code
-        self._symbolsRead = set(symbolsRead)
-
-    @property
-    def args(self):
-        return []
-
-    @property
-    def symbolsDefined(self):
-        return set()
-
-    @property
-    def symbolsRead(self):
-        return self._symbolsRead
-
-    def generateC(self):
-        return c.LiteralLines(self._code)
-
-
-class PrintNode(DebugNode):
-    def __init__(self, symbolToPrint):
-        code = '\nstd::cout << "%s  =  " << %s << std::endl; \n' % (symbolToPrint.name, symbolToPrint.name)
-        super(PrintNode, self).__init__(code, [symbolToPrint])
-
-
-class KernelFunction(Node):
-
-    class Argument:
-        def __init__(self, name, dtype):
-            self.name = name
-            self.dtype = dtype
-            self.isFieldPtrArgument = False
-            self.isFieldShapeArgument = False
-            self.isFieldStrideArgument = False
-            self.isFieldArgument = False
-            self.fieldName = ""
-            self.coordinate = None
-
-            if name.startswith(FIELD_PTR_PREFIX):
-                self.isFieldPtrArgument = True
-                self.isFieldArgument = True
-                self.fieldName = name[len(FIELD_PTR_PREFIX):]
-            elif name.startswith(Field.SHAPE_PREFIX):
-                self.isFieldShapeArgument = True
-                self.isFieldArgument = True
-                self.fieldName = name[len(Field.SHAPE_PREFIX):]
-            elif name.startswith(Field.STRIDE_PREFIX):
-                self.isFieldStrideArgument = True
-                self.isFieldArgument = True
-                self.fieldName = name[len(Field.STRIDE_PREFIX):]
-
-    def __init__(self, body, functionName="kernel"):
-        super(KernelFunction, self).__init__()
-        self._body = body
-        self._parameters = None
-        self._functionName = functionName
-        self._body.parent = self
-        self.variablesToIgnore = set()
-        self.qualifierPrefix = ""
-
-    @property
-    def symbolsDefined(self):
-        return set()
-
-    @property
-    def symbolsRead(self):
-        return set()
-
-    @property
-    def parameters(self):
-        self._updateArguments()
-        return self._parameters
-
-    @property
-    def body(self):
-        return self._body
-
-    @property
-    def args(self):
-        return [self._body]
-
-    @property
-    def functionName(self):
-        return self._functionName
-
-    def _updateArguments(self):
-        undefinedSymbols = self._body.symbolsRead - self._body.symbolsDefined - self.variablesToIgnore
-        self._parameters = [KernelFunction.Argument(s.name, s.dtype) for s in undefinedSymbols]
-        self._parameters.sort(key=lambda l: (l.fieldName, l.isFieldPtrArgument, l.isFieldShapeArgument,
-                                             l.isFieldStrideArgument, l.name),
-                              reverse=True)
-
-    def generateC(self):
-        self._updateArguments()
-        functionArguments = [MyPOD(s.dtype, s.name) for s in self._parameters]
-        functionPOD = MyPOD(self.qualifierPrefix + "void", self._functionName, )
-        funcDeclaration = c.FunctionDeclaration(functionPOD, functionArguments)
-        return c.FunctionBody(funcDeclaration, self._body.generateC())
-
-
-class Block(Node):
-    def __init__(self, listOfNodes):
-        super(Node, self).__init__()
-        self._nodes = listOfNodes
-        for n in self._nodes:
-            n.parent = self
-
-    @property
-    def args(self):
-        return self._nodes
-
-    def insertFront(self, node):
-        node.parent = self
-        self._nodes.insert(0, node)
-
-    def append(self, node):
-        node.parent = self
-        self._nodes.append(node)
-
-    def generateC(self):
-        return c.Block([e.generateC() for e in self.args])
-
-    def takeChildNodes(self):
-        tmp = self._nodes
-        self._nodes = []
-        return tmp
-
-    def replace(self, child, replacements):
-        idx = self._nodes.index(child)
-        del self._nodes[idx]
-        if type(replacements) is list:
-            for e in replacements:
-                e.parent = self
-            self._nodes = self._nodes[:idx] + replacements + self._nodes[idx:]
-        else:
-            replacements.parent = self
-            self._nodes.insert(idx, replacements)
-
-    @property
-    def symbolsDefined(self):
-        result = set()
-        for a in self.args:
-            result.update(a.symbolsDefined)
-        return result
-
-    @property
-    def symbolsRead(self):
-        result = set()
-        for a in self.args:
-            result.update(a.symbolsRead)
-        return result
-
-
-class PragmaBlock(Block):
-    def __init__(self, pragmaLine, listOfNodes):
-        super(PragmaBlock, self).__init__(listOfNodes)
-        self._pragmaLine = pragmaLine
-
-    def generateC(self):
-        class PragmaGenerable(c.Generable):
-            def __init__(self, line, block):
-                self._line = line
-                self._block = block
-
-            def generate(self):
-                yield self._line
-                for e in self._block.generate():
-                    yield e
-
-        return PragmaGenerable(self._pragmaLine, super(PragmaBlock, self).generateC())
-
-
-class LoopOverCoordinate(Node):
-
-    def __init__(self, body, coordinateToLoopOver, shape, increment=1, ghostLayers=1,
-                 isInnermostLoop=False, isOutermostLoop=False):
-        self._body = body
-        self._coordinateToLoopOver = coordinateToLoopOver
-        self._shape = shape
-        self._increment = increment
-        self._ghostLayers = ghostLayers
-        self._body.parent = self
-        self.prefixLines = []
-        self._isInnermostLoop = isInnermostLoop
-        self._isOutermostLoop = isOutermostLoop
-
-    def newLoopWithDifferentBody(self, newBody):
-        result = LoopOverCoordinate(newBody, self._coordinateToLoopOver, self._shape, self._increment,
-                                    self._ghostLayers, self._isInnermostLoop, self._isOutermostLoop)
-        result.prefixLines = self.prefixLines
-        return result
-
-    @property
-    def args(self):
-        result = [self._body]
-        limit = self._shape[self._coordinateToLoopOver]
-        if isinstance(limit, Node) or isinstance(limit, sp.Basic):
-            result.append(limit)
-        return result
-
-    @property
-    def body(self):
-        return self._body
-
-    @property
-    def loopCounterName(self):
-        return "%s_%s" % (COORDINATE_LOOP_COUNTER_NAME, self._coordinateToLoopOver)
-
-    @property
-    def coordinateToLoopOver(self):
-        return self._coordinateToLoopOver
-
-    @property
-    def symbolsDefined(self):
-        result = self._body.symbolsDefined
-        result.add(self.loopCounterSymbol)
-        return result
-
-    @property
-    def loopCounterSymbol(self):
-        return TypedSymbol(self.loopCounterName, "int")
-
-    @property
-    def symbolsRead(self):
-        result = self._body.symbolsRead
-        limit = self._shape[self._coordinateToLoopOver]
-        if isinstance(limit, sp.Basic):
-            result.update(limit.atoms(sp.Symbol))
-        return result
-
-    @property
-    def isOutermostLoop(self):
-        return self._isOutermostLoop
-
-    @property
-    def isInnermostLoop(self):
-        return self._isInnermostLoop
-
-    @property
-    def coordinateToLoopOver(self):
-        return self._coordinateToLoopOver
-
-    @property
-    def iterationRegionWithGhostLayer(self):
-        return self._shape[self._coordinateToLoopOver]
-
-    def generateC(self):
-        coord = self._coordinateToLoopOver
-        end = self._shape[coord] - self._ghostLayers
-
-        counterVar = self.loopCounterName
-
-        class LoopWithOptionalPrefix(c.CustomLoop):
-            def __init__(self, intro_line, body, prefixLines=[]):
-                super(LoopWithOptionalPrefix, self).__init__(intro_line, body)
-                self.prefixLines = prefixLines
-
-            def generate(self):
-                for l in self.prefixLines:
-                    yield l
-
-                for e in super(LoopWithOptionalPrefix, self).generate():
-                    yield e
-
-        start = "int %s = %d" % (counterVar, self._ghostLayers)
-        condition = "%s < %s" % (counterVar, codePrinter.doprint(end))
-        update = "++%s" % (counterVar,)
-        loopStr = "for (%s; %s; %s)" % (start, condition, update)
-        return LoopWithOptionalPrefix(loopStr, self._body.generateC(), prefixLines=self.prefixLines)
-
-
-class SympyAssignment(Node):
-
-    def __init__(self, lhsSymbol, rhsTerm, isConst=True):
-        self._lhsSymbol = lhsSymbol
-        self.rhs = rhsTerm
-        self._isDeclaration = True
-        if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, IndexedBase):
-            self._isDeclaration = False
-        self._isConst = isConst
-
-    @property
-    def lhs(self):
-        return self._lhsSymbol
-
-    @lhs.setter
-    def lhs(self, newValue):
-        self._lhsSymbol = newValue
-        self._isDeclaration = True
-        if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, Indexed):
-            self._isDeclaration = False
-
-    @property
-    def args(self):
-        return [self._lhsSymbol, self.rhs]
-
-    @property
-    def symbolsDefined(self):
-        if not self._isDeclaration:
-            return set()
-        return set([self._lhsSymbol])
-
-    @property
-    def symbolsRead(self):
-        result = self.rhs.atoms(sp.Symbol)
-        result.update(self._lhsSymbol.atoms(sp.Symbol))
-        return result
-
-    @property
-    def isConst(self):
-        return self._isConst
-
-    def __repr__(self):
-        return repr(self.lhs) + " = " + repr(self.rhs)
-
-    def generateC(self):
-        dtype = ""
-        if hasattr(self._lhsSymbol, 'dtype') and self._isDeclaration:
-            if self._isConst:
-                dtype = "const " + self._lhsSymbol.dtype + " "
-            else:
-                dtype = self._lhsSymbol.dtype + " "
-
-        return c.Assign(dtype + codePrinter.doprint(self._lhsSymbol),
-                        codePrinter.doprint(self.rhs))
-
-
-class CustomCppCode(Node):
-    def __init__(self, code, symbolsRead, symbolsDefined):
-        self._code = "\n" + code
-        self._symbolsRead = set(symbolsRead)
-        self._symbolsDefined = set(symbolsDefined)
-
-    @property
-    def args(self):
-        return []
-
-    @property
-    def symbolsDefined(self):
-        return self._symbolsDefined
-
-    @property
-    def symbolsRead(self):
-        return self._symbolsRead
-
-    def generateC(self):
-        return c.LiteralLines(self._code)
-
-
-class TemporaryArrayDefinition(Node):
-    def __init__(self, typedSymbol, size):
-        self._symbol = typedSymbol
-        self._size = size
-
-    @property
-    def symbolsDefined(self):
-        return set([self._symbol])
-
-    @property
-    def symbolsRead(self):
-        return set()
-
-    def generateC(self):
-        return c.Assign(self._symbol.dtype + " * " + codePrinter.doprint(self._symbol),
-                        "new %s[%s]" % (self._symbol.dtype, codePrinter.doprint(self._size)))
-
-    @property
-    def args(self):
-        return [self._symbol]
-
-
-class TemporaryArrayDelete(Node):
-    def __init__(self, typedSymbol):
-        self._symbol = typedSymbol
-
-    @property
-    def symbolsDefined(self):
-        return set()
-
-    @property
-    def symbolsRead(self):
-        return set()
-
-    def generateC(self):
-        return c.Statement("delete [] %s" % (codePrinter.doprint(self._symbol),))
-
-    @property
-    def args(self):
-        return []
-
-
-# --------------------------------------- Factory Functions ------------------------------------------------------------
-
-
-def getOptimalLoopOrdering(fields):
-    assert len(fields) > 0
-    refField = next(iter(fields))
-    for field in fields:
-        if field.spatialDimensions != refField.spatialDimensions:
-            raise ValueError("All fields have to have the same number of spatial dimensions")
-
-    layouts = set([field.layout for field in fields])
-    if len(layouts) > 1:
-        raise ValueError("Due to different layout of the fields no optimal loop ordering exists")
-    layout = list(layouts)[0]
-    return list(reversed(layout))
-
-
-def makeLoopOverDomain(body, functionName):
-    """
-    :param body: list of nodes
-    :param functionName: name of generated C function
-    :return: LoopOverCoordinate instance with nested loops, ordered according to field layouts
-    """
-    # find correct ordering by inspecting participating FieldAccesses
-    fieldAccesses = body.atoms(Field.Access)
-    fieldList = [e.field for e in fieldAccesses]
-    fields = set(fieldList)
-    loopOrder = getOptimalLoopOrdering(fields)
-
-    # find number of required ghost layers
-    requiredGhostLayers = max([fa.requiredGhostLayers for fa in fieldAccesses])
-
-    shapes = set([f.spatialShape for f in fields])
-
-    if len(shapes) > 1:
-        nrOfFixedSizedFields = 0
-        for shape in shapes:
-            if not isinstance(shape[0], sp.Basic):
-                nrOfFixedSizedFields += 1
-        assert nrOfFixedSizedFields <= 1, "Differently sized field accesses in loop body: " + str(shapes)
-    shape = list(shapes)[0]
-
-    currentBody = body
-    lastLoop = None
-    for i, loopCoordinate in enumerate(loopOrder):
-        newLoop = LoopOverCoordinate(currentBody, loopCoordinate, shape, 1, requiredGhostLayers,
-                                     isInnermostLoop=(i == 0), isOutermostLoop=(i == len(loopOrder) - 1))
-        lastLoop = newLoop
-        currentBody = Block([lastLoop])
-    return KernelFunction(currentBody, functionName)
-
-
-# --------------------------------------- Transformations --------------------------------------------------------------
-
-def createIntermediateBasePointer(fieldAccess, coordinates, previousPtr):
-    field = fieldAccess.field
-
-    offset = 0
-    name = ""
-    listToHash = []
-    for coordinateId, coordinateValue in coordinates.items():
-        offset += field.strides[coordinateId] * coordinateValue
-
-        if coordinateId < field.spatialDimensions:
-            offset += field.strides[coordinateId] * fieldAccess.offsets[coordinateId]
-            if type(fieldAccess.offsets[coordinateId]) is int:
-                offsetComp = offsetComponentToDirectionString(coordinateId, fieldAccess.offsets[coordinateId])
-                name += "_"
-                name += offsetComp if offsetComp else "C"
-            else:
-                listToHash.append(fieldAccess.offsets[coordinateId])
-        else:
-            if type(coordinateValue) is int:
-                name += "_%d" % (coordinateValue,)
-            else:
-                listToHash.append(coordinateValue)
-
-    if len(listToHash) > 0:
-        name += "%0.6X" % (abs(hash(tuple(listToHash))))
-
-    newPtr = TypedSymbol(previousPtr.name + name, previousPtr.dtype)
-    return newPtr, offset
-
-
-def parseBasePointerInfo(basePointerSpecification, loopOrder, field):
-    """
-    Allowed specifications:
-    "spatialInner<int>" spatialInner0 is the innermost loop coordinate, spatialInner1 the loop enclosing the innermost
-    "spatialOuter<int>" spatialOuter0 is the outermost loop
-    "index<int>": index coordinate
-    "<int>": specifying directly the coordinate
-    :param basePointerSpecification: nested list with above specifications
-    :param loopOrder: list with ordering of loops from inner to outer
-    :param field:
-    :return:
-    """
-    result = []
-    specifiedCoordinates = set()
-    for specGroup in basePointerSpecification:
-        newGroup = []
-
-        def addNewElement(i):
-            if i >= field.spatialDimensions + field.indexDimensions:
-                raise ValueError("Coordinate %d does not exist" % (i,))
-            newGroup.append(i)
-            if i in specifiedCoordinates:
-                raise ValueError("Coordinate %d specified two times" % (i,))
-            specifiedCoordinates.add(i)
-
-        for element in specGroup:
-            if type(element) is int:
-                addNewElement(element)
-            elif element.startswith("spatial"):
-                element = element[len("spatial"):]
-                if element.startswith("Inner"):
-                    index = int(element[len("Inner"):])
-                    addNewElement(loopOrder[index])
-                elif element.startswith("Outer"):
-                    index = int(element[len("Outer"):])
-                    addNewElement(loopOrder[-index])
-                elif element == "all":
-                    for i in range(field.spatialDimensions):
-                        addNewElement(i)
-                else:
-                    raise ValueError("Could not parse " + element)
-            elif element.startswith("index"):
-                index = int(element[len("index"):])
-                addNewElement(field.spatialDimensions + index)
-            else:
-                raise ValueError("Unknown specification %s" % (element,))
-
-        result.append(newGroup)
-
-    allCoordinates = set(range(field.spatialDimensions + field.indexDimensions))
-    rest = allCoordinates - specifiedCoordinates
-    if rest:
-        result.append(list(rest))
-    return result
-
-
-def getLoopHierarchy(block):
-    result = []
-    node = block
-    while node is not None:
-        node = getNextParentOfType(node, LoopOverCoordinate)
-        if node:
-            result.append(node.coordinateToLoopOver)
-    return result
-
-
-def resolveFieldAccesses(ast, fieldToBasePointerInfo={}, fieldToFixedCoordinates={}):
-    """Substitutes FieldAccess nodes by array indexing"""
-
-    def visitSympyExpr(expr, enclosingBlock):
-        if isinstance(expr, Field.Access):
-            fieldAccess = expr
-            field = fieldAccess.field
-            if field.name in fieldToBasePointerInfo:
-                basePointerInfo = fieldToBasePointerInfo[field.name]
-            else:
-                basePointerInfo = [list(range(field.indexDimensions + field.spatialDimensions))]
-
-            dtype = "%s * __restrict__" % field.dtype
-            if field.readOnly:
-                dtype = "const " + dtype
-
-            fieldPtr = TypedSymbol("%s%s" % (FIELD_PTR_PREFIX, field.name), dtype)
-
-            lastPointer = fieldPtr
-
-            def createCoordinateDict(group):
-                coordDict = {}
-                for e in group:
-                    if e < field.spatialDimensions:
-                        if field.name in fieldToFixedCoordinates:
-                            coordDict[e] = fieldToFixedCoordinates[field.name][e]
-                        else:
-                            coordDict[e] = TypedSymbol("%s_%d" % (COORDINATE_LOOP_COUNTER_NAME, e), "int")
-                    else:
-                        coordDict[e] = fieldAccess.index[e-field.spatialDimensions]
-                return coordDict
-
-            for group in reversed(basePointerInfo[1:]):
-                coordDict = createCoordinateDict(group)
-                newPtr, offset = createIntermediateBasePointer(fieldAccess, coordDict, lastPointer)
-                if newPtr not in enclosingBlock.symbolsDefined:
-                    enclosingBlock.insertFront(SympyAssignment(newPtr, lastPointer + offset, isConst=False))
-                lastPointer = newPtr
-
-            _, offset = createIntermediateBasePointer(fieldAccess, createCoordinateDict(basePointerInfo[0]), lastPointer)
-            baseArr = IndexedBase(lastPointer, shape=(1,))
-            return baseArr[offset]
-        else:
-            newArgs = [visitSympyExpr(e, enclosingBlock) for e in expr.args]
-            kwargs = {'evaluate': False} if type(expr) is sp.Add or type(expr) is sp.Mul else {}
-            return expr.func(*newArgs, **kwargs) if newArgs else expr
-
-    def visitNode(subAst):
-        if isinstance(subAst, SympyAssignment):
-            enclosingBlock = subAst.parent
-            assert type(enclosingBlock) is Block
-            subAst.lhs = visitSympyExpr(subAst.lhs, enclosingBlock)
-            subAst.rhs = visitSympyExpr(subAst.rhs, enclosingBlock)
-        else:
-            for i, a in enumerate(subAst.args):
-                visitNode(a)
-
-    return visitNode(ast)
-
-
-def moveConstantsBeforeLoop(ast):
-
-    def findBlockToMoveTo(node):
-        """Traverses parents of node as long as the symbols are independent and returns a (parent) block
-        the assignment can be safely moved to
-        :param node: SympyAssignment inside a Block"""
-        assert isinstance(node, SympyAssignment)
-        assert isinstance(node.parent, Block)
-
-        lastBlock = node.parent
-        element = node.parent
-        while element:
-            if isinstance(element, Block):
-                lastBlock = element
-            if node.symbolsRead.intersection(element.symbolsDefined):
-                break
-            element = element.parent
-        return lastBlock
-
-    def checkIfAssignmentAlreadyInBlock(assignment, targetBlock):
-        for arg in targetBlock.args:
-            if type(arg) is not SympyAssignment:
-                continue
-            if arg.lhs == assignment.lhs:
-                return arg
-        return None
-
-    for block in ast.atoms(Block):
-        children = block.takeChildNodes()
-        for child in children:
-            if not isinstance(child, SympyAssignment):
-                block.append(child)
-            else:
-                target = findBlockToMoveTo(child)
-                if target == block:     # movement not possible
-                    target.append(child)
-                else:
-                    existingAssignment = checkIfAssignmentAlreadyInBlock(child, target)
-                    if not existingAssignment:
-                        target.insertFront(child)
-                    else:
-                        assert existingAssignment.rhs == child.rhs, "Symbol with same name exists already"
-
-
-def splitInnerLoop(ast, symbolGroups):
-    allLoops = ast.atoms(LoopOverCoordinate)
-    innerLoop = [l for l in allLoops if l.isInnermostLoop]
-    assert len(innerLoop) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?"
-    innerLoop = innerLoop[0]
-    assert type(innerLoop.body) is Block
-    outerLoop = [l for l in allLoops if l.isOutermostLoop]
-    assert len(outerLoop) == 1, "Error in AST, multiple outermost loops."
-    outerLoop = outerLoop[0]
-
-    symbolsWithTemporaryArray = dict()
-
-    assignmentMap = {a.lhs: a for a in innerLoop.body.args}
-
-    assignmentGroups = []
-    for symbolGroup in symbolGroups:
-        # get all dependent symbols
-        symbolsToProcess = list(symbolGroup)
-        symbolsResolved = set()
-        while symbolsToProcess:
-            s = symbolsToProcess.pop()
-            if s in symbolsResolved:
-                continue
-
-            if s in assignmentMap:  # if there is no assignment inside the loop body it is independent already
-                for newSymbol in assignmentMap[s].rhs.atoms(sp.Symbol):
-                    if type(newSymbol) is not Field.Access and newSymbol not in symbolsWithTemporaryArray:
-                        symbolsToProcess.append(newSymbol)
-            symbolsResolved.add(s)
-
-        for symbol in symbolGroup:
-            if type(symbol) is not Field.Access:
-                assert type(symbol) is TypedSymbol
-                symbolsWithTemporaryArray[symbol] = IndexedBase(symbol, shape=(1,))[innerLoop.loopCounterSymbol]
-
-        assignmentGroup = []
-        for assignment in innerLoop.body.args:
-            if assignment.lhs in symbolsResolved:
-                newRhs = assignment.rhs.subs(symbolsWithTemporaryArray.items())
-                if type(assignment.lhs) is not Field.Access and assignment.lhs in symbolGroup:
-                    newLhs = IndexedBase(assignment.lhs, shape=(1,))[innerLoop.loopCounterSymbol]
-                else:
-                    newLhs = assignment.lhs
-                assignmentGroup.append(SympyAssignment(newLhs, newRhs))
-        assignmentGroups.append(assignmentGroup)
-
-    newLoops = [innerLoop.newLoopWithDifferentBody(Block(group)) for group in assignmentGroups]
-    innerLoop.parent.replace(innerLoop, newLoops)
-
-    for tmpArray in symbolsWithTemporaryArray:
-        outerLoop.parent.insertFront(TemporaryArrayDefinition(tmpArray, innerLoop.iterationRegionWithGhostLayer))
-        outerLoop.parent.append(TemporaryArrayDelete(tmpArray))
-
-
-# ------------------------------------- Main ---------------------------------------------------------------------------
-
-
-def extractCommonSubexpressions(equations):
-    """Uses sympy to find common subexpressions in equations and returns
-    them in a topologically sorted order, ready for evaluation"""
-    replacements, newEq = sp.cse(equations)
-    replacementEqs = [sp.Eq(*r) for r in replacements]
-    equations = replacementEqs + newEq
-    topologicallySortedPairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in equations])
-    equations = [sp.Eq(*a) for a in topologicallySortedPairs]
-    return equations
-
-
-def addOpenMP(ast):
-    assert type(ast) is KernelFunction
-    body = ast.body
-    wrapperBlock = PragmaBlock('#pragma omp parallel', body.takeChildNodes())
-    body.append(wrapperBlock)
-
-    outerLoops = [l for l in body.atoms(LoopOverCoordinate) if l.isOutermostLoop]
-    assert outerLoops, "No outer loop found"
-    assert len(outerLoops) <= 1, "More than one outer loop found. Which one should be parallelized?"
-    outerLoops[0].prefixLines.append("#pragma omp for schedule(static)")
-
-
-def typeAllEquations(eqs, typeForSymbol):
-    fieldsWritten = set()
-    fieldsRead = set()
-
-    def processRhs(term):
-        """Replaces Symbols by:
-            - TypedSymbol if symbol is not a field access
-        """
-        if isinstance(term, Field.Access):
-            fieldsRead.add(term.field)
-            return term
-        elif isinstance(term, sp.Symbol):
-            return TypedSymbol(term.name, typeForSymbol[term.name])
-        else:
-            newArgs = [processRhs(arg) for arg in term.args]
-            return term.func(*newArgs) if newArgs else term
-
-    def processLhs(term):
-        """Replaces symbol by TypedSymbol and adds field to fieldsWriten"""
-        if isinstance(term, Field.Access):
-            fieldsWritten.add(term.field)
-            return term
-        elif isinstance(term, sp.Symbol):
-            return TypedSymbol(term.name, typeForSymbol[term.name])
-        else:
-            assert False, "Expected a symbol as left-hand-side"
-
-    typedEquations = []
-    for eq in eqs:
-        if isinstance(eq, sp.Eq):
-            newLhs = processLhs(eq.lhs)
-            newRhs = processRhs(eq.rhs)
-            typedEquations.append(SympyAssignment(newLhs, newRhs))
-        else:
-            assert isinstance(eq, Node), "Only equations and ast nodes are allowed in input"
-            typedEquations.append(eq)
-
-    typedEquations = typedEquations
-
-    return fieldsRead, fieldsWritten, typedEquations
-
-
-def typingFromSympyInspection(eqs, defaultType="double"):
-    result = defaultdict(lambda: defaultType)
-    for eq in eqs:
-        if isinstance(eq.rhs, Boolean):
-            result[eq.lhs.name] = "bool"
-    return result
-
-
-def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, splitGroups=[]):
-    if not typeForSymbol:
-        typeForSymbol = typingFromSympyInspection(listOfEquations, "double")
-
-    def typeSymbol(term):
-        if isinstance(term, Field.Access) or isinstance(term, TypedSymbol):
-            return term
-        elif isinstance(term, sp.Symbol):
-            return TypedSymbol(term.name, typeForSymbol[term.name])
-        else:
-            raise ValueError("Term has to be field access or symbol")
-
-    fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol)
-    allFields = fieldsRead.union(fieldsWritten)
-
-    for field in allFields:
-        field.setReadOnly(False)
-    for field in fieldsRead - fieldsWritten:
-        field.setReadOnly()
-
-    body = Block(assignments)
-    code = makeLoopOverDomain(body, functionName)
-
-    if splitGroups:
-        typedSplitGroups = [[typeSymbol(s) for s in splitGroup] for splitGroup in splitGroups]
-        splitInnerLoop(code, typedSplitGroups)
-
-    loopOrder = getOptimalLoopOrdering(allFields)
-
-    basePointerInfo = [['spatialInner0'], ['spatialInner1']]
-    basePointerInfos = {f.name: parseBasePointerInfo(basePointerInfo, loopOrder, f) for f in allFields}
-
-    resolveFieldAccesses(code, fieldToBasePointerInfo=basePointerInfos)
-    moveConstantsBeforeLoop(code)
-    addOpenMP(code)
-
-    return code
-
-
-if __name__ == "__main__":
-    f = Field.createGeneric('f', 3, indexDimensions=1)
-    pointerSpec = [['spatialInner0']]
-    parseBasePointerInfo(pointerSpec, [0, 1, 2], f)
\ No newline at end of file
diff --git a/jit.py b/jit.py
index fdaf9800e..5debe1a68 100644
--- a/jit.py
+++ b/jit.py
@@ -2,7 +2,7 @@ import os
 import subprocess
 from ctypes import cdll, c_double, c_float, sizeof
 from tempfile import TemporaryDirectory
-
+from pystencils.backends.cbackend import printCCode
 import numpy as np
 
 
@@ -67,7 +67,7 @@ def compileAndLoad(kernelFunctionNode):
             print('#include <iostream>', file=sourceFile)
             print("#include <cmath>", file=sourceFile)
             print('extern "C" { ', file=sourceFile)
-            print(kernelFunctionNode.generateC(), file=sourceFile)
+            print(printCCode(kernelFunctionNode), file=sourceFile)
             print('}', file=sourceFile)
 
         compilerCmd = [CONFIG['compiler']] + CONFIG['flags'].split()
diff --git a/transformations.py b/transformations.py
new file mode 100644
index 000000000..0ef3b0a3f
--- /dev/null
+++ b/transformations.py
@@ -0,0 +1,444 @@
+from collections import defaultdict
+import sympy as sp
+from sympy.logic.boolalg import Boolean
+from sympy.tensor import IndexedBase
+from pystencils.field import Field, offsetComponentToDirectionString
+from pystencils.typedsymbol import TypedSymbol
+import pystencils.ast as ast
+
+
+# --------------------------------------- Factory Functions ------------------------------------------------------------
+
+
+def makeLoopOverDomain(body, functionName):
+    """
+    :param body: list of nodes
+    :param functionName: name of generated C function
+    :return: LoopOverCoordinate instance with nested loops, ordered according to field layouts
+    """
+    # find correct ordering by inspecting participating FieldAccesses
+    fieldAccesses = body.atoms(Field.Access)
+    fieldList = [e.field for e in fieldAccesses]
+    fields = set(fieldList)
+    loopOrder = getOptimalLoopOrdering(fields)
+
+    # find number of required ghost layers
+    requiredGhostLayers = max([fa.requiredGhostLayers for fa in fieldAccesses])
+
+    shapes = set([f.spatialShape for f in fields])
+
+    if len(shapes) > 1:
+        nrOfFixedSizedFields = 0
+        for shape in shapes:
+            if not isinstance(shape[0], sp.Basic):
+                nrOfFixedSizedFields += 1
+        assert nrOfFixedSizedFields <= 1, "Differently sized field accesses in loop body: " + str(shapes)
+    shape = list(shapes)[0]
+
+    currentBody = body
+    lastLoop = None
+    for i, loopCoordinate in enumerate(loopOrder):
+        newLoop = ast.LoopOverCoordinate(currentBody, loopCoordinate, shape, 1, requiredGhostLayers,
+                                         isInnermostLoop=(i == 0), isOutermostLoop=(i == len(loopOrder) - 1))
+        lastLoop = newLoop
+        currentBody = ast.Block([lastLoop])
+    return ast.KernelFunction(currentBody, functionName)
+
+
+# --------------------------------------- Transformations --------------------------------------------------------------
+
+def createIntermediateBasePointer(fieldAccess, coordinates, previousPtr):
+    field = fieldAccess.field
+
+    offset = 0
+    name = ""
+    listToHash = []
+    for coordinateId, coordinateValue in coordinates.items():
+        offset += field.strides[coordinateId] * coordinateValue
+
+        if coordinateId < field.spatialDimensions:
+            offset += field.strides[coordinateId] * fieldAccess.offsets[coordinateId]
+            if type(fieldAccess.offsets[coordinateId]) is int:
+                offsetComp = offsetComponentToDirectionString(coordinateId, fieldAccess.offsets[coordinateId])
+                name += "_"
+                name += offsetComp if offsetComp else "C"
+            else:
+                listToHash.append(fieldAccess.offsets[coordinateId])
+        else:
+            if type(coordinateValue) is int:
+                name += "_%d" % (coordinateValue,)
+            else:
+                listToHash.append(coordinateValue)
+
+    if len(listToHash) > 0:
+        name += "%0.6X" % (abs(hash(tuple(listToHash))))
+
+    newPtr = TypedSymbol(previousPtr.name + name, previousPtr.dtype)
+    return newPtr, offset
+
+
+def parseBasePointerInfo(basePointerSpecification, loopOrder, field):
+    """
+    Allowed specifications:
+    "spatialInner<int>" spatialInner0 is the innermost loop coordinate, spatialInner1 the loop enclosing the innermost
+    "spatialOuter<int>" spatialOuter0 is the outermost loop
+    "index<int>": index coordinate
+    "<int>": specifying directly the coordinate
+    :param basePointerSpecification: nested list with above specifications
+    :param loopOrder: list with ordering of loops from inner to outer
+    :param field:
+    :return:
+    """
+    result = []
+    specifiedCoordinates = set()
+    for specGroup in basePointerSpecification:
+        newGroup = []
+
+        def addNewElement(i):
+            if i >= field.spatialDimensions + field.indexDimensions:
+                raise ValueError("Coordinate %d does not exist" % (i,))
+            newGroup.append(i)
+            if i in specifiedCoordinates:
+                raise ValueError("Coordinate %d specified two times" % (i,))
+            specifiedCoordinates.add(i)
+
+        for element in specGroup:
+            if type(element) is int:
+                addNewElement(element)
+            elif element.startswith("spatial"):
+                element = element[len("spatial"):]
+                if element.startswith("Inner"):
+                    index = int(element[len("Inner"):])
+                    addNewElement(loopOrder[index])
+                elif element.startswith("Outer"):
+                    index = int(element[len("Outer"):])
+                    addNewElement(loopOrder[-index])
+                elif element == "all":
+                    for i in range(field.spatialDimensions):
+                        addNewElement(i)
+                else:
+                    raise ValueError("Could not parse " + element)
+            elif element.startswith("index"):
+                index = int(element[len("index"):])
+                addNewElement(field.spatialDimensions + index)
+            else:
+                raise ValueError("Unknown specification %s" % (element,))
+
+        result.append(newGroup)
+
+    allCoordinates = set(range(field.spatialDimensions + field.indexDimensions))
+    rest = allCoordinates - specifiedCoordinates
+    if rest:
+        result.append(list(rest))
+    return result
+
+
+def resolveFieldAccesses(astNode, fieldToBasePointerInfo={}, fieldToFixedCoordinates={}):
+    """Substitutes FieldAccess nodes by array indexing"""
+
+    def visitSympyExpr(expr, enclosingBlock):
+        if isinstance(expr, Field.Access):
+            fieldAccess = expr
+            field = fieldAccess.field
+            if field.name in fieldToBasePointerInfo:
+                basePointerInfo = fieldToBasePointerInfo[field.name]
+            else:
+                basePointerInfo = [list(range(field.indexDimensions + field.spatialDimensions))]
+
+            dtype = "%s * __restrict__" % field.dtype
+            if field.readOnly:
+                dtype = "const " + dtype
+
+            fieldPtr = TypedSymbol("%s%s" % (Field.DATA_PREFIX, field.name), dtype)
+
+            lastPointer = fieldPtr
+
+            def createCoordinateDict(group):
+                coordDict = {}
+                for e in group:
+                    if e < field.spatialDimensions:
+                        if field.name in fieldToFixedCoordinates:
+                            coordDict[e] = fieldToFixedCoordinates[field.name][e]
+                        else:
+                            ctrName = ast.LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
+                            coordDict[e] = TypedSymbol("%s_%d" % (ctrName, e), "int")
+                    else:
+                        coordDict[e] = fieldAccess.index[e-field.spatialDimensions]
+                return coordDict
+
+            for group in reversed(basePointerInfo[1:]):
+                coordDict = createCoordinateDict(group)
+                newPtr, offset = createIntermediateBasePointer(fieldAccess, coordDict, lastPointer)
+                if newPtr not in enclosingBlock.symbolsDefined:
+                    enclosingBlock.insertFront(ast.SympyAssignment(newPtr, lastPointer + offset, isConst=False))
+                lastPointer = newPtr
+
+            _, offset = createIntermediateBasePointer(fieldAccess, createCoordinateDict(basePointerInfo[0]),
+                                                      lastPointer)
+            baseArr = IndexedBase(lastPointer, shape=(1,))
+            return baseArr[offset]
+        else:
+            newArgs = [visitSympyExpr(e, enclosingBlock) for e in expr.args]
+            kwargs = {'evaluate': False} if type(expr) is sp.Add or type(expr) is sp.Mul else {}
+            return expr.func(*newArgs, **kwargs) if newArgs else expr
+
+    def visitNode(subAst):
+        if isinstance(subAst, ast.SympyAssignment):
+            enclosingBlock = subAst.parent
+            assert type(enclosingBlock) is ast.Block
+            subAst.lhs = visitSympyExpr(subAst.lhs, enclosingBlock)
+            subAst.rhs = visitSympyExpr(subAst.rhs, enclosingBlock)
+        else:
+            for i, a in enumerate(subAst.args):
+                visitNode(a)
+
+    return visitNode(astNode)
+
+
+def moveConstantsBeforeLoop(astNode):
+
+    def findBlockToMoveTo(node):
+        """Traverses parents of node as long as the symbols are independent and returns a (parent) block
+        the assignment can be safely moved to
+        :param node: SympyAssignment inside a Block"""
+        assert isinstance(node, ast.SympyAssignment)
+        assert isinstance(node.parent, ast.Block)
+
+        lastBlock = node.parent
+        element = node.parent
+        while element:
+            if isinstance(element, ast.Block):
+                lastBlock = element
+            if node.symbolsRead.intersection(element.symbolsDefined):
+                break
+            element = element.parent
+        return lastBlock
+
+    def checkIfAssignmentAlreadyInBlock(assignment, targetBlock):
+        for arg in targetBlock.args:
+            if type(arg) is not ast.SympyAssignment:
+                continue
+            if arg.lhs == assignment.lhs:
+                return arg
+        return None
+
+    for block in astNode.atoms(ast.Block):
+        children = block.takeChildNodes()
+        for child in children:
+            if not isinstance(child, ast.SympyAssignment):
+                block.append(child)
+            else:
+                target = findBlockToMoveTo(child)
+                if target == block:     # movement not possible
+                    target.append(child)
+                else:
+                    existingAssignment = checkIfAssignmentAlreadyInBlock(child, target)
+                    if not existingAssignment:
+                        target.insertFront(child)
+                    else:
+                        assert existingAssignment.rhs == child.rhs, "Symbol with same name exists already"
+
+
+def splitInnerLoop(astNode, symbolGroups):
+    allLoops = astNode.atoms(ast.LoopOverCoordinate)
+    innerLoop = [l for l in allLoops if l.isInnermostLoop]
+    assert len(innerLoop) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?"
+    innerLoop = innerLoop[0]
+    assert type(innerLoop.body) is ast.Block
+    outerLoop = [l for l in allLoops if l.isOutermostLoop]
+    assert len(outerLoop) == 1, "Error in AST, multiple outermost loops."
+    outerLoop = outerLoop[0]
+
+    symbolsWithTemporaryArray = dict()
+
+    assignmentMap = {a.lhs: a for a in innerLoop.body.args}
+
+    assignmentGroups = []
+    for symbolGroup in symbolGroups:
+        # get all dependent symbols
+        symbolsToProcess = list(symbolGroup)
+        symbolsResolved = set()
+        while symbolsToProcess:
+            s = symbolsToProcess.pop()
+            if s in symbolsResolved:
+                continue
+
+            if s in assignmentMap:  # if there is no assignment inside the loop body it is independent already
+                for newSymbol in assignmentMap[s].rhs.atoms(sp.Symbol):
+                    if type(newSymbol) is not Field.Access and newSymbol not in symbolsWithTemporaryArray:
+                        symbolsToProcess.append(newSymbol)
+            symbolsResolved.add(s)
+
+        for symbol in symbolGroup:
+            if type(symbol) is not Field.Access:
+                assert type(symbol) is TypedSymbol
+                symbolsWithTemporaryArray[symbol] = IndexedBase(symbol, shape=(1,))[innerLoop.loopCounterSymbol]
+
+        assignmentGroup = []
+        for assignment in innerLoop.body.args:
+            if assignment.lhs in symbolsResolved:
+                newRhs = assignment.rhs.subs(symbolsWithTemporaryArray.items())
+                if type(assignment.lhs) is not Field.Access and assignment.lhs in symbolGroup:
+                    newLhs = IndexedBase(assignment.lhs, shape=(1,))[innerLoop.loopCounterSymbol]
+                else:
+                    newLhs = assignment.lhs
+                assignmentGroup.append(ast.SympyAssignment(newLhs, newRhs))
+        assignmentGroups.append(assignmentGroup)
+
+    newLoops = [innerLoop.newLoopWithDifferentBody(ast.Block(group)) for group in assignmentGroups]
+    innerLoop.parent.replace(innerLoop, newLoops)
+
+    for tmpArray in symbolsWithTemporaryArray:
+        outerLoop.parent.insertFront(ast.TemporaryMemoryAllocation(tmpArray, innerLoop.iterationRegionWithGhostLayer))
+        outerLoop.parent.append(ast.TemporaryMemoryFree(tmpArray))
+
+
+# ------------------------------------- Main ---------------------------------------------------------------------------
+
+
+def extractCommonSubexpressions(equations):
+    """Uses sympy to find common subexpressions in equations and returns
+    them in a topologically sorted order, ready for evaluation"""
+    replacements, newEq = sp.cse(equations)
+    replacementEqs = [sp.Eq(*r) for r in replacements]
+    equations = replacementEqs + newEq
+    topologicallySortedPairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in equations])
+    equations = [sp.Eq(*a) for a in topologicallySortedPairs]
+    return equations
+
+
+def addOpenMP(astNode):
+    assert type(astNode) is ast.KernelFunction
+    body = astNode.body
+    wrapperBlock = ast.PragmaBlock('#pragma omp parallel', body.takeChildNodes())
+    body.append(wrapperBlock)
+
+    outerLoops = [l for l in body.atoms(ast.LoopOverCoordinate) if l.isOutermostLoop]
+    assert outerLoops, "No outer loop found"
+    assert len(outerLoops) <= 1, "More than one outer loop found. Which one should be parallelized?"
+    outerLoops[0].prefixLines.append("#pragma omp for schedule(static)")
+
+
+def typeAllEquations(eqs, typeForSymbol):
+    fieldsWritten = set()
+    fieldsRead = set()
+
+    def processRhs(term):
+        """Replaces Symbols by:
+            - TypedSymbol if symbol is not a field access
+        """
+        if isinstance(term, Field.Access):
+            fieldsRead.add(term.field)
+            return term
+        elif isinstance(term, sp.Symbol):
+            return TypedSymbol(term.name, typeForSymbol[term.name])
+        else:
+            newArgs = [processRhs(arg) for arg in term.args]
+            return term.func(*newArgs) if newArgs else term
+
+    def processLhs(term):
+        """Replaces symbol by TypedSymbol and adds field to fieldsWriten"""
+        if isinstance(term, Field.Access):
+            fieldsWritten.add(term.field)
+            return term
+        elif isinstance(term, sp.Symbol):
+            return TypedSymbol(term.name, typeForSymbol[term.name])
+        else:
+            assert False, "Expected a symbol as left-hand-side"
+
+    typedEquations = []
+    for eq in eqs:
+        if isinstance(eq, sp.Eq):
+            newLhs = processLhs(eq.lhs)
+            newRhs = processRhs(eq.rhs)
+            typedEquations.append(ast.SympyAssignment(newLhs, newRhs))
+        else:
+            assert isinstance(eq, ast.Node), "Only equations and ast nodes are allowed in input"
+            typedEquations.append(eq)
+
+    typedEquations = typedEquations
+
+    return fieldsRead, fieldsWritten, typedEquations
+
+
+def typingFromSympyInspection(eqs, defaultType="double"):
+    result = defaultdict(lambda: defaultType)
+    for eq in eqs:
+        if isinstance(eq.rhs, Boolean):
+            result[eq.lhs.name] = "bool"
+    return result
+
+
+def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, splitGroups=[]):
+    if not typeForSymbol:
+        typeForSymbol = typingFromSympyInspection(listOfEquations, "double")
+
+    def typeSymbol(term):
+        if isinstance(term, Field.Access) or isinstance(term, TypedSymbol):
+            return term
+        elif isinstance(term, sp.Symbol):
+            return TypedSymbol(term.name, typeForSymbol[term.name])
+        else:
+            raise ValueError("Term has to be field access or symbol")
+
+    fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol)
+    allFields = fieldsRead.union(fieldsWritten)
+
+    for field in allFields:
+        field.setReadOnly(False)
+    for field in fieldsRead - fieldsWritten:
+        field.setReadOnly()
+
+    body = ast.Block(assignments)
+    code = makeLoopOverDomain(body, functionName)
+
+    if splitGroups:
+        typedSplitGroups = [[typeSymbol(s) for s in splitGroup] for splitGroup in splitGroups]
+        splitInnerLoop(code, typedSplitGroups)
+
+    loopOrder = getOptimalLoopOrdering(allFields)
+
+    basePointerInfo = [['spatialInner0'], ['spatialInner1']]
+    basePointerInfos = {field.name: parseBasePointerInfo(basePointerInfo, loopOrder, field) for field in allFields}
+
+    resolveFieldAccesses(code, fieldToBasePointerInfo=basePointerInfos)
+    moveConstantsBeforeLoop(code)
+    addOpenMP(code)
+
+    return code
+
+
+# --------------------------------------- Helper Functions -------------------------------------------------------------
+
+
+def getNextParentOfType(node, parentType):
+    parent = node.parent
+    while parent is not None:
+        if isinstance(parent, parentType):
+            return parent
+        parent = parent.parent
+    return None
+
+
+def getOptimalLoopOrdering(fields):
+    assert len(fields) > 0
+    refField = next(iter(fields))
+    for field in fields:
+        if field.spatialDimensions != refField.spatialDimensions:
+            raise ValueError("All fields have to have the same number of spatial dimensions")
+
+    layouts = set([field.layout for field in fields])
+    if len(layouts) > 1:
+        raise ValueError("Due to different layout of the fields no optimal loop ordering exists")
+    layout = list(layouts)[0]
+    return list(reversed(layout))
+
+
+def getLoopHierarchy(block):
+    result = []
+    node = block
+    while node is not None:
+        node = getNextParentOfType(node, ast.LoopOverCoordinate)
+        if node:
+            result.append(node.coordinateToLoopOver)
+    return result
\ No newline at end of file
-- 
GitLab