diff --git a/cpu/cpujit.py b/cpu/cpujit.py
index 5be04bb0c6d0ab419bf657f73cb3bcc64d014252..75db5a790bea85c00c72d4a49916a23d867b55ee 100644
--- a/cpu/cpujit.py
+++ b/cpu/cpujit.py
@@ -363,6 +363,8 @@ def buildCTypeArgumentList(parameterSpecification, argumentDict):
     argumentDict = {symbolNameToVariableName(k): v for k, v in argumentDict.items()}
     ctArguments = []
     arrayShapes = set()
+    indexArrShapes = set()
+
     for arg in parameterSpecification:
         if arg.isFieldArgument:
             try:
@@ -388,8 +390,11 @@ def buildCTypeArgumentList(parameterSpecification, argumentDict):
                         raise ValueError("Passed array '%s' has strides %s which does not match expected strides %s" %
                                          (arg.fieldName, str(fieldArr.strides), str(symbolicFieldStrides)))
 
-                if not symbolicField.isIndexField:
+                if symbolicField.isIndexField:
+                    indexArrShapes.add(fieldArr.shape[:symbolicField.spatialDimensions])
+                else:
                     arrayShapes.add(fieldArr.shape[:symbolicField.spatialDimensions])
+
             elif arg.isFieldShapeArgument:
                 dataType = toCtypes(getBaseType(arg.dtype))
                 ctArguments.append(fieldArr.ctypes.shape_as(dataType))
@@ -412,6 +417,9 @@ def buildCTypeArgumentList(parameterSpecification, argumentDict):
 
     if len(arrayShapes) > 1:
         raise ValueError("All passed arrays have to have the same size " + str(arrayShapes))
+    if len(indexArrShapes) > 1:
+        raise ValueError("All passed index arrays have to have the same size " + str(arrayShapes))
+
     return ctArguments
 
 
diff --git a/cpu/kernelcreation.py b/cpu/kernelcreation.py
index 26d710d1b0c605244a98cb7f4be7c5b310603b60..d05c3143bdee8df0cb13229b2e6a2ec7524ff721 100644
--- a/cpu/kernelcreation.py
+++ b/cpu/kernelcreation.py
@@ -60,7 +60,8 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl
     return code
 
 
-def createIndexedKernel(listOfEquations, indexFields, typeForSymbol=None, coordinateNames=('x', 'y', 'z')):
+def createIndexedKernel(listOfEquations, indexFields, functionName="kernel", typeForSymbol=None,
+                        coordinateNames=('x', 'y', 'z')):
     """
     Similar to :func:`createKernel`, but here not all cells of a field are updated but only cells with
     coordinates which are stored in an index field. This traversal method can e.g. be used for boundary handling.
@@ -73,6 +74,7 @@ def createIndexedKernel(listOfEquations, indexFields, typeForSymbol=None, coordi
     :param listOfEquations: list of update equations or AST nodes
     :param indexFields: list of index fields, i.e. 1D fields with struct data type
     :param typeForSymbol: see documentation of :func:`createKernel`
+    :param functionName: see documentation of :func:`createKernel`
     :param coordinateNames: name of the coordinate fields in the struct data type
     :return: abstract syntax tree
     """
@@ -110,7 +112,7 @@ def createIndexedKernel(listOfEquations, indexFields, typeForSymbol=None, coordi
         loopBody.append(assignment)
 
     functionBody = Block([loopNode])
-    ast = KernelFunction(functionBody, allFields.union(indexFields))
+    ast = KernelFunction(functionBody, allFields, functionName)
 
     fixedCoordinateMapping = {f.name: coordinateTypedSymbols for f in nonIndexFields}
     resolveFieldAccesses(ast, set(['indexField']), fieldToFixedCoordinates=fixedCoordinateMapping)
diff --git a/gpucuda/__init__.py b/gpucuda/__init__.py
index d211c9eb4dcdecc95a6ba72eb543dff673b34577..0a8845b9fd573eca51453b61e0d6c5f3196c2757 100644
--- a/gpucuda/__init__.py
+++ b/gpucuda/__init__.py
@@ -1,2 +1,2 @@
-from pystencils.gpucuda.kernelcreation import createCUDAKernel
+from pystencils.gpucuda.kernelcreation import createCUDAKernel, createdIndexedCUDAKernel
 from pystencils.gpucuda.cudajit import makePythonFunction
diff --git a/gpucuda/cudajit.py b/gpucuda/cudajit.py
index 57fb3da588b085cae372b025d5ae243accbdf7f3..c754afcd6cf795cb88dd0f5f386a08a378371bb1 100644
--- a/gpucuda/cudajit.py
+++ b/gpucuda/cudajit.py
@@ -2,39 +2,44 @@ import numpy as np
 import pycuda.driver as cuda
 import pycuda.autoinit
 from pycuda.compiler import SourceModule
-from pycuda.gpuarray import GPUArray
 from pystencils.backends.cbackend import generateC
 from pystencils.transformations import symbolNameToVariableName
+from pystencils.types import StructType
 
 
-def numpyTypeFromString(typename, includePointers=True):
-    import ctypes as ct
+def makePythonFunction(kernelFunctionNode, argumentDict={}):
+    """
+    Creates a kernel function from an abstract syntax tree which
+    was created e.g. by :func:`pystencils.gpucuda.createCUDAKernel`
+    or :func:`pystencils.gpucuda.createdIndexedCUDAKernel`
 
-    typename = typename.replace("*", " * ")
-    typeComponents = typename.split()
+    :param kernelFunctionNode: the abstract syntax tree
+    :param argumentDict: parameters passed here are already fixed. Remaining parameters have to be passed to the
+                        returned kernel functor.
+    :return: kernel functor
+    """
+    code = "#include <cstdint>\n"
+    code += "#define FUNC_PREFIX __global__\n"
+    code += "#define RESTRICT __restrict__\n\n"
+    code += str(generateC(kernelFunctionNode))
 
-    basicTypeMap = {
-        'double': np.float64,
-        'float': np.float32,
-        'int': np.int32,
-        'long': np.int64,
-    }
+    mod = SourceModule(code, options=["-w", "-std=c++11"])
+    func = mod.get_function(kernelFunctionNode.functionName)
+
+    def wrapper(**kwargs):
+        from copy import copy
+        fullArguments = copy(argumentDict)
+        fullArguments.update(kwargs)
+        shape = _checkArguments(kernelFunctionNode.parameters, fullArguments)
 
-    resultType = None
-    for typeComponent in typeComponents:
-        typeComponent = typeComponent.strip()
-        if typeComponent == "const" or typeComponent == "restrict" or typeComponent == "volatile":
-            continue
-        if typeComponent in basicTypeMap:
-            resultType = basicTypeMap[typeComponent]
-        elif typeComponent == "*" and includePointers:
-            assert resultType is not None
-            resultType = ct.POINTER(resultType)
+        dictWithBlockAndThreadNumbers = kernelFunctionNode.getCallParameters(shape)
 
-    return resultType
+        args = _buildNumpyArgumentList(kernelFunctionNode, fullArguments)
+        func(*args, **dictWithBlockAndThreadNumbers)
+    return wrapper
 
 
-def buildNumpyArgumentList(kernelFunctionNode, argumentDict):
+def _buildNumpyArgumentList(kernelFunctionNode, argumentDict):
     argumentDict = {symbolNameToVariableName(k): v for k, v in argumentDict.items()}
     result = []
     for arg in kernelFunctionNode.parameters:
@@ -52,38 +57,57 @@ def buildNumpyArgumentList(kernelFunctionNode, argumentDict):
                 assert False
         else:
             param = argumentDict[arg.name]
-            expectedType = numpyTypeFromString(arg.dtype)
+            expectedType = arg.dtype.numpyDtype
             result.append(expectedType(param))
     return result
 
 
-def makePythonFunction(kernelFunctionNode, argumentDict={}):
-    code = "#define FUNC_PREFIX __global__\n"
-    code += "#define RESTRICT __restrict__\n\n"
-    code += str(generateC(kernelFunctionNode))
+def _checkArguments(parameterSpecification, argumentDict):
+    """
+    Checks if parameters passed to kernel match the description in the AST function node.
+    If not it raises a ValueError, on success it returns the array shape that determines the CUDA blocks and threads
+    """
+    argumentDict = {symbolNameToVariableName(k): v for k, v in argumentDict.items()}
+    arrayShapes = set()
+    indexArrShapes = set()
+    for arg in parameterSpecification:
+        if arg.isFieldArgument:
+            try:
+                fieldArr = argumentDict[arg.fieldName]
+            except KeyError:
+                raise KeyError("Missing field parameter for kernel call " + arg.fieldName)
 
-    mod = SourceModule(code, options=["-w"])
-    func = mod.get_function(kernelFunctionNode.functionName)
+            symbolicField = arg.field
+            if arg.isFieldPtrArgument:
+                if symbolicField.hasFixedShape:
+                    symbolicFieldShape = tuple(int(i) for i in symbolicField.shape)
+                    if isinstance(symbolicField.dtype, StructType):
+                        symbolicFieldShape = symbolicFieldShape[:-1]
+                    if symbolicFieldShape != fieldArr.shape:
+                        raise ValueError("Passed array '%s' has shape %s which does not match expected shape %s" %
+                                         (arg.fieldName, str(fieldArr.shape), str(symbolicField.shape)))
+                if symbolicField.hasFixedShape:
+                    symbolicFieldStrides = tuple(int(i) * fieldArr.dtype.itemsize for i in symbolicField.strides)
+                    if isinstance(symbolicField.dtype, StructType):
+                        symbolicFieldStrides = symbolicFieldStrides[:-1]
+                    if symbolicFieldStrides != fieldArr.strides:
+                        raise ValueError("Passed array '%s' has strides %s which does not match expected strides %s" %
+                                         (arg.fieldName, str(fieldArr.strides), str(symbolicFieldStrides)))
 
-    def wrapper(**kwargs):
-        from copy import copy
-        fullArguments = copy(argumentDict)
-        fullArguments.update(kwargs)
+                if symbolicField.isIndexField:
+                    indexArrShapes.add(fieldArr.shape[:symbolicField.spatialDimensions])
+                else:
+                    arrayShapes.add(fieldArr.shape[:symbolicField.spatialDimensions])
+
+    if len(arrayShapes) > 1:
+        raise ValueError("All passed arrays have to have the same size " + str(arrayShapes))
+    if len(indexArrShapes) > 1:
+        raise ValueError("All passed index arrays have to have the same size " + str(arrayShapes))
+
+    if len(indexArrShapes) > 0:
+        return list(indexArrShapes)[0]
+    else:
+        return list(arrayShapes)[0]
 
-        shapes = set()
-        strides = set()
-        for argValue in fullArguments.values():
-            if isinstance(argValue, GPUArray):
-                shapes.add(argValue.shape)
-                strides.add(argValue.strides)
-        if len(strides) == 0:
-            raise ValueError("No GPU arrays passed as argument")
-        assert len(strides) < 2, "All passed arrays have to have the same strides"
-        assert len(shapes) < 2, "All passed arrays have to have the same size"
-        shape = list(shapes)[0]
-        dictWithBlockAndThreadNumbers = kernelFunctionNode.getCallParameters(shape)
 
-        args = buildNumpyArgumentList(kernelFunctionNode, fullArguments)
-        func(*args, **dictWithBlockAndThreadNumbers)
-    return wrapper
 
diff --git a/gpucuda/kernelcreation.py b/gpucuda/kernelcreation.py
index 09d1efdf1c2a844840ace129bb4f55435bc4a8ec..512334898583749c2ab7ddb861c186e26c82c8f6 100644
--- a/gpucuda/kernelcreation.py
+++ b/gpucuda/kernelcreation.py
@@ -2,8 +2,9 @@ import sympy as sp
 
 from pystencils.transformations import resolveFieldAccesses, typeAllEquations, \
     parseBasePointerInfo, typingFromSympyInspection
-from pystencils.astnodes import Block, KernelFunction
+from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, SympyAssignment
 from pystencils import Field
+from pystencils.types import TypedSymbol, BasicType, StructType
 
 BLOCK_IDX = list(sp.symbols("blockIdx.x blockIdx.y blockIdx.z"))
 THREAD_IDX = list(sp.symbols("threadIdx.x threadIdx.y threadIdx.z"))
@@ -31,19 +32,14 @@ def getLinewiseCoordinates(field, ghostLayers):
 
 
 def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=None):
-    if not typeForSymbol or typeForSymbol == 'double':
-        typeForSymbol = typingFromSympyInspection(listOfEquations, "double")
-    elif typeForSymbol == 'float':
-        typeForSymbol = typingFromSympyInspection(listOfEquations, "float")
-
     fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol)
     allFields = fieldsRead.union(fieldsWritten)
     readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten])
 
-    code = KernelFunction(Block(assignments), allFields, functionName)
-    code.globalVariables.update(BLOCK_IDX + THREAD_IDX)
+    ast = KernelFunction(Block(assignments), allFields, functionName)
+    ast.globalVariables.update(BLOCK_IDX + THREAD_IDX)
 
-    fieldAccesses = code.atoms(Field.Access)
+    fieldAccesses = ast.atoms(Field.Access)
     requiredGhostLayers = max([fa.requiredGhostLayers for fa in fieldAccesses])
 
     coordMapping, getCallParameters = getLinewiseCoordinates(list(fieldsRead)[0], requiredGhostLayers)
@@ -51,10 +47,63 @@ def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=None)
     basePointerInfos = {f.name: parseBasePointerInfo(basePointerInfo, [2, 1, 0], f) for f in allFields}
 
     coordMapping = {f.name: coordMapping for f in allFields}
-    resolveFieldAccesses(code, readOnlyFields, fieldToFixedCoordinates=coordMapping,
+    resolveFieldAccesses(ast, readOnlyFields, fieldToFixedCoordinates=coordMapping,
                          fieldToBasePointerInfo=basePointerInfos)
     # add the function which determines #blocks and #threads as additional member to KernelFunction node
     # this is used by the jit
-    code.getCallParameters = getCallParameters
-    return code
+    ast.getCallParameters = getCallParameters
+    return ast
+
+
+def createdIndexedCUDAKernel(listOfEquations, indexFields, functionName="kernel", typeForSymbol=None,
+                             coordinateNames=('x', 'y', 'z')):
+    fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol)
+    allFields = fieldsRead.union(fieldsWritten)
+    readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten])
 
+    for indexField in indexFields:
+        indexField.isIndexField = True
+        assert indexField.spatialDimensions == 1, "Index fields have to be 1D"
+
+    nonIndexFields = [f for f in allFields if f not in indexFields]
+    spatialCoordinates = {f.spatialDimensions for f in nonIndexFields}
+    assert len(spatialCoordinates) == 1, "Non-index fields do not have the same number of spatial coordinates"
+    spatialCoordinates = list(spatialCoordinates)[0]
+
+    def getCoordinateSymbolAssignment(name):
+        for indexField in indexFields:
+            assert isinstance(indexField.dtype, StructType), "Index fields have to have a struct datatype"
+            dataType = indexField.dtype
+            if dataType.hasElement(name):
+                rhs = indexField[0](name)
+                lhs = TypedSymbol(name, BasicType(dataType.getElementType(name)))
+                return SympyAssignment(lhs, rhs)
+        raise ValueError("Index %s not found in any of the passed index fields" % (name,))
+
+    coordinateSymbolAssignments = [getCoordinateSymbolAssignment(n) for n in coordinateNames[:spatialCoordinates]]
+    coordinateTypedSymbols = [eq.lhs for eq in coordinateSymbolAssignments]
+    assignments = coordinateSymbolAssignments + assignments
+
+    # make 1D loop over index fields
+    loopBody = Block([])
+    loopNode = LoopOverCoordinate(loopBody, coordinateToLoopOver=0, start=0, stop=indexFields[0].shape[0])
+
+    for assignment in assignments:
+        loopBody.append(assignment)
+
+    functionBody = Block([loopNode])
+    ast = KernelFunction(functionBody, allFields, functionName)
+    ast.globalVariables.update(BLOCK_IDX + THREAD_IDX)
+
+    coordMapping, getCallParameters = getLinewiseCoordinates(list(fieldsRead)[0], ghostLayers=0)
+    basePointerInfo = [['spatialInner0']]
+    basePointerInfos = {f.name: parseBasePointerInfo(basePointerInfo, [2, 1, 0], f) for f in allFields}
+
+    coordMapping = {f.name: coordMapping for f in indexFields}
+    coordMapping.update({f.name: coordinateTypedSymbols for f in nonIndexFields})
+    resolveFieldAccesses(ast, readOnlyFields, fieldToFixedCoordinates=coordMapping,
+                         fieldToBasePointerInfo=basePointerInfos)
+    # add the function which determines #blocks and #threads as additional member to KernelFunction node
+    # this is used by the jit
+    ast.getCallParameters = getCallParameters
+    return ast
\ No newline at end of file