diff --git a/pystencils/llvm/llvm.py b/pystencils/llvm/llvm.py
index bfccbdc2e01817ff54aa1121ad91fbfb3991bcd9..1d5223e9509f274e002d755fa19a1a56f9ec016e 100644
--- a/pystencils/llvm/llvm.py
+++ b/pystencils/llvm/llvm.py
@@ -13,6 +13,24 @@ from pystencils.data_types import (
 from pystencils.llvm.control_flow import Loop
 
 
+# From Numba
+def set_cuda_kernel(lfunc):
+    from llvmlite.llvmpy.core import MetaData, MetaDataString, Constant, Type
+
+    m = lfunc.module
+
+    ops = lfunc, MetaDataString.get(m, "kernel"), Constant.int(Type.int(), 1)
+    md = MetaData.get(m, ops)
+
+    nmd = m.get_or_insert_named_metadata('nvvm.annotations')
+    nmd.add(md)
+
+    # set nvvm ir version
+    i32 = ir.IntType(32)
+    md_ver = m.add_metadata([i32(1), i32(2), i32(2), i32(0)])
+    m.add_named_metadata('nvvmir.version', md_ver)
+
+
 # From Numba
 def _call_sreg(builder, name):
     module = builder.module
@@ -191,6 +209,9 @@ class LLVMPrinter(Printer):
         self._print(func.body)
         self.builder.ret_void()
         self.fn = fn
+        if self.target == 'gpu':
+            set_cuda_kernel(fn)
+
         return fn
 
     def _print_Block(self, block):
diff --git a/pystencils/llvm/llvmjit.py b/pystencils/llvm/llvmjit.py
index 4d85e41885916b7d68109b2a7637883d76d91ab8..f8b0205f5e4f7c4356903ea1c4b04d7bc645763a 100644
--- a/pystencils/llvm/llvmjit.py
+++ b/pystencils/llvm/llvmjit.py
@@ -1,5 +1,7 @@
 import ctypes as ct
 import subprocess
+from functools import partial
+from itertools import chain
 from os.path import exists, join
 
 import llvmlite.binding as llvm
@@ -103,9 +105,9 @@ def generate_and_jit(ast):
     target = 'gpu' if ast._backend == 'llvm_gpu' else 'cpu'
     gen = generate_llvm(ast, target=target)
     if isinstance(gen, ir.Module):
-        return compile_llvm(gen, target)
+        return compile_llvm(gen, target, ast)
     else:
-        return compile_llvm(gen.module, target)
+        return compile_llvm(gen.module, target, ast)
 
 
 def make_python_function(ast, argument_dict={}, func=None):
@@ -120,8 +122,8 @@ def make_python_function(ast, argument_dict={}, func=None):
     return lambda: func(*args)
 
 
-def compile_llvm(module, target='cpu'):
-    jit = CudaJit() if target == "gpu" else Jit()
+def compile_llvm(module, target='cpu', ast=None):
+    jit = CudaJit(ast) if target == "gpu" else Jit()
     jit.parse(module)
     jit.optimize()
     jit.compile()
@@ -243,12 +245,13 @@ class CudaJit(Jit):
 
     default_data_layout = data_layout[MACHINE_BITS]
 
-    def __init__(self):
+    def __init__(self, ast):
         # super().__init__()
 
         # self.target = llvm.Target.from_triple(self.CUDA_TRIPLE[self.MACHINE_BITS])
         self._data_layout = self.default_data_layout[self.MACHINE_BITS]
         # self._target_data = llvm.create_target_data(self._data_layout)
+        self.indexing = ast.indexing
 
     def optimize(self):
         pmb = llvm.create_pass_manager_builder()
@@ -278,7 +281,6 @@ class CudaJit(Jit):
         llvmmod.verify()
         llvmmod.name = 'module'
 
-        self.module = str(llvmmod)
         self._llvmmod = llvm.parse_assembly(str(llvmmod))
 
     def compile(self):
@@ -287,48 +289,30 @@ class CudaJit(Jit):
         compiler_cache = get_cache_config()['object_cache']
         ir_file = join(compiler_cache, hashlib.md5(str(self._llvmmod).encode()).hexdigest() + '.ll')
         ptx_file = ir_file.replace('.ll', '.ptx')
+        try:
+            from pycuda.driver import Context
+            arch = "sm_%d%d" % Context.get_device().compute_capability()
+        except Exception:
+            arch = "sm_35"
 
         if not exists(ptx_file):
             self.write_ll(ir_file)
-            try:
-                from pycuda.driver import Context
-                arch = "sm_%d%d" % Context.get_device().compute_capability()
-            except Exception:
-                arch = "sm_35"
-
             subprocess.check_call(['llc-10', '-mcpu=' + arch, ir_file, '-o', ptx_file])
 
-        # TODO: make loading of ptx work
-        # import pycuda.autoinit
-
-        # def handler(compile_success_bool, info_str, error_str):
-            # if not compile_success_bool:
-            # print(info_str)
-            # print(error_str)
+        # cubin_file = ir_file.replace('.ll', '.cubin')
+        # if not exists(cubin_file):
+            # subprocess.check_call(['ptxas', '--gpu-name', arch, ptx_file, '-o', cubin_file])
+        import pycuda.driver
 
-        # # with open(ptx_file, 'rb') as f:
-            # # ptx_code = f.read()
-
-        # # from pycuda.driver import jit_input_type
-        # # self.linker.add_data(ptx_code, jit_input_type.PTX, 'foo')
-        # from pycuda.compiler import DynamicModule
-
-        # from pycuda.driver import jit_input_type
-        # module = DynamicModule().add_file(ptx_file, jit_input_type.PTX)
-        # module.link()
-        # # cuda_module = pycuda.driver.module_from_buffer(ptx_code, message_handler=handler)
-        # # print(dir(cuda_module))
-        # self.fptr = dict()
-        # module.get_function('kernel')
+        cuda_module = pycuda.driver.module_from_file(ptx_file)  # also works: cubin_file
+        self.cuda_module = cuda_module
 
     def __call__(self, func, *args, **kwargs):
-        fptr = {}
-        for func in self.module.functions:
-            if not func.is_declaration:
-                return_type = None
-                if func.ftype.return_type != ir.VoidType():
-                    return_type = to_ctypes(create_composite_type_from_string(str(func.ftype.return_type)))
-                args = [ctypes_from_llvm(arg) for arg in func.ftype.args]
-                function_address = self.ee.get_function_address(func.name)
-                fptr[func.name] = ct.CFUNCTYPE(return_type, *args)(function_address)
-        self.fptr = fptr
+        shape = [a.shape for a in chain(args, kwargs.values()) if hasattr(a, 'shape')][0]
+        block_and_thread_numbers = self.indexing.call_parameters(shape)
+        block_and_thread_numbers['block'] = tuple(int(i) for i in block_and_thread_numbers['block'])
+        block_and_thread_numbers['grid'] = tuple(int(i) for i in block_and_thread_numbers['grid'])
+        self.cuda_module.get_function(func)(*args, **kwargs, **block_and_thread_numbers)
+
+    def get_function_ptr(self, name):
+        return partial(self._call__, name)
diff --git a/pystencils_tests/test_jacobi_llvm.py b/pystencils_tests/test_jacobi_llvm.py
index a072248dd57422e2b17e11d984de737c9ab2d597..cccc710bf4f5dc1ce89e7369900c763b64319225 100644
--- a/pystencils_tests/test_jacobi_llvm.py
+++ b/pystencils_tests/test_jacobi_llvm.py
@@ -33,13 +33,19 @@ def test_jacobi_fixed_field_size():
 def test_jacobi_fixed_field_size_gpu():
     size = (30, 20)
 
+    import pycuda.autoinit  # noqa
+    from pycuda.gpuarray import to_gpu
+
     src_field_llvm = np.random.rand(*size)
     src_field_py = np.copy(src_field_llvm)
     dst_field_llvm = np.zeros(size)
     dst_field_py = np.zeros(size)
 
-    f = Field.create_from_numpy_array("f", src_field_llvm)
-    d = Field.create_from_numpy_array("d", dst_field_llvm)
+    f = Field.create_from_numpy_array("f", src_field_py)
+    d = Field.create_from_numpy_array("d", dst_field_py)
+
+    src_field_llvm = to_gpu(src_field_llvm)
+    dst_field_llvm = to_gpu(dst_field_llvm)
 
     jacobi = Assignment(d[0, 0], (f[1, 0] + f[-1, 0] + f[0, 1] + f[0, -1]) / 4)
     ast = create_kernel([jacobi], target='gpu')
@@ -52,7 +58,7 @@ def test_jacobi_fixed_field_size_gpu():
 
     jit = generate_and_jit(ast)
     jit('kernel', dst_field_llvm, src_field_llvm)
-    error = np.sum(np.abs(dst_field_py - dst_field_llvm))
+    error = np.sum(np.abs(dst_field_py - dst_field_llvm.get()))
     np.testing.assert_almost_equal(error, 0.0)