From 860cf788a3d0970aa1faa2661cbd21a1d2a2fba1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jan=20H=C3=B6nig?= <jan.hoenig@fau.de>
Date: Wed, 11 Oct 2017 15:44:35 +0200
Subject: [PATCH] Jan's rest of Master Thesis and followup Work:

Added LLVM: CodePrinter and a compiler
Updated data_types
Added tests
Added jupyter notebooks
Fixed bugs
Restructured transformation functions
---
 astnodes.py                                   | 213 ++++++------------
 backends/dot.py                               |  11 +-
 data_types.py                                 | 116 ++++++++--
 field.py                                      |   4 +
 llvm/__init__.py                              |   5 +-
 llvm/jit.py                                   |  81 -------
 llvm/kernelcreation.py                        |  87 ++++++-
 {backends => llvm}/llvm.py                    |  67 +++---
 llvm/llvmjit.py                               | 182 +++++++++++++++
 transformations/__init__.py                   |   2 +
 transformations/stage2.py                     | 159 +++++++++++++
 .../transformations.py                        |   1 -
 12 files changed, 647 insertions(+), 281 deletions(-)
 delete mode 100644 llvm/jit.py
 rename {backends => llvm}/llvm.py (81%)
 create mode 100644 llvm/llvmjit.py
 create mode 100644 transformations/__init__.py
 create mode 100644 transformations/stage2.py
 rename transformations.py => transformations/transformations.py (99%)

diff --git a/astnodes.py b/astnodes.py
index 09a23075f..2255e9455 100644
--- a/astnodes.py
+++ b/astnodes.py
@@ -61,6 +61,10 @@ class Node(object):
         for a in self.args:
             a.subs(*args, **kwargs)
 
+    @property
+    def func(self):
+        return self.__class__
+
     def atoms(self, argType):
         """
         Returns a set of all children which are an instance of the given argType
@@ -224,6 +228,7 @@ class Block(Node):
     def __init__(self, listOfNodes):
         super(Node, self).__init__()
         self._nodes = listOfNodes
+        self.parent = None
         for n in self._nodes:
             n.parent = self
 
@@ -324,6 +329,17 @@ class LoopOverCoordinate(Node):
                 result.append(e)
         return result
 
+    def replace(self, child, replacement):
+        if child == self.body:
+            self.body = replacement
+        elif child == self.start:
+            self.start = replacement
+        elif child == self.step:
+            self.step = replacement
+        elif child == self.stop:
+            self.stop = replacement
+
+
     @property
     def symbolsDefined(self):
         return set([self.loopCounterSymbol])
@@ -372,11 +388,15 @@ class LoopOverCoordinate(Node):
         return len(self.atoms(LoopOverCoordinate)) == 0
 
     def __str__(self):
-        return 'loop:{!s} in {!s}:{!s}:{!s}\n{!s}'.format(self.loopCounterName, self.start, self.stop, self.step,
-                                                          ("\t" + "\t".join(str(self.body).splitlines(True))))
+        return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})\n{!s}'.format(self.loopCounterName, self.start,
+                                                                     self.loopCounterName, self.stop,
+                                                                     self.loopCounterName, self.step,
+                                                                     ("\t" + "\t".join(str(self.body).splitlines(True))))
 
     def __repr__(self):
-        return 'loop:{!s} in {!s}:{!s}:{!s}'.format(self.loopCounterName, self.start, self.stop, self.step)
+        return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})'.format(self.loopCounterName, self.start,
+                                                               self.loopCounterName, self.stop,
+                                                               self.loopCounterName, self.step)
 
 
 class SympyAssignment(Node):
@@ -488,141 +508,52 @@ class TemporaryMemoryFree(Node):
         return []
 
 
-# TODO implement defined & undefinedSymbols
-class Conversion(Node):
-    def __init__(self, child, dtype, parent=None):
-        super(Conversion, self).__init__(parent)
-        self._args = [child]
-        self.dtype = dtype
-
-    @property
-    def args(self):
-        """Returns all arguments/children of this node"""
-        return self._args
-
-    @args.setter
-    def args(self, value):
-        self._args = value
-
-    @property
-    def symbolsDefined(self):
-        """Set of symbols which are defined by this node. """
-        return set()
-
-    @property
-    def undefinedSymbols(self):
-        """Symbols which are use but are not defined inside this node"""
-        raise set()
-
-    def __repr__(self):
-        return '(%s(%s))' % (repr(self.dtype), repr(self.args[0].dtype)) + repr(self.args)
-
-# TODO Pow
-
-
-_expr_dict = {'Add': ' + ', 'Mul': ' * ', 'Pow': '**'}
-
-
-class Expr(Node):
-    def __init__(self, args, parent=None):
-        super(Expr, self).__init__(parent)
-        self._args = list(args)
-        self.dtype = None
-
-    @property
-    def args(self):
-        return self._args
-
-    @args.setter
-    def args(self, value):
-        self._args = value
-
-    def replace(self, child, replacements):
-        idx = self.args.index(child)
-        del self.args[idx]
-        if type(replacements) is list:
-            for e in replacements:
-                e.parent = self
-            self.args = self.args[:idx] + replacements + self.args[idx:]
-        else:
-            replacements.parent = self
-            self.args.insert(idx, replacements)
-
-    @property
-    def symbolsDefined(self):
-        return set()  # Todo fix for symbol analysis
-
-    @property
-    def undefinedSymbols(self):
-        return set()  # Todo fix for symbol analysis
-
-    def __repr__(self):
-        return _expr_dict[self.__class__.__name__].join(repr(arg) for arg in self.args)
-
-
-class Mul(Expr):
-    pass
-
-
-class Add(Expr):
-    pass
-
-
-class Pow(Expr):
-    pass
-
-
-class Indexed(Expr):
-    def __init__(self, args, base, parent=None):
-        super(Indexed, self).__init__(args, parent)
-        self.base = base
-        # Get dtype from label, and unpointer it
-        self.dtype = createType(base.label.dtype.baseType)
-
-    def __repr__(self):
-        return '%s[%s]' % (self.args[0], self.args[1])
-
-
-class PointerArithmetic(Expr):
-    def __init__(self, args, pointer, parent=None):
-        super(PointerArithmetic, self).__init__([args] + [pointer], parent)
-        self.pointer = pointer
-        self.offset = args
-        self.dtype = pointer.dtype
-
-    def __repr__(self):
-        return '*(%s + %s)' % (self.pointer, self.args)
-
-
-class Number(Node, sp.AtomicExpr):
-    def __init__(self, number, parent=None):
-        super(Number, self).__init__(parent)
-
-        self.dtype, self.value = get_type_from_sympy(number)
-        self._args = tuple()
-
-    @property
-    def args(self):
-        """Returns all arguments/children of this node"""
-        return self._args
-
-    @property
-    def symbolsDefined(self):
-        """Set of symbols which are defined by this node. """
-        return set()
-
-    @property
-    def undefinedSymbols(self):
-        """Symbols which are use but are not defined inside this node"""
-        raise set()
-
-    def __repr__(self):
-        return repr(self.value)
-
-    def __float__(self):
-        return float(self.value)
-
-    def __int__(self):
-        return int(self.value)
-
-
+#_expr_dict = {'Add': ' + ', 'Mul': ' * ', 'Pow': '**'}
+#
+#
+#class Expr(Node):
+#    def __init__(self, args, parent=None):
+#        super(Expr, self).__init__(parent)
+#        self._args = list(args)
+#        self.dtype = None
+#
+#    @property
+#    def args(self):
+#        return self._args
+#
+#    @args.setter
+#    def args(self, value):
+#        self._args = value
+#
+#    def replace(self, child, replacements):
+#        idx = self.args.index(child)
+#        del self.args[idx]
+#        if type(replacements) is list:
+#            for e in replacements:
+#                e.parent = self
+#            self.args = self.args[:idx] + replacements + self.args[idx:]
+#        else:
+#            replacements.parent = self
+#            self.args.insert(idx, replacements)
+#
+#    @property
+#    def symbolsDefined(self):
+#        return set()  # Todo fix for symbol analysis
+#
+#    @property
+#    def undefinedSymbols(self):
+#        return set()  # Todo fix for symbol analysis
+#
+#    def __repr__(self):
+#        return _expr_dict[self.__class__.__name__].join(repr(arg) for arg in self.args)
+#
+#
+#class PointerArithmetic(Expr):
+#    def __init__(self, args, pointer, parent=None):
+#        super(PointerArithmetic, self).__init__([args] + [pointer], parent)
+#        self.pointer = pointer
+#        self.offset = args
+#        self.dtype = pointer.dtype
+#
+#    def __repr__(self):
+#        return '*(%s + %s)' % (self.pointer, self.args)
diff --git a/backends/dot.py b/backends/dot.py
index 63ff4651f..a36c8a5a3 100644
--- a/backends/dot.py
+++ b/backends/dot.py
@@ -1,5 +1,6 @@
 from sympy.printing.printer import Printer
 from graphviz import Digraph, lang
+import graphviz
 
 
 class DotPrinter(Printer):
@@ -14,7 +15,6 @@ class DotPrinter(Printer):
         self.dot.quote_edge = lang.quote
 
     def _print_KernelFunction(self, function):
-        print(self._nodeToStrFunction(function))
         self.dot.node(self._nodeToStrFunction(function), style='filled', fillcolor='#E69F00')
         self._print(function.body)
 
@@ -75,13 +75,18 @@ def dotprint(node, view=False, short=False, full=False, **kwargs):
     :param kwargs: is directly passed to the DotPrinter class: http://graphviz.readthedocs.io/en/latest/api.html#digraph
     :return: string in DOT format
     """
-    nodeToStrFunction = __shortened if short else lambda expr: repr(type(expr)) + repr(expr) if full else repr
+    nodeToStrFunction = repr
+    if short:
+        nodeToStrFunction = __shortened
+    elif full:
+        nodeToStrFunction = lambda expr: repr(type(expr)) + repr(expr)
     printer = DotPrinter(nodeToStrFunction, full, **kwargs)
     dot = printer.doprint(node)
     if view:
-        printer.dot.render(view=view)
+        return graphviz.Source(dot)
     return dot
 
+
 if __name__ == "__main__":
     from pystencils import Field
     import sympy as sp
diff --git a/data_types.py b/data_types.py
index 6442147cd..6f766b7aa 100644
--- a/data_types.py
+++ b/data_types.py
@@ -1,6 +1,7 @@
 import ctypes
 import sympy as sp
 import numpy as np
+import llvmlite.ir as ir
 from sympy.core.cache import cacheit
 
 from pystencils.cache import memorycache
@@ -18,6 +19,16 @@ class castFunc(sp.Function, sp.Rel):
             raise NotImplementedError()
 
 
+class pointerArithmeticFunc(sp.Function, sp.Rel):
+
+    @property
+    def canonical(self):
+        if hasattr(self.args[0], 'canonical'):
+            return self.args[0].canonical
+        else:
+            raise NotImplementedError()
+
+
 class TypedSymbol(sp.Symbol):
     def __new__(cls, *args, **kwds):
         obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
@@ -93,7 +104,10 @@ def createTypeFromString(specification):
     if basePart[0][-1] == "*":
         basePart[0] = basePart[0][:-1]
         parts.append('*')
-    baseType = BasicType(basePart[0], const)
+    try:
+        baseType = BasicType(basePart[0], const)
+    except TypeError:
+        baseType = BasicType(createTypeFromString.map[basePart[0]], const)
     currentType = baseType
     # Parse pointer parts
     for part in parts:
@@ -109,6 +123,13 @@ def createTypeFromString(specification):
         currentType = PointerType(currentType, const, restrict)
     return currentType
 
+createTypeFromString.map = {
+    'i64': np.int64,
+    'i32': np.int32,
+    'i16': np.int16,
+    'i8': np.int8,
+}
+
 
 def getBaseType(type):
     while type.baseType is not None:
@@ -145,6 +166,60 @@ toCtypes.map = {
 }
 
 
+def ctypes_from_llvm(data_type):
+    if isinstance(data_type, ir.PointerType):
+        ctype = ctypes_from_llvm(data_type.pointee)
+        if ctype is None:
+            return ctypes.c_void_p
+        else:
+            return ctypes.POINTER(ctype)
+    elif isinstance(data_type, ir.IntType):
+        if data_type.width == 8:
+            return ctypes.c_int8
+        elif data_type.width == 16:
+            return ctypes.c_int16
+        elif data_type.width == 32:
+            return ctypes.c_int32
+        elif data_type.width == 64:
+            return ctypes.c_int64
+        else:
+            raise ValueError("Int width %d is not supported" % data_type.width)
+    elif isinstance(data_type, ir.FloatType):
+        return ctypes.c_float
+    elif isinstance(data_type, ir.DoubleType):
+        return ctypes.c_double
+    elif isinstance(data_type, ir.VoidType):
+        return None  # Void type is not supported by ctypes
+    else:
+        raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type))
+
+
+def to_llvm_type(data_type):
+    """
+    Transforms a given type into ctypes
+    :param data_type: Subclass of Type
+    :return: llvmlite type object
+    """
+    if isinstance(data_type, PointerType):
+        return to_llvm_type(data_type.baseType).as_pointer()
+    else:
+        return to_llvm_type.map[data_type.numpyDtype]
+
+to_llvm_type.map = {
+    np.dtype(np.int8): ir.IntType(8),
+    np.dtype(np.int16): ir.IntType(16),
+    np.dtype(np.int32): ir.IntType(32),
+    np.dtype(np.int64): ir.IntType(64),
+
+    np.dtype(np.uint8): ir.IntType(8),
+    np.dtype(np.uint16): ir.IntType(16),
+    np.dtype(np.uint32): ir.IntType(32),
+    np.dtype(np.uint64): ir.IntType(64),
+
+    np.dtype(np.float32): ir.FloatType(),
+    np.dtype(np.float64): ir.DoubleType(),
+}
+
 def peelOffType(dtype, typeToPeelOff):
     while type(dtype) is typeToPeelOff:
         dtype = dtype.baseType
@@ -210,7 +285,7 @@ def getTypeOfExpression(expr):
         return collateTypes(tuple(getTypeOfExpression(a) for a in branchResults))
     elif isinstance(expr, sp.Indexed):
         typedSymbol = expr.base.label
-        return typedSymbol.dtype
+        return typedSymbol.dtype.baseType
     elif isinstance(expr, sp.boolalg.Boolean):
         # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
         result = createTypeFromString("bool")
@@ -222,31 +297,36 @@ def getTypeOfExpression(expr):
         types = tuple(getTypeOfExpression(a) for a in expr.args)
         return collateTypes(types)
 
-    raise NotImplementedError("Could not determine type for " + str(expr))
+    raise NotImplementedError("Could not determine type for", expr, type(expr))
 
 
 class Type(sp.Basic):
     def __new__(cls, *args, **kwargs):
         return sp.Basic.__new__(cls)
 
-    def __lt__(self, other):
+    def __lt__(self, other):  # deprecated
         # Needed for sorting the types inside an expression
         if isinstance(self, BasicType):
             if isinstance(other, BasicType):
-                return self.numpyDtype < other.numpyDtype  # TODO const
-            if isinstance(other, PointerType):
+                return self.numpyDtype > other.numpyDtype  # TODO const
+            elif isinstance(other, PointerType):
                 return False
-            if isinstance(other, StructType):
+            else:  # isinstance(other, StructType):
                 raise NotImplementedError("Struct type comparison is not yet implemented")
-        if isinstance(self, PointerType):
+        elif isinstance(self, PointerType):
             if isinstance(other, BasicType):
                 return True
-            if isinstance(other, PointerType):
-                return self.baseType < other.baseType  # TODO const, restrict
-            if isinstance(other, StructType):
+            elif isinstance(other, PointerType):
+                return self.baseType > other.baseType  # TODO const, restrict
+            else:  # isinstance(other, StructType):
                 raise NotImplementedError("Struct type comparison is not yet implemented")
-        if isinstance(self, StructType):
+        elif isinstance(self, StructType):
             raise NotImplementedError("Struct type comparison is not yet implemented")
+        else:
+            raise NotImplementedError
+
+    def _sympystr(self, *args, **kwargs):
+        return str(self)
 
 
 class BasicType(Type):
@@ -317,6 +397,9 @@ class BasicType(Type):
             result += " const"
         return result
 
+    def __repr__(self):
+        return str(self)
+
     def __eq__(self, other):
         if not isinstance(other, BasicType):
             return False
@@ -397,6 +480,9 @@ class PointerType(Type):
     def __str__(self):
         return "%s *%s%s" % (self.baseType, " RESTRICT " if self.restrict else "", " const " if self.const else "")
 
+    def __repr__(self):
+        return str(self)
+
     def __hash__(self):
         return hash(str(self))
 
@@ -444,6 +530,9 @@ class StructType(object):
             result += " const"
         return result
 
+    def __repr__(self):
+        return str(self)
+
     def __hash__(self):
         return hash((self.numpyDtype, self.const))
 
@@ -475,6 +564,7 @@ def get_type_from_sympy(node):
     elif isinstance(node, sp.Integer):
         return createType('int'), int(node)
     elif isinstance(node, sp.Rational):
-        raise NotImplementedError('Rationals are not supported yet')
+        # TODO is it always float?
+        return createType('double'), float(node.p/node.q)
     else:
         raise TypeError(node, ' is not a supported type (yet)!')
diff --git a/field.py b/field.py
index 98e388c49..3fb5b72e3 100644
--- a/field.py
+++ b/field.py
@@ -317,6 +317,10 @@ class Field(object):
         def offsets(self):
             return self._offsets
 
+        @offsets.setter
+        def offsets(self, value):
+            self._offsets = value
+
         @property
         def requiredGhostLayers(self):
             return int(np.max(np.abs(self._offsets)))
diff --git a/llvm/__init__.py b/llvm/__init__.py
index f34532f8b..16cd3d751 100644
--- a/llvm/__init__.py
+++ b/llvm/__init__.py
@@ -1,2 +1,3 @@
-from .kernelcreation import createKernel
-from .jit import compileLLVM
\ No newline at end of file
+from .kernelcreation import createKernel, createIndexedKernel
+from .llvmjit import compileLLVM, generate_and_jit, Jit, make_python_function
+from .llvm import generateLLVM
diff --git a/llvm/jit.py b/llvm/jit.py
deleted file mode 100644
index 918c202f3..000000000
--- a/llvm/jit.py
+++ /dev/null
@@ -1,81 +0,0 @@
-import llvmlite.ir as ir
-import llvmlite.binding as llvm
-from ..data_types import toCtypes, createType
-
-import ctypes as ct
-
-
-def compileLLVM(module):
-    jit = Jit()
-    jit.parse(module)
-    jit.optimize()
-    jit.compile()
-    return jit
-
-
-class Jit(object):
-    def __init__(self):
-        llvm.initialize()
-        llvm.initialize_all_targets()
-        llvm.initialize_native_target()
-        llvm.initialize_native_asmprinter()
-
-        self.module = None
-        self.llvmmod = None
-        self.target = llvm.Target.from_default_triple()
-        self.cpu = llvm.get_host_cpu_name()
-        self.cpu_features = llvm.get_host_cpu_features()
-        self.target_machine = self.target.create_target_machine(cpu=self.cpu, features=self.cpu_features.flatten(), opt=2)
-        self.ee = None
-        self.fptr = None
-
-    def parse(self, module):
-        self.module = module
-        llvmmod = llvm.parse_assembly(str(module))
-        llvmmod.verify()
-        self.llvmmod = llvmmod
-
-    def write_ll(self, file):
-        with open(file, 'w') as f:
-            f.write(str(self.llvmmod))
-
-    def optimize(self):
-        pmb = llvm.create_pass_manager_builder()
-        pmb.opt_level = 2
-        pmb.disable_unit_at_a_time = False
-        pmb.loop_vectorize = True
-        pmb.slp_vectorize = True
-        # TODO possible to pass for functions
-        pm = llvm.create_module_pass_manager()
-        pm.add_instruction_combining_pass()
-        pm.add_function_attrs_pass()
-        pm.add_constant_merge_pass()
-        pm.add_licm_pass()
-        pmb.populate(pm)
-        pm.run(self.llvmmod)
-
-    def compile(self, assembly_file=None, object_file=None):
-        ee = llvm.create_mcjit_compiler(self.llvmmod, self.target_machine)
-        ee.finalize_object()
-
-        if assembly_file is not None:
-            with open(assembly_file, 'w') as f:
-                f.write(self.target_machine.emit_assembly(self.llvmmod))
-        if object_file is not None:
-            with open(object_file, 'wb') as f:
-                f.write(self.target_machine.emit_object(self.llvmmod))
-
-        fptr = {}
-        for function in self.module.functions:
-            if not function.is_declaration:
-                return_type = None
-                if function.ftype.return_type != ir.VoidType():
-                    return_type = toCtypes(createType(str(function.ftype.return_type)))
-                args = [toCtypes(createType(str(arg))) for arg in function.ftype.args]
-                function_address = ee.get_function_address(function.name)
-                fptr[function.name] = ct.CFUNCTYPE(return_type, *args)(function_address)
-        self.ee = ee
-        self.fptr = fptr
-
-    def __call__(self, function, *args, **kwargs):
-        self.fptr[function](*args, **kwargs)
diff --git a/llvm/kernelcreation.py b/llvm/kernelcreation.py
index 403c9bb53..a22c39040 100644
--- a/llvm/kernelcreation.py
+++ b/llvm/kernelcreation.py
@@ -1,8 +1,9 @@
 import sympy as sp
+from pystencils.astnodes import SympyAssignment, Block, LoopOverCoordinate, KernelFunction
 from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, typingFromSympyInspection, \
-    typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop, \
-    desympy_ast, insert_casts
-from pystencils.data_types import TypedSymbol
+    typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop, insertCasts#, \
+    #desympy_ast, insert_casts
+from pystencils.data_types import TypedSymbol, BasicType, StructType
 from pystencils.field import Field
 import pystencils.astnodes as ast
 
@@ -54,17 +55,85 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl
         typedSplitGroups = [[typeSymbol(s) for s in splitGroup] for splitGroup in splitGroups]
         splitInnerLoop(code, typedSplitGroups)
 
-    basePointerInfo = [['spatialInner0'], ['spatialInner1']]
+    basePointerInfo = []
+    for i in range(len(loopOrder)):
+        basePointerInfo.append(['spatialInner%d' % i])
     basePointerInfos = {field.name: parseBasePointerInfo(basePointerInfo, loopOrder, field) for field in allFields}
 
     resolveFieldAccesses(code, readOnlyFields, fieldToBasePointerInfo=basePointerInfos)
     moveConstantsBeforeLoop(code)
 
-    print('Ast:')
+    #print('Ast:')
+    #print(code)
+    #desympy_ast(code)
+    #print('Desympied ast:')
+    #print(code)
+    #insert_casts(code)
     print(code)
-    desympy_ast(code)
-    print('Desympied ast:')
+    code = insertCasts(code)
     print(code)
-    insert_casts(code)
-
     return code
+
+
+def createIndexedKernel(listOfEquations, indexFields, functionName="kernel", typeForSymbol=None,
+                        coordinateNames=('x', 'y', 'z')):
+    """
+    Similar to :func:`createKernel`, but here not all cells of a field are updated but only cells with
+    coordinates which are stored in an index field. This traversal method can e.g. be used for boundary handling.
+
+    The coordinates are stored in a separated indexField, which is a one dimensional array with struct data type.
+    This struct has to contain fields named 'x', 'y' and for 3D fields ('z'). These names are configurable with the
+    'coordinateNames' parameter. The struct can have also other fields that can be read and written in the kernel, for
+    example boundary parameters.
+
+    :param listOfEquations: list of update equations or AST nodes
+    :param indexFields: list of index fields, i.e. 1D fields with struct data type
+    :param typeForSymbol: see documentation of :func:`createKernel`
+    :param functionName: see documentation of :func:`createKernel`
+    :param coordinateNames: name of the coordinate fields in the struct data type
+    :return: abstract syntax tree
+    """
+    fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol)
+    allFields = fieldsRead.union(fieldsWritten)
+
+    for indexField in indexFields:
+        indexField.isIndexField = True
+        assert indexField.spatialDimensions == 1, "Index fields have to be 1D"
+
+    nonIndexFields = [f for f in allFields if f not in indexFields]
+    spatialCoordinates = {f.spatialDimensions for f in nonIndexFields}
+    assert len(spatialCoordinates) == 1, "Non-index fields do not have the same number of spatial coordinates"
+    spatialCoordinates = list(spatialCoordinates)[0]
+
+    def getCoordinateSymbolAssignment(name):
+        for indexField in indexFields:
+            assert isinstance(indexField.dtype, StructType), "Index fields have to have a struct datatype"
+            dataType = indexField.dtype
+            if dataType.hasElement(name):
+                rhs = indexField[0](name)
+                lhs = TypedSymbol(name, BasicType(dataType.getElementType(name)))
+                return SympyAssignment(lhs, rhs)
+        raise ValueError("Index %s not found in any of the passed index fields" % (name,))
+
+    coordinateSymbolAssignments = [getCoordinateSymbolAssignment(n) for n in coordinateNames[:spatialCoordinates]]
+    coordinateTypedSymbols = [eq.lhs for eq in coordinateSymbolAssignments]
+    assignments = coordinateSymbolAssignments + assignments
+
+    # make 1D loop over index fields
+    loopBody = Block([])
+    loopNode = LoopOverCoordinate(loopBody, coordinateToLoopOver=0, start=0, stop=indexFields[0].shape[0])
+
+    for assignment in assignments:
+        loopBody.append(assignment)
+
+    functionBody = Block([loopNode])
+    ast = KernelFunction(functionBody, allFields, functionName)
+
+    fixedCoordinateMapping = {f.name: coordinateTypedSymbols for f in nonIndexFields}
+    resolveFieldAccesses(ast, set(['indexField']), fieldToFixedCoordinates=fixedCoordinateMapping)
+    moveConstantsBeforeLoop(ast)
+
+    desympy_ast(ast)
+    insert_casts(ast)
+
+    return ast
diff --git a/backends/llvm.py b/llvm/llvm.py
similarity index 81%
rename from backends/llvm.py
rename to llvm/llvm.py
index 34fed3765..14cbcc681 100644
--- a/backends/llvm.py
+++ b/llvm/llvm.py
@@ -6,16 +6,20 @@ from sympy import S
 # S is numbers?
 
 from pystencils.llvm.control_flow import Loop
-from ..data_types import createType
-from ..astnodes import Indexed
+from pystencils.data_types import createType, to_llvm_type, getTypeOfExpression
+from sympy import Indexed  # TODO used astnodes, this should not work!
 
 
-def generateLLVM(ast_node, module=ir.Module(), builder=ir.IRBuilder()):
+def generateLLVM(ast_node, module=None, builder=None):
     """
     Prints the ast as llvm code
     """
+    if module is None:
+        module = ir.Module()
+    if builder is None:
+        builder = ir.IRBuilder()
     printer = LLVMPrinter(module, builder)
-    return printer._print(ast_node)
+    return printer._print(ast_node) #TODO use doprint() instead???
 
 
 class LLVMPrinter(Printer):
@@ -37,19 +41,22 @@ class LLVMPrinter(Printer):
     def _add_tmp_var(self, name, value):
         self.tmp_var[name] = value
 
+    def _remove_tmp_var(self, name):
+        del self.tmp_var[name]
+
     def _print_Number(self, n):
-        if n.dtype == createType("int"):
+        if getTypeOfExpression(n) == createType("int"):
             return ir.Constant(self.integer, int(n))
-        elif n.dtype == createType("double"):
+        elif getTypeOfExpression(n) == createType("double"):
             return ir.Constant(self.fp_type, float(n))
         else:
             raise NotImplementedError("Numbers can only have int and double", n)
 
     def _print_Float(self, expr):
-        return ir.Constant(self.fp_type, expr.p)
+        return ir.Constant(self.fp_type, float(expr))
 
     def _print_Integer(self, expr):
-        return ir.Constant(self.integer, expr.p)
+        return ir.Constant(self.integer, int(expr))
 
     def _print_int(self, i):
         return ir.Constant(self.integer, i)
@@ -64,6 +71,7 @@ class LLVMPrinter(Printer):
         return val
 
     def _print_Pow(self, expr):
+        #print(expr)
         base0 = self._print(expr.base)
         if expr.exp == S.NegativeOne:
             return self.builder.fdiv(ir.Constant(self.fp_type, 1.0), base0)
@@ -88,9 +96,9 @@ class LLVMPrinter(Printer):
     def _print_Mul(self, expr):
         nodes = [self._print(a) for a in expr.args]
         e = nodes[0]
-        if expr.dtype == createType('double'):
+        if getTypeOfExpression(expr) == createType('double'):
             mul = self.builder.fmul
-        else: # int TODO others?
+        else:  # int TODO unsigned/signed
             mul = self.builder.mul
         for node in nodes[1:]:
             e = mul(e, node)
@@ -99,24 +107,20 @@ class LLVMPrinter(Printer):
     def _print_Add(self, expr):
         nodes = [self._print(a) for a in expr.args]
         e = nodes[0]
-        if expr.dtype == createType('double'):
+        if getTypeOfExpression(expr) == createType('double'):
             add = self.builder.fadd
-        else: # int TODO others?
+        else:  # int TODO unsigned/signed
             add = self.builder.add
         for node in nodes[1:]:
             e = add(e, node)
         return e
 
     def _print_KernelFunction(self, function):
+        # KernelFunction does not posses a return type
         return_type = self.void
-        # TODO argument in their own call? -> nope
         parameter_type = []
         for parameter in function.parameters:
-            # TODO what about ptr shape and stride argument?
-            if parameter.isFieldArgument:
-                parameter_type.append(self.fp_pointer)
-            else:
-                parameter_type.append(self.fp_type)
+            parameter_type.append(to_llvm_type(parameter.dtype))
         func_type = ir.FunctionType(return_type, tuple(parameter_type))
         name = function.functionName
         fn = ir.Function(self.module, func_type, name)
@@ -130,7 +134,7 @@ class LLVMPrinter(Printer):
         # func.attributes.add("inlinehint")
         # func.attributes.add("argmemonly")
         block = fn.append_basic_block(name="entry")
-        self.builder = ir.IRBuilder(block)
+        self.builder = ir.IRBuilder(block) #TODO use goto_block instead
         self._print(function.body)
         self.builder.ret_void()
         self.fn = fn
@@ -144,8 +148,8 @@ class LLVMPrinter(Printer):
         with Loop(self.builder, self._print(loop.start), self._print(loop.stop), self._print(loop.step),
                   loop.loopCounterName, loop.loopCounterSymbol.name) as i:
             self._add_tmp_var(loop.loopCounterSymbol, i)
-            # TODO remove tmp var
             self._print(loop.body)
+            self._remove_tmp_var(loop.loopCounterSymbol)
 
     def _print_SympyAssignment(self, assignment):
         expr = self._print(assignment.rhs)
@@ -158,10 +162,10 @@ class LLVMPrinter(Printer):
         self.func_arg_map[assignment.lhs.name] = expr
         return expr
 
-    def _print_Conversion(self, conversion):
+    def _print_castFunc(self, conversion):
         node = self._print(conversion.args[0])
-        to_dtype = conversion.dtype
-        from_dtype = conversion.args[0].dtype
+        to_dtype = getTypeOfExpression(conversion)
+        from_dtype = getTypeOfExpression(conversion.args[0])
         # (From, to)
         decision = {
             (createType("int"), createType("double")): functools.partial(self.builder.sitofp, node, self.fp_type),
@@ -173,8 +177,9 @@ class LLVMPrinter(Printer):
             (createType("double * restrict const"), createType("int")): functools.partial(self.builder.ptrtoint, node, self.integer),
             (createType("int"), createType("double * restrict const")): functools.partial(self.builder.inttoptr, node, self.fp_pointer),
             }
-        # TODO float, const, restrict
+        # TODO float, TEST: const, restrict
         # TODO bitcast, addrspacecast
+        # TODO unsigned/signed fills
         # print([x for x in decision.keys()])
         # print("Types:")
         # print([(type(x), type(y)) for (x, y) in decision.keys()])
@@ -182,21 +187,21 @@ class LLVMPrinter(Printer):
         # print((from_dtype, to_dtype))
         return decision[(from_dtype, to_dtype)]()
 
+    def _print_pointerArithmeticFunc(self, pointer):
+        ptr = self._print(pointer.args[0])
+        index = self._print(pointer.args[1])
+        return self.builder.gep(ptr, [index])
+
     def _print_Indexed(self, indexed):
         ptr = self._print(indexed.base.label)
         index = self._print(indexed.args[1])
         gep = self.builder.gep(ptr, [index])
         return self.builder.load(gep, name=indexed.base.label.name)
 
-    def _print_PointerArithmetic(self, pointer):
-        ptr = self._print(pointer.pointer)
-        index = self._print(pointer.offset)
-        return self.builder.gep(ptr, [index])
-
     # Should have a list of math library functions to validate this.
-    # TODO function calls
+    # TODO function calls to libs
     def _print_Function(self, expr):
-        name = expr.func.__name__
+        name = expr.name
         e0 = self._print(expr.args[0])
         fn = self.ext_fn.get(name)
         if not fn:
diff --git a/llvm/llvmjit.py b/llvm/llvmjit.py
new file mode 100644
index 000000000..d767c0ee5
--- /dev/null
+++ b/llvm/llvmjit.py
@@ -0,0 +1,182 @@
+import llvmlite.ir as ir
+import llvmlite.binding as llvm
+import numpy as np
+import ctypes as ct
+import subprocess
+import shutil
+
+from ..data_types import toCtypes, createType, ctypes_from_llvm
+from .llvm import generateLLVM
+from ..cpu.cpujit import buildCTypeArgumentList
+
+
+def generate_and_jit(ast):
+    gen = generateLLVM(ast)
+    if isinstance(gen, ir.Module):
+        return compileLLVM(gen)
+    else:
+        return compileLLVM(gen.module)
+
+
+def make_python_function(ast, argumentDict={}, func=None):
+    try:
+        args = buildCTypeArgumentList(ast.parameters, argumentDict)
+    except KeyError:
+        # not all parameters specified yet
+        return make_python_function_incomplete(ast, argumentDict, func)
+    if func is None:
+        jit = generate_and_jit(ast)
+        func = jit.get_function_ptr(ast.functionName)
+    return lambda: func(*args)
+
+
+def make_python_function_incomplete(ast, argumentDict, func=None):
+    if func is None:
+        jit = generate_and_jit(ast)
+        func = jit.get_function_ptr(ast.functionName)
+    parameters = ast.parameters
+
+    cache = {}
+
+    def wrapper(**kwargs):
+        key = hash(tuple((k, id(v)) for k, v in kwargs.items()))
+        try:
+            args = cache[key]
+            func(*args)
+        except KeyError:
+            fullArguments = argumentDict.copy()
+            fullArguments.update(kwargs)
+            args = buildCTypeArgumentList(parameters, fullArguments)
+            cache[key] = args
+            func(*args)
+
+    return wrapper
+
+
+def compileLLVM(module):
+    jit = Jit()
+    jit.parse(module)
+    jit.optimize()
+    jit.compile()
+    return jit
+
+
+class Jit(object):
+    def __init__(self):
+        llvm.initialize()
+        llvm.initialize_all_targets()
+        llvm.initialize_native_target()
+        llvm.initialize_native_asmprinter()
+
+        self.module = None
+        self._llvmmod = llvm.parse_assembly("")
+        self.target = llvm.Target.from_default_triple()
+        self.cpu = llvm.get_host_cpu_name()
+        self.cpu_features = llvm.get_host_cpu_features()
+        self.target_machine = self.target.create_target_machine(cpu=self.cpu, features=self.cpu_features.flatten(), opt=2)
+        llvm.check_jit_execution()
+        self.ee = llvm.create_mcjit_compiler(self.llvmmod, self.target_machine)
+        self.ee.finalize_object()
+        self.fptr = None
+
+    @property
+    def llvmmod(self):
+        return self._llvmmod
+
+    @llvmmod.setter
+    def llvmmod(self, mod):
+        self.ee.remove_module(self.llvmmod)
+        self.ee.add_module(mod)
+        self.ee.finalize_object()
+        self.compile()
+        self._llvmmod = mod
+
+    def parse(self, module):
+        self.module = module
+        llvmmod = llvm.parse_assembly(str(module))
+        llvmmod.verify()
+        llvmmod.triple = self.target.triple
+        llvmmod.name = 'module'
+        self.llvmmod = llvmmod
+
+    def write_ll(self, file):
+        with open(file, 'w') as f:
+            f.write(str(self.llvmmod))
+
+    def write_assembly(self, file):
+        with open(file, 'w') as f:
+            f.write(self.target_machine.emit_assembly(self.llvmmod))
+
+    def write_object_file(self, file):
+        with open(file, 'wb') as f:
+            f.write(self.target_machine.emit_object(self.llvmmod))
+
+    def optimize(self):
+        pmb = llvm.create_pass_manager_builder()
+        pmb.opt_level = 2
+        pmb.disable_unit_at_a_time = False
+        pmb.loop_vectorize = True
+        pmb.slp_vectorize = True
+        # TODO possible to pass for functions
+        pm = llvm.create_module_pass_manager()
+        pm.add_instruction_combining_pass()
+        pm.add_function_attrs_pass()
+        pm.add_constant_merge_pass()
+        pm.add_licm_pass()
+        pmb.populate(pm)
+        pm.run(self.llvmmod)
+
+    def optimize_polly(self, opt):
+        if shutil.which(opt) is None:
+            print('Path to the executable is wrong')
+            return
+        canonicalize = subprocess.Popen([opt, '-polly-canonicalize'], stdin=subprocess.PIPE, stdout=subprocess.PIPE)
+
+        analyze = subprocess.Popen(
+            [opt, '-polly-codegen', '-polly-vectorizer=polly', '-polly-parallel', '-polly-process-unprofitable', '-f'],
+            stdin=canonicalize.stdout, stdout=subprocess.PIPE)
+
+        canonicalize.communicate(input=self.llvmmod.as_bitcode())
+
+        optimize = subprocess.Popen([opt, '-O3', '-f'], stdin=analyze.stdout, stdout=subprocess.PIPE)
+        opts, _ = optimize.communicate()
+        llvmmod = llvm.parse_bitcode(opts)
+        llvmmod.verify()
+        self.llvmmod = llvmmod
+
+    def compile(self):
+        fptr = {}
+        for function in self.module.functions:
+            if not function.is_declaration:
+                return_type = None
+                if function.ftype.return_type != ir.VoidType():
+                    return_type = toCtypes(createType(str(function.ftype.return_type)))
+                args = [toCtypes(createType(str(arg))) for arg in function.ftype.args]
+                function_address = self.ee.get_function_address(function.name)
+                fptr[function.name] = ct.CFUNCTYPE(return_type, *args)(function_address)
+        self.fptr = fptr
+
+    def __call__(self, function, *args, **kwargs):
+        target_function = next(f for f in self.module.functions if f.name == function)
+        arg_types = [ctypes_from_llvm(arg.type) for arg in target_function.args]
+
+        transformed_args = []
+        for i, arg in enumerate(args):
+            if isinstance(arg, np.ndarray):
+                transformed_args.append(arg.ctypes.data_as(arg_types[i]))
+            else:
+                transformed_args.append(arg)
+
+        self.fptr[function](*transformed_args)
+
+    def print_functions(self):
+        for f in self.module.functions:
+            print(f.ftype.return_type, f.name, f.args)
+
+    def get_function_ptr(self, name):
+        fptr = self.fptr[name]
+        fptr.jit = self
+        return fptr
+
+
+
diff --git a/transformations/__init__.py b/transformations/__init__.py
new file mode 100644
index 000000000..a8ba7d85a
--- /dev/null
+++ b/transformations/__init__.py
@@ -0,0 +1,2 @@
+from .transformations import *
+from .stage2 import *
diff --git a/transformations/stage2.py b/transformations/stage2.py
new file mode 100644
index 000000000..d1cfd2d15
--- /dev/null
+++ b/transformations/stage2.py
@@ -0,0 +1,159 @@
+from operator import attrgetter
+
+import sympy as sp
+
+from pystencils.data_types import TypedSymbol, createType, PointerType, StructType, getBaseType, getTypeOfExpression, collateTypes, castFunc, pointerArithmeticFunc
+import pystencils.astnodes as ast
+
+
+def insertCasts(node): # TODO test casts!!!, edit testcase
+    """
+    Checks the types and inserts casts and pointer arithmetic where necessary
+    :param node: the head node of the ast
+    :return: modified ast
+    """
+    def cast(zippedArgsTypes, target):
+        """
+        Adds casts to the arguments if their type differs from the target type
+        :param zippedArgsTypes: a zipped list of args and types
+        :param target: The target data type
+        :return: args with possible casts
+        """
+        casted_args = []
+        for arg, dataType in zippedArgsTypes:
+            if dataType.numpyDtype != target.numpyDtype: # TODO ignoring const
+                casted_args.append(castFunc(arg, target))
+            else:
+                casted_args.append(arg)
+        return casted_args
+
+    def pointerArithmetic(args):
+        """
+        Creates a valid pointer arithmetic function
+        :param args: Arguments of the add expression
+        :return: pointerArithmeticFunc
+        """
+        pointer = None
+        newArgs = []
+        for arg, dataType in args:
+            if dataType.func == PointerType:
+                assert pointer is None
+                pointer = arg
+        for arg, dataType in args:
+            if arg != pointer:
+                assert dataType.is_int() or dataType.is_uint()
+                newArgs.append(arg)
+        newArgs = sp.Add(*newArgs) if len(newArgs) > 0 else newArgs
+        return pointerArithmeticFunc(pointer, newArgs)
+
+    if isinstance(node, sp.AtomicExpr):
+        return node
+    args = []
+    for arg in node.args:
+        args.append(insertCasts(arg))
+    # TODO indexed, SympyAssignment, LoopOverCoordinate, Pow
+    if node.func in (sp.Add, sp.Mul):
+        types = [getTypeOfExpression(arg) for arg in args]
+        assert len(types) > 0
+        target = collateTypes(types)
+        zipped = list(zip(args, types))
+        print(zipped)
+        if target.func == PointerType:
+            assert node.func == sp.Add
+            return pointerArithmetic(zipped)
+        else:
+            return node.func(*cast(zipped, target))
+    elif node.func == ast.SympyAssignment:
+        # TODO casting of rhs/lhs
+        return node.func(*args)
+    elif node.func == ast.ResolvedFieldAccess:
+        #print("Node:", node, type(node), node.__class__.mro())
+        # TODO Everything
+        return node
+    elif node.func == ast.Block:
+        for oldArg, newArg in zip(node.args, args):
+            node.replace(oldArg, newArg)
+        return node
+    elif node.func == ast.LoopOverCoordinate:
+        for oldArg, newArg in zip(node.args, args):
+            node.replace(oldArg, newArg)
+        return node
+
+    #print(node.func(*args))
+    return node.func(*args)
+
+
+def insert_casts(node):
+    """
+    Inserts casts and dtype where needed
+    :param node: ast which should be traversed
+    :return: node
+    """
+    def conversion(args):
+        target = args[0]
+        if isinstance(target.dtype, PointerType):
+            # Pointer arithmetic
+            for arg in args[1:]:
+                # Check validness
+                if not arg.dtype.is_int() and not arg.dtype.is_uint():
+                    raise ValueError("Impossible pointer arithmetic", target, arg)
+            pointer = ast.PointerArithmetic(ast.Add(args[1:]), target)
+            return [pointer]
+
+        else:
+            for i in range(len(args)):
+                if args[i].dtype.numpyDtype != target.dtype.numpyDtype:  # TODO ignoring const -> valid behavior?
+                    args[i] = ast.Conversion(args[i], createType(target.dtype), node)
+            return args
+
+    for arg in node.args:
+        insert_casts(arg)
+    if isinstance(node, ast.Indexed):
+        # TODO need to do something here?
+        pass
+    elif isinstance(node, ast.Expr):
+        args = sorted((arg for arg in node.args), key=attrgetter('dtype'))
+        target = args[0]
+        node.args = conversion(args)
+        node.dtype = target.dtype
+    elif isinstance(node, ast.SympyAssignment):
+        if node.lhs.dtype != node.rhs.dtype:
+            node.replace(node.rhs, ast.Conversion(node.rhs, node.lhs.dtype))
+    elif isinstance(node, ast.LoopOverCoordinate):
+        pass
+    return node
+
+
+#def desympy_ast(node):
+#    """
+#    Remove Sympy Expressions, which have more then one argument.
+#    This is necessary for further changes in the tree.
+#    :param node: ast which should be traversed. Only node's children will be modified.
+#    :return: (modified) node
+#    """
+#    if node.args is None:
+#        return node
+#    for i in range(len(node.args)):
+#        arg = node.args[i]
+#        if isinstance(arg, sp.Add):
+#            node.replace(arg, ast.Add(arg.args, node))
+#        elif isinstance(arg, sp.Number):
+#            node.replace(arg, ast.Number(arg, node))
+#        elif isinstance(arg, sp.Mul):
+#            node.replace(arg, ast.Mul(arg.args, node))
+#        elif isinstance(arg, sp.Pow):
+#            node.replace(arg, ast.Pow(arg, node))
+#        elif isinstance(arg, sp.tensor.Indexed) or isinstance(arg, sp.tensor.indexed.Indexed):
+#            node.replace(arg, ast.Indexed(arg.args, arg.base, node))
+#        elif isinstance(arg,  sp.tensor.IndexedBase):
+#            node.replace(arg, arg.label)
+#        elif isinstance(arg, sp.Function):
+#            node.replace(arg, ast.Function(arg.func, arg.args, node))
+#        #elif isinstance(arg, sp.containers.Tuple):
+#        #
+#        else:
+#            #print('Not transforming:', type(arg), arg)
+#            pass
+#    for arg in node.args:
+#        desympy_ast(arg)
+#    return node
diff --git a/transformations.py b/transformations/transformations.py
similarity index 99%
rename from transformations.py
rename to transformations/transformations.py
index 5c1fe5698..7cc41daf1 100644
--- a/transformations.py
+++ b/transformations/transformations.py
@@ -195,7 +195,6 @@ def parseBasePointerInfo(basePointerSpecification, loopOrder, field):
             if i in specifiedCoordinates:
                 raise ValueError("Coordinate %d specified two times" % (i,))
             specifiedCoordinates.add(i)
-
         for element in specGroup:
             if type(element) is int:
                 addNewElement(element)
-- 
GitLab