From 6154104699bde5a50e50dfd74c112cee00f56f8b Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Wed, 2 Nov 2016 12:07:25 +0100
Subject: [PATCH] Restructuring: moved to pystencils

---
 cuda.py              |  73 ++++
 cudajit.py           |  65 ++++
 field.py             | 385 +++++++++++++++++++
 finitedifferences.py |   5 +-
 generator.py         | 893 +++++++++++++++++++++++++++++++++++++++++++
 jit.py               | 149 ++++++++
 typedsymbol.py       |  26 ++
 7 files changed, 1594 insertions(+), 2 deletions(-)
 create mode 100644 cuda.py
 create mode 100644 cudajit.py
 create mode 100644 field.py
 create mode 100644 generator.py
 create mode 100644 jit.py
 create mode 100644 typedsymbol.py

diff --git a/cuda.py b/cuda.py
new file mode 100644
index 000000000..a9af8361a
--- /dev/null
+++ b/cuda.py
@@ -0,0 +1,73 @@
+from collections import defaultdict
+
+import sympy as sp
+
+from pystencils.generator import resolveFieldAccesses
+from pystencils.generator import typeAllEquations, Block, KernelFunction, parseBasePointerInfo
+
+BLOCK_IDX = list(sp.symbols("blockIdx.x blockIdx.y blockIdx.z"))
+THREAD_IDX = list(sp.symbols("threadIdx.x threadIdx.y threadIdx.z"))
+
+"""
+GPU Access Patterns
+
+- knows about the iteration range
+- know about mapping of field indices to CUDA block and thread indices
+- iterates over spatial coordinates - constructed with a specific number of coordinates
+- can
+"""
+
+
+def getLinewiseCoordinateAccessExpression(field, indexCoordinate):
+    availableIndices = [THREAD_IDX[0]] + BLOCK_IDX
+    fastestCoordinate = field.layout[-1]
+    availableIndices[fastestCoordinate], availableIndices[0] = availableIndices[0], availableIndices[fastestCoordinate]
+    cudaIndices = availableIndices[:field.spatialDimensions]
+
+    offsetToCell = sum([cudaIdx * stride for cudaIdx, stride in zip(cudaIndices, field.spatialStrides)])
+    indexOffset = sum([idx * indexStride for idx, indexStride in zip(indexCoordinate, field.indexStrides)])
+    return sp.simplify(offsetToCell + indexOffset)
+
+
+def getLinewiseCoordinates(field):
+    availableIndices = [THREAD_IDX[0]] + BLOCK_IDX
+    d = field.spatialDimensions + field.indexDimensions
+    fastestCoordinate = field.layout[-1]
+    result = availableIndices[:d]
+    result[0], result[fastestCoordinate] = result[fastestCoordinate], result[0]
+    return result
+
+
+def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=defaultdict(lambda: "double")):
+    fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol)
+    for f in fieldsRead - fieldsWritten:
+        f.setReadOnly()
+
+    code = KernelFunction(Block(assignments), functionName)
+    code.qualifierPrefix = "__global__ "
+    code.variablesToIgnore.update(BLOCK_IDX + THREAD_IDX)
+
+    coordMapping = getLinewiseCoordinates(list(fieldsRead)[0])
+    allFields = fieldsRead.union(fieldsWritten)
+    basePointerInfo = [['spatialInner0']]
+    basePointerInfos = {f.name: parseBasePointerInfo(basePointerInfo, [0, 1, 2], f) for f in allFields}
+    resolveFieldAccesses(code, fieldToFixedCoordinates={'src': coordMapping, 'dst': coordMapping},
+                         fieldToBasePointerInfo=basePointerInfos)
+    return code
+
+
+if __name__ == "__main__":
+    import sympy as sp
+    from lbmpy.stencils import getStencil
+    from lbmpy.collisionoperator import makeSRT
+    from lbmpy.lbmgenerator import createLbmEquations
+
+    latticeModel = makeSRT(getStencil("D2Q9"), order=2, compressible=False)
+    r = createLbmEquations(latticeModel, doCSE=True)
+    kernel = createCUDAKernel(r)
+    print(kernel.generateC())
+
+    from pycuda.compiler import SourceModule
+
+    mod = SourceModule(str(kernel.generateC()))
+    func = mod.get_function("kernel")
diff --git a/cudajit.py b/cudajit.py
new file mode 100644
index 000000000..29239b63e
--- /dev/null
+++ b/cudajit.py
@@ -0,0 +1,65 @@
+import numpy as np
+import pycuda.driver as cuda
+from pycuda.compiler import SourceModule
+
+
+def numpyTypeFromString(typename, includePointers=True):
+    import ctypes as ct
+
+    typename = typename.replace("*", " * ")
+    typeComponents = typename.split()
+
+    basicTypeMap = {
+        'double': np.float64,
+        'float': np.float32,
+        'int': np.int32,
+        'long': np.int64,
+    }
+
+    resultType = None
+    for typeComponent in typeComponents:
+        typeComponent = typeComponent.strip()
+        if typeComponent == "const" or typeComponent == "restrict" or typeComponent == "volatile":
+            continue
+        if typeComponent in basicTypeMap:
+            resultType = basicTypeMap[typeComponent]
+        elif typeComponent == "*" and includePointers:
+            assert resultType is not None
+            resultType = ct.POINTER(resultType)
+
+    return resultType
+
+
+def buildNumpyArgumentList(kernelFunctionNode, argumentDict):
+    result = []
+    for arg in kernelFunctionNode.parameters:
+        if arg.isFieldArgument:
+            field = argumentDict[arg.fieldName]
+            if arg.isFieldPtrArgument:
+                result.append(field.gpudata)
+            elif arg.isFieldShapeArgument:
+                strideArr = np.array(field.strides, dtype=np.int32) / field.dtype.itemsize
+                result.append(cuda.In(strideArr))
+            elif arg.isFieldStrideArgument:
+                shapeArr = np.array(field.shape, dtype=np.int32)
+                result.append(cuda.In(shapeArr))
+            else:
+                assert False
+        else:
+            param = argumentDict[arg.name]
+            expectedType = numpyTypeFromString(arg.dtype)
+            result.append(expectedType(param))
+    return result
+
+
+def makePythonFunction(kernelFunctionNode, argumentDict):
+    mod = SourceModule(str(kernelFunctionNode.generateC()))
+    func = mod.get_function(kernelFunctionNode.functionName)
+
+    # 1) get argument list
+    args = buildNumpyArgumentList(kernelFunctionNode, argumentDict)
+
+    # 2) determine block and grid tuples
+        
+    # TODO prepare the function here
+
diff --git a/field.py b/field.py
new file mode 100644
index 000000000..5067bc9de
--- /dev/null
+++ b/field.py
@@ -0,0 +1,385 @@
+from itertools import chain
+import numpy as np
+import sympy as sp
+from sympy.core.cache import cacheit
+from sympy.tensor import IndexedBase
+from pystencils.typedsymbol import TypedSymbol
+
+
+def getLayoutFromNumpyArray(arr):
+    """
+    Returns a list indicating the memory layout (linearization order) of the numpy array.
+    Example:
+        >>> getLayoutFromNumpyArray(np.zeros([3,3,3]))
+        [0, 1, 2]
+    In this example the loop over the zeroth coordinate should be the outermost loop,
+    followed by the first and second. Elements arr[x,y,0] and arr[x,y,1] are adjacent in memory.
+    Normally constructed numpy arrays have this order, however by stride tricks or other frameworks, arrays
+    with different memory layout can be created.
+    """
+    coordinates = list(range(len(arr.shape)))
+    return [x for (y, x) in sorted(zip(arr.strides, coordinates), key=lambda pair: pair[0], reverse=True)]
+
+
+def numpyDataTypeToC(dtype):
+    """Mapping numpy data types to C data types"""
+    if dtype == np.float64:
+        return "double"
+    elif dtype == np.float32:
+        return "float"
+    elif dtype == np.int32:
+        return "int"
+    raise NotImplementedError()
+
+
+def offsetComponentToDirectionString(coordinateId, value):
+    """
+    Translates numerical offset to string notation.
+    x offsets are labeled with east 'E' and 'W',
+    y offsets with north 'N' and 'S' and
+    z offsets with top 'T' and bottom 'B'
+    If the absolute value of the offset is bigger than 1, this number is prefixed.
+    :param coordinateId: integer 0, 1 or 2 standing for x,y and z
+    :param value: integer offset
+
+    Example:
+    >>> offsetComponentToDirectionString(0, 1)
+    'E'
+    >>> offsetComponentToDirectionString(1, 2)
+    '2N'
+    """
+    nameComponents = (('W', 'E'),  # west, east
+                      ('S', 'N'),  # south, north
+                      ('B', 'T'),  # bottom, top
+                      )
+    if value == 0:
+        result = ""
+    elif value < 0:
+        result = nameComponents[coordinateId][0]
+    else:
+        result = nameComponents[coordinateId][1]
+    if abs(value) > 1:
+        result = "%d%s" % (abs(value), result)
+    return result
+
+
+def offsetToDirectionString(offsetTuple):
+    """
+    Translates numerical offset to string notation.
+    For details see :func:`offsetComponentToDirectionString`
+    :param offsetTuple: 3-tuple with x,y,z offset
+
+    Example:
+    >>> offsetToDirectionString([1, -1, 0])
+    'SE'
+    >>> offsetToDirectionString(([-3, 0, -2]))
+    '2B3W'
+    """
+    names = ["", "", ""]
+    for i in range(len(offsetTuple)):
+        names[i] = offsetComponentToDirectionString(i, offsetTuple[i])
+    name = "".join(reversed(names))
+    if name == "":
+        name = "C"
+    return name
+
+
+def directionStringToOffset(directionStr, dim=3):
+    """
+    Reverse mapping of :func:`offsetToDirectionString`
+    :param directionStr: string representation of offset
+    :param dim: dimension of offset, i.e the length of the returned list
+
+    >>> directionStringToOffset('NW', dim=3)
+    array([-1,  1,  0])
+    >>> directionStringToOffset('NW', dim=2)
+    array([-1,  1])
+    >>> directionStringToOffset(offsetToDirectionString([3,-2,1]))
+    array([ 3, -2,  1])
+    """
+    offsetMap = {
+        'C': np.array([0, 0, 0]),
+
+        'W': np.array([-1, 0, 0]),
+        'E': np.array([1, 0, 0]),
+
+        'S': np.array([0, -1, 0]),
+        'N': np.array([0, 1, 0]),
+
+        'B': np.array([0, 0, -1]),
+        'T': np.array([0, 0, 1]),
+    }
+    offset = np.array([0, 0, 0])
+
+    while len(directionStr) > 0:
+        factor = 1
+        firstNonDigit = 0
+        while directionStr[firstNonDigit].isdigit():
+            firstNonDigit += 1
+        if firstNonDigit > 0:
+            factor = int(directionStr[:firstNonDigit])
+            directionStr = directionStr[firstNonDigit:]
+        curOffset = offsetMap[directionStr[0]]
+        offset += factor * curOffset
+        directionStr = directionStr[1:]
+    return offset[:dim]
+
+
+class Field:
+    """
+    With fields one can formulate stencil-like update rules on structured grids.
+    This Field class knows about the dimension, memory layout (strides) and optionally about the size of an array.
+
+    To create a field use one of the static create* members. There are two options:
+        1. create a kernel with fixed loop sizes i.e. the shape of the array is already known. This is usually the
+           case if just-in-time compilation directly from Python is done. (see Field.createFromNumpyArray)
+        2. create a more general kernel that works for variable array sizes. This can be used to create kernels
+           beforehand for a library. (see Field.createGeneric)
+
+    Dimensions:
+        A field has spatial and index dimensions, where the spatial dimensions come first.
+        The interpretation is that the field has multiple cells in (usually) two or three dimensional space which are
+        looped over. Additionally  N values are stored per cell. In this case spatialDimensions is two or three,
+        and indexDimensions equals N. If you want to store a matrix on each point in a two dimensional grid, there
+        are four dimensions, two spatial and two index dimensions. len(arr.shape) == spatialDims + indexDims
+
+    Indexing:
+        When accessing (indexing) a field the result is a FieldAccess which is derived from sympy Symbol.
+        First specify the spatial offsets in [], then in case indexDimension>0 the indices in ()
+        e.g. f[-1,0,0](7)
+
+    Example without index dimensions:
+        >>> a = np.zeros([10, 10])
+        >>> f = Field.createFromNumpyArray("f", a, indexDimensions=0)
+        >>> jacobi = ( f[-1,0] + f[1,0] + f[0,-1] + f[0,1] ) / 4
+
+    Example with index dimensions: LBM D2Q9 stream pull
+        >>> stencil = np.array([[0,0], [0,1], [0,-1]])
+        >>> src = Field.createGeneric("src", spatialDimensions=2, indexDimensions=1)
+        >>> dst = Field.createGeneric("dst", spatialDimensions=2, indexDimensions=1)
+        >>> for i, offset in enumerate(stencil):
+        ...     sp.Eq(dst[0,0](i), src[-offset](i))
+        Eq(dst_C^0, src_C^0)
+        Eq(dst_C^1, src_S^1)
+        Eq(dst_C^2, src_N^2)
+    """
+    @staticmethod
+    def createFromNumpyArray(fieldName, npArray, indexDimensions=0):
+        """
+        Creates a field based on the layout, data type, and shape of a given numpy array.
+        Kernels created for these kind of fields can only be called with arrays of the same layout, shape and type.
+        :param fieldName: symbolic name for the field
+        :param npArray: numpy array
+        :param indexDimensions: see documentation of Field
+        """
+        spatialDimensions = len(npArray.shape) - indexDimensions
+        if spatialDimensions < 1:
+            raise ValueError("Too many index dimensions. At least one spatial dimension required")
+
+        fullLayout = getLayoutFromNumpyArray(npArray)
+        spatialLayout = tuple([i for i in fullLayout if i < spatialDimensions])
+        assert len(spatialLayout) == spatialDimensions
+
+        strides = tuple([s // np.dtype(npArray.dtype).itemsize for s in npArray.strides])
+        shape = tuple([int(s) for s in npArray.shape])
+
+        return Field(fieldName, npArray.dtype, spatialLayout, shape, strides)
+
+    @staticmethod
+    def createGeneric(fieldName, spatialDimensions, dtype=np.float64, indexDimensions=0, layout=None):
+        """
+        Creates a generic field where the field size is not fixed i.e. can be called with arrays of different sizes
+        :param fieldName: symbolic name for the field
+        :param dtype: numpy data type of the array the kernel is called with later
+        :param spatialDimensions: see documentation of Field
+        :param indexDimensions: see documentation of Field
+        :param layout: tuple specifying the loop ordering of the spatial dimensions e.g. (2, 1, 0 ) means that
+                       the outer loop loops over dimension 2, the second outer over dimension 1, and the inner loop
+                       over dimension 0
+        """
+        if not layout:
+            layout = tuple(reversed(range(spatialDimensions)))
+        if len(layout) != spatialDimensions:
+            raise ValueError("Layout")
+        shapeSymbol = IndexedBase(TypedSymbol(Field.SHAPE_PREFIX + fieldName, Field.SHAPE_DTYPE), shape=(1,))
+        strideSymbol = IndexedBase(TypedSymbol(Field.STRIDE_PREFIX + fieldName, Field.STRIDE_DTYPE), shape=(1,))
+        totalDimensions = spatialDimensions + indexDimensions
+        shape = tuple([shapeSymbol[i] for i in range(totalDimensions)])
+        strides = tuple([strideSymbol[i] for i in range(totalDimensions)])
+        return Field(fieldName, dtype, layout, shape, strides)
+
+    def __init__(self, fieldName, dtype, layout, shape, strides):
+        """Do not use directly. Use static create* methods"""
+        self._fieldName = fieldName
+        self._dtype = numpyDataTypeToC(dtype)
+        self._layout = layout
+        self._shape = shape
+        self._strides = strides
+        self._readonly = False
+
+    @property
+    def spatialDimensions(self):
+        return len(self._layout)
+
+    @property
+    def indexDimensions(self):
+        return len(self._shape) - len(self._layout)
+
+    @property
+    def layout(self):
+        return self._layout
+
+    @property
+    def name(self):
+        return self._fieldName
+
+    @property
+    def shape(self):
+        return self._shape
+
+    @property
+    def spatialShape(self):
+        return self._shape[:self.spatialDimensions]
+
+    @property
+    def indexShape(self):
+        return self._shape[self.spatialDimensions:]
+
+    @property
+    def spatialStrides(self):
+        return self._strides[:self.spatialDimensions]
+
+    @property
+    def indexStrides(self):
+        return self._strides[self.spatialDimensions:]
+
+    @property
+    def strides(self):
+        return self._strides
+
+    @property
+    def dtype(self):
+        return self._dtype
+
+    @property
+    def readOnly(self):
+        return self._readonly
+
+    def setReadOnly(self, value=True):
+        self._readonly = value
+
+    def __repr__(self):
+        return self._fieldName
+
+    def __getitem__(self, offset):
+        if type(offset) is np.ndarray:
+            offset = tuple(offset)
+        if type(offset) is str:
+            offset = tuple(directionStringToOffset(offset, self.spatialDimensions))
+        if type(offset) is not tuple:
+            offset = (offset,)
+        if len(offset) != self.spatialDimensions:
+            raise ValueError("Wrong number of spatial indices: "
+                             "Got %d, expected %d" % (len(offset), self.spatialDimensions))
+        return Field.Access(self, offset)
+
+    def __call__(self, *args, **kwargs):
+        center = tuple([0]*self.spatialDimensions)
+        return Field.Access(self, center)(*args, **kwargs)
+
+    def __hash__(self):
+        return hash((self._layout, self._shape, self._strides, self._dtype, self._fieldName))
+
+    def __eq__(self, other):
+        selfTuple = (self.shape, self.strides, self.name, self.dtype)
+        otherTuple = (other.shape, other.strides, other.name, other.dtype)
+        return selfTuple == otherTuple
+
+    PREFIX = "f"
+    STRIDE_PREFIX = PREFIX + "stride_"
+    SHAPE_PREFIX = PREFIX + "shape_"
+    STRIDE_DTYPE = "const int *"
+    SHAPE_DTYPE = "const int *"
+
+    class Access(sp.Symbol):
+        def __new__(cls, name, *args, **kwargs):
+            obj = Field.Access.__xnew_cached_(cls, name, *args, **kwargs)
+            return obj
+
+        def __new_stage2__(self, field, offsets=(0, 0, 0), idx=None):
+            fieldName = field.name
+            offsetsAndIndex = chain(offsets, idx) if idx is not None else offsets
+            constantOffsets = not any([isinstance(o, sp.Basic) for o in offsetsAndIndex])
+
+            if not idx:
+                idx = tuple([0] * field.indexDimensions)
+
+            if constantOffsets:
+                offsetName = offsetToDirectionString(offsets)
+
+                if field.indexDimensions == 0:
+                    obj = super(Field.Access, self).__xnew__(self, fieldName + "_" + offsetName)
+                elif field.indexDimensions == 1:
+                    obj = super(Field.Access, self).__xnew__(self, fieldName + "_" + offsetName + "^" + str(idx[0]))
+                else:
+                    idxStr = ",".join([str(e) for e in idx])
+                    obj = super(Field.Access, self).__xnew__(self, fieldName + "_" + offsetName + "^" + idxStr)
+
+            else:
+                offsetName = "%0.10X" % (abs(hash(tuple(offsetsAndIndex))))
+                obj = super(Field.Access, self).__xnew__(self, fieldName + "_" + offsetName)
+
+            obj._field = field
+            obj._offsets = []
+            for o in offsets:
+                if isinstance(o, sp.Basic):
+                    obj._offsets.append(o)
+                else:
+                    obj._offsets.append(int(o))
+            obj._offsetName = offsetName
+            obj._index = idx
+
+            return obj
+
+        __xnew__ = staticmethod(__new_stage2__)
+        __xnew_cached_ = staticmethod(cacheit(__new_stage2__))
+
+        def __call__(self, *idx):
+            if self._index != tuple([0]*self.field.indexDimensions):
+                print(self._index, tuple([0]*self.field.indexDimensions))
+                raise ValueError("Indexing an already indexed Field.Access")
+
+            idx = tuple(idx)
+            if len(idx) != self.field.indexDimensions and idx != (0,):
+                raise ValueError("Wrong number of indices: "
+                                 "Got %d, expected %d" % (len(idx), self.field.indexDimensions))
+            return Field.Access(self.field, self._offsets, idx)
+
+        @property
+        def field(self):
+            return self._field
+
+        @property
+        def offsets(self):
+            return self._offsets
+
+        @property
+        def requiredGhostLayers(self):
+            return int(np.max(np.abs(self._offsets)))
+
+        @property
+        def nrOfCoordinates(self):
+            return len(self._offsets)
+
+        @property
+        def offsetName(self):
+            return self._offsetName
+
+        @property
+        def index(self):
+            return self._index
+
+        def _hashable_content(self):
+            superClassContents = list(super(Field.Access, self)._hashable_content())
+            t = tuple([*superClassContents, hash(self._field), self._index] + self._offsets)
+            return t
diff --git a/finitedifferences.py b/finitedifferences.py
index 533b4b289..80357943d 100644
--- a/finitedifferences.py
+++ b/finitedifferences.py
@@ -1,6 +1,7 @@
-import sympy as sp
 import numpy as np
-from lbmpy.generator import Field
+import sympy as sp
+
+from pystencils.generator import Field
 
 
 def __upDownOffsets(d, dim):
diff --git a/generator.py b/generator.py
new file mode 100644
index 000000000..5a8dcec4c
--- /dev/null
+++ b/generator.py
@@ -0,0 +1,893 @@
+from collections import defaultdict
+import cgen as c
+import sympy as sp
+from sympy.logic.boolalg import Boolean
+from sympy.utilities.codegen import CCodePrinter
+from sympy.tensor import IndexedBase, Indexed
+from pystencils.field import Field, offsetComponentToDirectionString
+from pystencils.typedsymbol import TypedSymbol
+
+COORDINATE_LOOP_COUNTER_NAME = "ctr"
+FIELD_PTR_PREFIX = Field.PREFIX + "d_"
+
+
+# --------------------------------------- Helper Functions -------------------------------------------------------------
+
+
+class CodePrinter(CCodePrinter):
+    def _print_Pow(self, expr):
+        if expr.exp.is_integer and expr.exp.is_number and expr.exp > 0:
+            return '(' + '*'.join(["(" + self._print(expr.base) + ")"] * expr.exp) + ')'
+        else:
+            return super(CodePrinter, self)._print_Pow(expr)
+
+    def _print_Rational(self, expr):
+        return str(expr.evalf().num)
+
+    def _print_Equality(self, expr):
+        return '((' + self._print(expr.lhs) + ") == (" + self._print(expr.rhs) + '))'
+
+    def _print_Piecewise(self, expr):
+        result = super(CodePrinter, self)._print_Piecewise(expr)
+        return result.replace("\n", "")
+
+codePrinter = CodePrinter()
+
+
+class MyPOD(c.Declarator):
+    def __init__(self, dtype, name):
+        self.dtype = dtype
+        self.name = name
+
+    def get_decl_pair(self):
+        return [self.dtype], self.name
+
+
+def getNextParentOfType(node, parentType):
+    parent = node.parent
+    while parent is not None:
+        if isinstance(parent, parentType):
+            return parent
+        parent = parent.parent
+    return None
+
+
+# --------------------------------------- AST Nodes  -------------------------------------------------------------------
+
+
+class Node:
+    def __init__(self, parent=None):
+        self.parent = parent
+
+    def args(self):
+        return []
+
+    def atoms(self, argType):
+        result = set()
+        for arg in self.args:
+            if isinstance(arg, argType):
+                result.add(arg)
+            result.update(arg.atoms(argType))
+        return result
+
+
+class DebugNode(Node):
+    def __init__(self, code, symbolsRead=[]):
+        self._code = code
+        self._symbolsRead = set(symbolsRead)
+
+    @property
+    def args(self):
+        return []
+
+    @property
+    def symbolsDefined(self):
+        return set()
+
+    @property
+    def symbolsRead(self):
+        return self._symbolsRead
+
+    def generateC(self):
+        return c.LiteralLines(self._code)
+
+
+class PrintNode(DebugNode):
+    def __init__(self, symbolToPrint):
+        code = '\nstd::cout << "%s  =  " << %s << std::endl; \n' % (symbolToPrint.name, symbolToPrint.name)
+        super(PrintNode, self).__init__(code, [symbolToPrint])
+
+
+class KernelFunction(Node):
+
+    class Argument:
+        def __init__(self, name, dtype):
+            self.name = name
+            self.dtype = dtype
+            self.isFieldPtrArgument = False
+            self.isFieldShapeArgument = False
+            self.isFieldStrideArgument = False
+            self.isFieldArgument = False
+            self.fieldName = ""
+            self.coordinate = None
+
+            if name.startswith(FIELD_PTR_PREFIX):
+                self.isFieldPtrArgument = True
+                self.isFieldArgument = True
+                self.fieldName = name[len(FIELD_PTR_PREFIX):]
+            elif name.startswith(Field.SHAPE_PREFIX):
+                self.isFieldShapeArgument = True
+                self.isFieldArgument = True
+                self.fieldName = name[len(Field.SHAPE_PREFIX):]
+            elif name.startswith(Field.STRIDE_PREFIX):
+                self.isFieldStrideArgument = True
+                self.isFieldArgument = True
+                self.fieldName = name[len(Field.STRIDE_PREFIX):]
+
+    def __init__(self, body, functionName="kernel"):
+        super(KernelFunction, self).__init__()
+        self._body = body
+        self._parameters = None
+        self._functionName = functionName
+        self._body.parent = self
+        self.variablesToIgnore = set()
+        self.qualifierPrefix = ""
+
+    @property
+    def symbolsDefined(self):
+        return set()
+
+    @property
+    def symbolsRead(self):
+        return set()
+
+    @property
+    def parameters(self):
+        self._updateArguments()
+        return self._parameters
+
+    @property
+    def body(self):
+        return self._body
+
+    @property
+    def args(self):
+        return [self._body]
+
+    @property
+    def functionName(self):
+        return self._functionName
+
+    def _updateArguments(self):
+        undefinedSymbols = self._body.symbolsRead - self._body.symbolsDefined - self.variablesToIgnore
+        self._parameters = [KernelFunction.Argument(s.name, s.dtype) for s in undefinedSymbols]
+        self._parameters.sort(key=lambda l: (l.fieldName, l.isFieldPtrArgument, l.isFieldShapeArgument,
+                                             l.isFieldStrideArgument, l.name),
+                              reverse=True)
+
+    def generateC(self):
+        self._updateArguments()
+        functionArguments = [MyPOD(s.dtype, s.name) for s in self._parameters]
+        functionPOD = MyPOD(self.qualifierPrefix + "void", self._functionName, )
+        funcDeclaration = c.FunctionDeclaration(functionPOD, functionArguments)
+        return c.FunctionBody(funcDeclaration, self._body.generateC())
+
+
+class Block(Node):
+    def __init__(self, listOfNodes):
+        super(Node, self).__init__()
+        self._nodes = listOfNodes
+        for n in self._nodes:
+            n.parent = self
+
+    @property
+    def args(self):
+        return self._nodes
+
+    def insertFront(self, node):
+        node.parent = self
+        self._nodes.insert(0, node)
+
+    def append(self, node):
+        node.parent = self
+        self._nodes.append(node)
+
+    def generateC(self):
+        return c.Block([e.generateC() for e in self.args])
+
+    def takeChildNodes(self):
+        tmp = self._nodes
+        self._nodes = []
+        return tmp
+
+    def replace(self, child, replacements):
+        idx = self._nodes.index(child)
+        del self._nodes[idx]
+        if type(replacements) is list:
+            for e in replacements:
+                e.parent = self
+            self._nodes = self._nodes[:idx] + replacements + self._nodes[idx:]
+        else:
+            replacements.parent = self
+            self._nodes.insert(idx, replacements)
+
+    @property
+    def symbolsDefined(self):
+        result = set()
+        for a in self.args:
+            result.update(a.symbolsDefined)
+        return result
+
+    @property
+    def symbolsRead(self):
+        result = set()
+        for a in self.args:
+            result.update(a.symbolsRead)
+        return result
+
+
+class PragmaBlock(Block):
+    def __init__(self, pragmaLine, listOfNodes):
+        super(PragmaBlock, self).__init__(listOfNodes)
+        self._pragmaLine = pragmaLine
+
+    def generateC(self):
+        class PragmaGenerable(c.Generable):
+            def __init__(self, line, block):
+                self._line = line
+                self._block = block
+
+            def generate(self):
+                yield self._line
+                for e in self._block.generate():
+                    yield e
+
+        return PragmaGenerable(self._pragmaLine, super(PragmaBlock, self).generateC())
+
+
+class LoopOverCoordinate(Node):
+
+    def __init__(self, body, coordinateToLoopOver, shape, increment=1, ghostLayers=1,
+                 isInnermostLoop=False, isOutermostLoop=False):
+        self._body = body
+        self._coordinateToLoopOver = coordinateToLoopOver
+        self._shape = shape
+        self._increment = increment
+        self._ghostLayers = ghostLayers
+        self._body.parent = self
+        self.prefixLines = []
+        self._isInnermostLoop = isInnermostLoop
+        self._isOutermostLoop = isOutermostLoop
+
+    def newLoopWithDifferentBody(self, newBody):
+        result = LoopOverCoordinate(newBody, self._coordinateToLoopOver, self._shape, self._increment,
+                                    self._ghostLayers, self._isInnermostLoop, self._isOutermostLoop)
+        result.prefixLines = self.prefixLines
+        return result
+
+    @property
+    def args(self):
+        result = [self._body]
+        limit = self._shape[self._coordinateToLoopOver]
+        if isinstance(limit, Node) or isinstance(limit, sp.Basic):
+            result.append(limit)
+        return result
+
+    @property
+    def body(self):
+        return self._body
+
+    @property
+    def loopCounterName(self):
+        return "%s_%s" % (COORDINATE_LOOP_COUNTER_NAME, self._coordinateToLoopOver)
+
+    @property
+    def coordinateToLoopOver(self):
+        return self._coordinateToLoopOver
+
+    @property
+    def symbolsDefined(self):
+        result = self._body.symbolsDefined
+        result.add(self.loopCounterSymbol)
+        return result
+
+    @property
+    def loopCounterSymbol(self):
+        return TypedSymbol(self.loopCounterName, "int")
+
+    @property
+    def symbolsRead(self):
+        result = self._body.symbolsRead
+        limit = self._shape[self._coordinateToLoopOver]
+        if isinstance(limit, sp.Basic):
+            result.update(limit.atoms(sp.Symbol))
+        return result
+
+    @property
+    def isOutermostLoop(self):
+        return self._isOutermostLoop
+
+    @property
+    def isInnermostLoop(self):
+        return self._isInnermostLoop
+
+    @property
+    def coordinateToLoopOver(self):
+        return self._coordinateToLoopOver
+
+    @property
+    def iterationRegionWithGhostLayer(self):
+        return self._shape[self._coordinateToLoopOver]
+
+    def generateC(self):
+        coord = self._coordinateToLoopOver
+        end = self._shape[coord] - self._ghostLayers
+
+        counterVar = self.loopCounterName
+
+        class LoopWithOptionalPrefix(c.CustomLoop):
+            def __init__(self, intro_line, body, prefixLines=[]):
+                super(LoopWithOptionalPrefix, self).__init__(intro_line, body)
+                self.prefixLines = prefixLines
+
+            def generate(self):
+                for l in self.prefixLines:
+                    yield l
+
+                for e in super(LoopWithOptionalPrefix, self).generate():
+                    yield e
+
+        start = "int %s = %d" % (counterVar, self._ghostLayers)
+        condition = "%s < %s" % (counterVar, codePrinter.doprint(end))
+        update = "++%s" % (counterVar,)
+        loopStr = "for (%s; %s; %s)" % (start, condition, update)
+        return LoopWithOptionalPrefix(loopStr, self._body.generateC(), prefixLines=self.prefixLines)
+
+
+class SympyAssignment(Node):
+
+    def __init__(self, lhsSymbol, rhsTerm, isConst=True):
+        self._lhsSymbol = lhsSymbol
+        self.rhs = rhsTerm
+        self._isDeclaration = True
+        if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, IndexedBase):
+            self._isDeclaration = False
+        self._isConst = isConst
+
+    @property
+    def lhs(self):
+        return self._lhsSymbol
+
+    @lhs.setter
+    def lhs(self, newValue):
+        self._lhsSymbol = newValue
+        self._isDeclaration = True
+        if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, Indexed):
+            self._isDeclaration = False
+
+    @property
+    def args(self):
+        return [self._lhsSymbol, self.rhs]
+
+    @property
+    def symbolsDefined(self):
+        if not self._isDeclaration:
+            return set()
+        return set([self._lhsSymbol])
+
+    @property
+    def symbolsRead(self):
+        result = self.rhs.atoms(sp.Symbol)
+        result.update(self._lhsSymbol.atoms(sp.Symbol))
+        return result
+
+    @property
+    def isConst(self):
+        return self._isConst
+
+    def __repr__(self):
+        return repr(self.lhs) + " = " + repr(self.rhs)
+
+    def generateC(self):
+        dtype = ""
+        if hasattr(self._lhsSymbol, 'dtype') and self._isDeclaration:
+            if self._isConst:
+                dtype = "const " + self._lhsSymbol.dtype + " "
+            else:
+                dtype = self._lhsSymbol.dtype + " "
+
+        return c.Assign(dtype + codePrinter.doprint(self._lhsSymbol),
+                        codePrinter.doprint(self.rhs))
+
+
+class CustomCppCode(Node):
+    def __init__(self, code, symbolsRead, symbolsDefined):
+        self._code = "\n" + code
+        self._symbolsRead = set(symbolsRead)
+        self._symbolsDefined = set(symbolsDefined)
+
+    @property
+    def args(self):
+        return []
+
+    @property
+    def symbolsDefined(self):
+        return self._symbolsDefined
+
+    @property
+    def symbolsRead(self):
+        return self._symbolsRead
+
+    def generateC(self):
+        return c.LiteralLines(self._code)
+
+
+class TemporaryArrayDefinition(Node):
+    def __init__(self, typedSymbol, size):
+        self._symbol = typedSymbol
+        self._size = size
+
+    @property
+    def symbolsDefined(self):
+        return set([self._symbol])
+
+    @property
+    def symbolsRead(self):
+        return set()
+
+    def generateC(self):
+        return c.Assign(self._symbol.dtype + " * " + codePrinter.doprint(self._symbol),
+                        "new %s[%s]" % (self._symbol.dtype, codePrinter.doprint(self._size)))
+
+    @property
+    def args(self):
+        return [self._symbol]
+
+
+class TemporaryArrayDelete(Node):
+    def __init__(self, typedSymbol):
+        self._symbol = typedSymbol
+
+    @property
+    def symbolsDefined(self):
+        return set()
+
+    @property
+    def symbolsRead(self):
+        return set()
+
+    def generateC(self):
+        return c.Statement("delete [] %s" % (codePrinter.doprint(self._symbol),))
+
+    @property
+    def args(self):
+        return []
+
+
+# --------------------------------------- Factory Functions ------------------------------------------------------------
+
+
+def getOptimalLoopOrdering(fields):
+    assert len(fields) > 0
+    refField = next(iter(fields))
+    for field in fields:
+        if field.spatialDimensions != refField.spatialDimensions:
+            raise ValueError("All fields have to have the same number of spatial dimensions")
+
+    layouts = set([field.layout for field in fields])
+    if len(layouts) > 1:
+        raise ValueError("Due to different layout of the fields no optimal loop ordering exists")
+    layout = list(layouts)[0]
+    return list(reversed(layout))
+
+
+def makeLoopOverDomain(body, functionName):
+    """
+    :param body: list of nodes
+    :param functionName: name of generated C function
+    :return: LoopOverCoordinate instance with nested loops, ordered according to field layouts
+    """
+    # find correct ordering by inspecting participating FieldAccesses
+    fieldAccesses = body.atoms(Field.Access)
+    fieldList = [e.field for e in fieldAccesses]
+    fields = set(fieldList)
+    loopOrder = getOptimalLoopOrdering(fields)
+
+    # find number of required ghost layers
+    requiredGhostLayers = max([fa.requiredGhostLayers for fa in fieldAccesses])
+
+    shapes = set([f.spatialShape for f in fields])
+
+    if len(shapes) > 1:
+        nrOfFixedSizedFields = 0
+        for shape in shapes:
+            if not isinstance(shape[0], sp.Basic):
+                nrOfFixedSizedFields += 1
+        assert nrOfFixedSizedFields <= 1, "Differently sized field accesses in loop body: " + str(shapes)
+    shape = list(shapes)[0]
+
+    currentBody = body
+    lastLoop = None
+    for i, loopCoordinate in enumerate(loopOrder):
+        newLoop = LoopOverCoordinate(currentBody, loopCoordinate, shape, 1, requiredGhostLayers,
+                                     isInnermostLoop=(i == 0), isOutermostLoop=(i == len(loopOrder) - 1))
+        lastLoop = newLoop
+        currentBody = Block([lastLoop])
+    return KernelFunction(currentBody, functionName)
+
+
+# --------------------------------------- Transformations --------------------------------------------------------------
+
+def createIntermediateBasePointer(fieldAccess, coordinates, previousPtr):
+    field = fieldAccess.field
+
+    offset = 0
+    name = ""
+    listToHash = []
+    for coordinateId, coordinateValue in coordinates.items():
+        offset += field.strides[coordinateId] * coordinateValue
+
+        if coordinateId < field.spatialDimensions:
+            offset += field.strides[coordinateId] * fieldAccess.offsets[coordinateId]
+            if type(fieldAccess.offsets[coordinateId]) is int:
+                offsetComp = offsetComponentToDirectionString(coordinateId, fieldAccess.offsets[coordinateId])
+                name += "_"
+                name += offsetComp if offsetComp else "C"
+            else:
+                listToHash.append(fieldAccess.offsets[coordinateId])
+        else:
+            if type(coordinateValue) is int:
+                name += "_%d" % (coordinateValue,)
+            else:
+                listToHash.append(coordinateValue)
+
+    if len(listToHash) > 0:
+        name += "%0.6X" % (abs(hash(tuple(listToHash))))
+
+    newPtr = TypedSymbol(previousPtr.name + name, previousPtr.dtype)
+    return newPtr, offset
+
+
+def parseBasePointerInfo(basePointerSpecification, loopOrder, field):
+    """
+    Allowed specifications:
+    "spatialInner<int>" spatialInner0 is the innermost loop coordinate, spatialInner1 the loop enclosing the innermost
+    "spatialOuter<int>" spatialOuter0 is the outermost loop
+    "index<int>": index coordinate
+    "<int>": specifying directly the coordinate
+    :param basePointerSpecification: nested list with above specifications
+    :param loopOrder: list with ordering of loops from inner to outer
+    :param field:
+    :return:
+    """
+    result = []
+    specifiedCoordinates = set()
+    for specGroup in basePointerSpecification:
+        newGroup = []
+
+        def addNewElement(i):
+            if i >= field.spatialDimensions + field.indexDimensions:
+                raise ValueError("Coordinate %d does not exist" % (i,))
+            newGroup.append(i)
+            if i in specifiedCoordinates:
+                raise ValueError("Coordinate %d specified two times" % (i,))
+            specifiedCoordinates.add(i)
+
+        for element in specGroup:
+            if type(element) is int:
+                addNewElement(element)
+            elif element.startswith("spatial"):
+                element = element[len("spatial"):]
+                if element.startswith("Inner"):
+                    index = int(element[len("Inner"):])
+                    addNewElement(loopOrder[index])
+                elif element.startswith("Outer"):
+                    index = int(element[len("Outer"):])
+                    addNewElement(loopOrder[-index])
+                elif element == "all":
+                    for i in range(field.spatialDimensions):
+                        addNewElement(i)
+                else:
+                    raise ValueError("Could not parse " + element)
+            elif element.startswith("index"):
+                index = int(element[len("index"):])
+                addNewElement(field.spatialDimensions + index)
+            else:
+                raise ValueError("Unknown specification %s" % (element,))
+
+        result.append(newGroup)
+
+    allCoordinates = set(range(field.spatialDimensions + field.indexDimensions))
+    rest = allCoordinates - specifiedCoordinates
+    if rest:
+        result.append(list(rest))
+    return result
+
+
+def getLoopHierarchy(block):
+    result = []
+    node = block
+    while node is not None:
+        node = getNextParentOfType(node, LoopOverCoordinate)
+        if node:
+            result.append(node.coordinateToLoopOver)
+    return result
+
+
+def resolveFieldAccesses(ast, fieldToBasePointerInfo={}, fieldToFixedCoordinates={}):
+    """Substitutes FieldAccess nodes by array indexing"""
+
+    def visitSympyExpr(expr, enclosingBlock):
+        if isinstance(expr, Field.Access):
+            fieldAccess = expr
+            field = fieldAccess.field
+            if field.name in fieldToBasePointerInfo:
+                basePointerInfo = fieldToBasePointerInfo[field.name]
+            else:
+                basePointerInfo = [list(range(field.indexDimensions + field.spatialDimensions))]
+
+            dtype = "%s * __restrict__" % field.dtype
+            if field.readOnly:
+                dtype = "const " + dtype
+
+            fieldPtr = TypedSymbol("%s%s" % (FIELD_PTR_PREFIX, field.name), dtype)
+
+            lastPointer = fieldPtr
+
+            def createCoordinateDict(group):
+                coordDict = {}
+                for e in group:
+                    if e < field.spatialDimensions:
+                        if field.name in fieldToFixedCoordinates:
+                            coordDict[e] = fieldToFixedCoordinates[field.name][e]
+                        else:
+                            coordDict[e] = TypedSymbol("%s_%d" % (COORDINATE_LOOP_COUNTER_NAME, e), "int")
+                    else:
+                        coordDict[e] = fieldAccess.index[e-field.spatialDimensions]
+                return coordDict
+
+            for group in reversed(basePointerInfo[1:]):
+                coordDict = createCoordinateDict(group)
+                newPtr, offset = createIntermediateBasePointer(fieldAccess, coordDict, lastPointer)
+                if newPtr not in enclosingBlock.symbolsDefined:
+                    enclosingBlock.insertFront(SympyAssignment(newPtr, lastPointer + offset, isConst=False))
+                lastPointer = newPtr
+
+            _, offset = createIntermediateBasePointer(fieldAccess, createCoordinateDict(basePointerInfo[0]), lastPointer)
+            baseArr = IndexedBase(lastPointer, shape=(1,))
+            return baseArr[offset]
+        else:
+            newArgs = [visitSympyExpr(e, enclosingBlock) for e in expr.args]
+            kwargs = {'evaluate': False} if type(expr) is sp.Add or type(expr) is sp.Mul else {}
+            return expr.func(*newArgs, **kwargs) if newArgs else expr
+
+    def visitNode(subAst):
+        if isinstance(subAst, SympyAssignment):
+            enclosingBlock = subAst.parent
+            assert type(enclosingBlock) is Block
+            subAst.lhs = visitSympyExpr(subAst.lhs, enclosingBlock)
+            subAst.rhs = visitSympyExpr(subAst.rhs, enclosingBlock)
+        else:
+            for i, a in enumerate(subAst.args):
+                visitNode(a)
+
+    return visitNode(ast)
+
+
+def moveConstantsBeforeLoop(ast):
+
+    def findBlockToMoveTo(node):
+        """Traverses parents of node as long as the symbols are independent and returns a (parent) block
+        the assignment can be safely moved to
+        :param node: SympyAssignment inside a Block"""
+        assert isinstance(node, SympyAssignment)
+        assert isinstance(node.parent, Block)
+
+        lastBlock = node.parent
+        element = node.parent
+        while element:
+            if isinstance(element, Block):
+                lastBlock = element
+            if node.symbolsRead.intersection(element.symbolsDefined):
+                break
+            element = element.parent
+        return lastBlock
+
+    def checkIfAssignmentAlreadyInBlock(assignment, targetBlock):
+        for arg in targetBlock.args:
+            if type(arg) is not SympyAssignment:
+                continue
+            if arg.lhs == assignment.lhs:
+                return arg
+        return None
+
+    for block in ast.atoms(Block):
+        children = block.takeChildNodes()
+        for child in children:
+            if not isinstance(child, SympyAssignment):
+                block.append(child)
+            else:
+                target = findBlockToMoveTo(child)
+                if target == block:     # movement not possible
+                    target.append(child)
+                else:
+                    existingAssignment = checkIfAssignmentAlreadyInBlock(child, target)
+                    if not existingAssignment:
+                        target.insertFront(child)
+                    else:
+                        assert existingAssignment.rhs == child.rhs, "Symbol with same name exists already"
+
+
+def splitInnerLoop(ast, symbolGroups):
+    allLoops = ast.atoms(LoopOverCoordinate)
+    innerLoop = [l for l in allLoops if l.isInnermostLoop]
+    assert len(innerLoop) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?"
+    innerLoop = innerLoop[0]
+    assert type(innerLoop.body) is Block
+    outerLoop = [l for l in allLoops if l.isOutermostLoop]
+    assert len(outerLoop) == 1, "Error in AST, multiple outermost loops."
+    outerLoop = outerLoop[0]
+
+    symbolsWithTemporaryArray = dict()
+
+    assignmentMap = {a.lhs: a for a in innerLoop.body.args}
+
+    assignmentGroups = []
+    for symbolGroup in symbolGroups:
+        # get all dependent symbols
+        symbolsToProcess = list(symbolGroup)
+        symbolsResolved = set()
+        while symbolsToProcess:
+            s = symbolsToProcess.pop()
+            if s in symbolsResolved:
+                continue
+
+            if s in assignmentMap:  # if there is no assignment inside the loop body it is independent already
+                for newSymbol in assignmentMap[s].rhs.atoms(sp.Symbol):
+                    if type(newSymbol) is not Field.Access and newSymbol not in symbolsWithTemporaryArray:
+                        symbolsToProcess.append(newSymbol)
+            symbolsResolved.add(s)
+
+        for symbol in symbolGroup:
+            if type(symbol) is not Field.Access:
+                assert type(symbol) is TypedSymbol
+                symbolsWithTemporaryArray[symbol] = IndexedBase(symbol, shape=(1,))[innerLoop.loopCounterSymbol]
+
+        assignmentGroup = []
+        for assignment in innerLoop.body.args:
+            if assignment.lhs in symbolsResolved:
+                newRhs = assignment.rhs.subs(symbolsWithTemporaryArray.items())
+                if type(assignment.lhs) is not Field.Access and assignment.lhs in symbolGroup:
+                    newLhs = IndexedBase(assignment.lhs, shape=(1,))[innerLoop.loopCounterSymbol]
+                else:
+                    newLhs = assignment.lhs
+                assignmentGroup.append(SympyAssignment(newLhs, newRhs))
+        assignmentGroups.append(assignmentGroup)
+
+    newLoops = [innerLoop.newLoopWithDifferentBody(Block(group)) for group in assignmentGroups]
+    innerLoop.parent.replace(innerLoop, newLoops)
+
+    for tmpArray in symbolsWithTemporaryArray:
+        outerLoop.parent.insertFront(TemporaryArrayDefinition(tmpArray, innerLoop.iterationRegionWithGhostLayer))
+        outerLoop.parent.append(TemporaryArrayDelete(tmpArray))
+
+
+# ------------------------------------- Main ---------------------------------------------------------------------------
+
+
+def extractCommonSubexpressions(equations):
+    """Uses sympy to find common subexpressions in equations and returns
+    them in a topologically sorted order, ready for evaluation"""
+    replacements, newEq = sp.cse(equations)
+    replacementEqs = [sp.Eq(*r) for r in replacements]
+    equations = replacementEqs + newEq
+    topologicallySortedPairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in equations])
+    equations = [sp.Eq(*a) for a in topologicallySortedPairs]
+    return equations
+
+
+def addOpenMP(ast):
+    assert type(ast) is KernelFunction
+    body = ast.body
+    wrapperBlock = PragmaBlock('#pragma omp parallel', body.takeChildNodes())
+    body.append(wrapperBlock)
+
+    outerLoops = [l for l in body.atoms(LoopOverCoordinate) if l.isOutermostLoop]
+    assert outerLoops, "No outer loop found"
+    assert len(outerLoops) <= 1, "More than one outer loop found. Which one should be parallelized?"
+    outerLoops[0].prefixLines.append("#pragma omp for schedule(static)")
+
+
+def typeAllEquations(eqs, typeForSymbol):
+    fieldsWritten = set()
+    fieldsRead = set()
+
+    def processRhs(term):
+        """Replaces Symbols by:
+            - TypedSymbol if symbol is not a field access
+        """
+        if isinstance(term, Field.Access):
+            fieldsRead.add(term.field)
+            return term
+        elif isinstance(term, sp.Symbol):
+            return TypedSymbol(term.name, typeForSymbol[term.name])
+        else:
+            newArgs = [processRhs(arg) for arg in term.args]
+            return term.func(*newArgs) if newArgs else term
+
+    def processLhs(term):
+        """Replaces symbol by TypedSymbol and adds field to fieldsWriten"""
+        if isinstance(term, Field.Access):
+            fieldsWritten.add(term.field)
+            return term
+        elif isinstance(term, sp.Symbol):
+            return TypedSymbol(term.name, typeForSymbol[term.name])
+        else:
+            assert False, "Expected a symbol as left-hand-side"
+
+    typedEquations = []
+    for eq in eqs:
+        if isinstance(eq, sp.Eq):
+            newLhs = processLhs(eq.lhs)
+            newRhs = processRhs(eq.rhs)
+            typedEquations.append(SympyAssignment(newLhs, newRhs))
+        else:
+            assert isinstance(eq, Node), "Only equations and ast nodes are allowed in input"
+            typedEquations.append(eq)
+
+    typedEquations = typedEquations
+
+    return fieldsRead, fieldsWritten, typedEquations
+
+
+def typingFromSympyInspection(eqs, defaultType="double"):
+    result = defaultdict(lambda: defaultType)
+    for eq in eqs:
+        if isinstance(eq.rhs, Boolean):
+            result[eq.lhs.name] = "bool"
+    return result
+
+
+def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, splitGroups=[]):
+    if not typeForSymbol:
+        typeForSymbol = typingFromSympyInspection(listOfEquations, "double")
+
+    def typeSymbol(term):
+        if isinstance(term, Field.Access) or isinstance(term, TypedSymbol):
+            return term
+        elif isinstance(term, sp.Symbol):
+            return TypedSymbol(term.name, typeForSymbol[term.name])
+        else:
+            raise ValueError("Term has to be field access or symbol")
+
+    fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol)
+    allFields = fieldsRead.union(fieldsWritten)
+
+    for field in allFields:
+        field.setReadOnly(False)
+    for field in fieldsRead - fieldsWritten:
+        field.setReadOnly()
+
+    body = Block(assignments)
+    code = makeLoopOverDomain(body, functionName)
+
+    if splitGroups:
+        typedSplitGroups = [[typeSymbol(s) for s in splitGroup] for splitGroup in splitGroups]
+        splitInnerLoop(code, typedSplitGroups)
+
+    loopOrder = getOptimalLoopOrdering(allFields)
+
+    basePointerInfo = [['spatialInner0'], ['spatialInner1']]
+    basePointerInfos = {f.name: parseBasePointerInfo(basePointerInfo, loopOrder, f) for f in allFields}
+
+    resolveFieldAccesses(code, fieldToBasePointerInfo=basePointerInfos)
+    moveConstantsBeforeLoop(code)
+    addOpenMP(code)
+
+    return code
+
+
+if __name__ == "__main__":
+    f = Field.createGeneric('f', 3, indexDimensions=1)
+    pointerSpec = [['spatialInner0']]
+    parseBasePointerInfo(pointerSpec, [0, 1, 2], f)
\ No newline at end of file
diff --git a/jit.py b/jit.py
new file mode 100644
index 000000000..fdaf9800e
--- /dev/null
+++ b/jit.py
@@ -0,0 +1,149 @@
+import os
+import subprocess
+from ctypes import cdll, c_double, c_float, sizeof
+from tempfile import TemporaryDirectory
+
+import numpy as np
+
+
+CONFIG_GCC = {
+    'compiler': 'g++',
+    'flags': '-Ofast -DNDEBUG -fPIC -shared -march=native -fopenmp',
+}
+CONFIG_INTEL = {
+    'compiler': '/software/intel/2017/bin/icpc',
+    'flags': '-Ofast -DNDEBUG -fPIC -shared -march=native -fopenmp -Wl,-rpath=/software/intel/2017/lib/intel64',
+    'env': {
+        'INTEL_LICENSE_FILE': '1713@license4.rrze.uni-erlangen.de',
+        'LM_PROJECT': 'iwia',
+    }
+}
+CONFIG_CLANG = {
+    'compiler': 'clang++',
+    'flags': '-Ofast -DNDEBUG -fPIC -shared -march=native -fopenmp',
+}
+CONFIG = CONFIG_INTEL
+
+
+def ctypeFromString(typename, includePointers=True):
+    import ctypes as ct
+
+    typename = typename.replace("*", " * ")
+    typeComponents = typename.split()
+
+    basicTypeMap = {
+        'double': ct.c_double,
+        'float': ct.c_float,
+        'int': ct.c_int,
+        'long': ct.c_long,
+    }
+
+    resultType = None
+    for typeComponent in typeComponents:
+        typeComponent = typeComponent.strip()
+        if typeComponent == "const" or typeComponent == "restrict" or typeComponent == "volatile":
+            continue
+        if typeComponent in basicTypeMap:
+            resultType = basicTypeMap[typeComponent]
+        elif typeComponent == "*" and includePointers:
+            assert resultType is not None
+            resultType = ct.POINTER(resultType)
+
+    return resultType
+
+
+def ctypeFromNumpyType(numpyType):
+    typeMap = {
+        np.dtype('float64'): c_double,
+        np.dtype('float32'): c_float,
+    }
+    return typeMap[numpyType]
+
+
+def compileAndLoad(kernelFunctionNode):
+    with TemporaryDirectory() as tmpDir:
+        srcFile = os.path.join(tmpDir, 'source.cpp')
+        with open(srcFile, 'w') as sourceFile:
+            print('#include <iostream>', file=sourceFile)
+            print("#include <cmath>", file=sourceFile)
+            print('extern "C" { ', file=sourceFile)
+            print(kernelFunctionNode.generateC(), file=sourceFile)
+            print('}', file=sourceFile)
+
+        compilerCmd = [CONFIG['compiler']] + CONFIG['flags'].split()
+        libFile = os.path.join(tmpDir, "jit.so")
+        compilerCmd += [srcFile, '-o', libFile]
+        configEnv = CONFIG['env'] if 'env' in CONFIG else {}
+        env = os.environ.copy()
+        env.update(configEnv)
+        subprocess.call(compilerCmd, env=env)
+
+        showAssembly = False
+        if showAssembly:
+            assemblyFile = os.path.join(tmpDir, "assembly.s")
+            compilerCmd = [CONFIG['compiler'], '-S', '-o', assemblyFile, srcFile] + CONFIG['flags'].split()
+            subprocess.call(compilerCmd, env=env)
+            assembly = open(assemblyFile, 'r').read()
+            kernelFunctionNode.assembly = assembly
+        loadedJitLib = cdll.LoadLibrary(libFile)
+
+    return loadedJitLib
+
+
+def buildCTypeArgumentList(kernelFunctionNode, argumentDict):
+    ctArguments = []
+    for arg in kernelFunctionNode.parameters:
+        if arg.isFieldArgument:
+            field = argumentDict[arg.fieldName]
+            if arg.isFieldPtrArgument:
+                ctArguments.append(field.ctypes.data_as(ctypeFromString(arg.dtype)))
+            elif arg.isFieldShapeArgument:
+                dataType = ctypeFromString(arg.dtype, includePointers=False)
+                ctArguments.append(field.ctypes.shape_as(dataType))
+            elif arg.isFieldStrideArgument:
+                dataType = ctypeFromString(arg.dtype, includePointers=False)
+                baseFieldType = ctypeFromNumpyType(field.dtype)
+                strides = field.ctypes.strides_as(dataType)
+                for i in range(len(field.shape)):
+                    assert strides[i] % sizeof(baseFieldType) == 0
+                    strides[i] //= sizeof(baseFieldType)
+                ctArguments.append(strides)
+            else:
+                assert False
+        else:
+            param = argumentDict[arg.name]
+            expectedType = ctypeFromString(arg.dtype)
+            ctArguments.append(expectedType(param))
+    return ctArguments
+
+
+def makePythonFunctionIncompleteParams(kernelFunctionNode, argumentDict):
+    func = compileAndLoad(kernelFunctionNode)[kernelFunctionNode.functionName]
+    func.restype = None
+
+    def wrapper(**kwargs):
+        from copy import copy
+        fullArguments = copy(argumentDict)
+        fullArguments.update(kwargs)
+        args = buildCTypeArgumentList(kernelFunctionNode, fullArguments)
+        func(*args)
+    return wrapper
+
+
+def makePythonFunction(kernelFunctionNode, argumentDict={}):
+    # build up list of CType arguments
+    try:
+        args = buildCTypeArgumentList(kernelFunctionNode, argumentDict)
+    except KeyError:
+        # not all parameters specified yet
+        return makePythonFunctionIncompleteParams(kernelFunctionNode, argumentDict)
+    func = compileAndLoad(kernelFunctionNode)[kernelFunctionNode.functionName]
+    func.restype = None
+    return lambda: func(*args)
+
+
+
+
+
+
+
diff --git a/typedsymbol.py b/typedsymbol.py
new file mode 100644
index 000000000..675c39093
--- /dev/null
+++ b/typedsymbol.py
@@ -0,0 +1,26 @@
+import sympy as sp
+from sympy.core.cache import cacheit
+
+
+class TypedSymbol(sp.Symbol):
+
+    def __new__(cls, name, *args, **kwds):
+        obj = TypedSymbol.__xnew_cached_(cls, name, *args, **kwds)
+        return obj
+
+    def __new_stage2__(cls, name, dtype):
+        obj = super(TypedSymbol, cls).__xnew__(cls, name)
+        obj._dtype = dtype
+        return obj
+
+    __xnew__ = staticmethod(__new_stage2__)
+    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))
+
+    @property
+    def dtype(self):
+        return self._dtype
+
+    def _hashable_content(self):
+        superClassContents = list(super(TypedSymbol, self)._hashable_content())
+        t = tuple([*superClassContents, hash(self._dtype)])
+        return t
-- 
GitLab