diff --git a/pystencils/llvm/llvmjit.py b/pystencils/llvm/llvmjit.py
index 6dc40151e4495d5fa67f9c6081a442a558b1d09a..858509441875bb189152303aa84f751728f36d22 100644
--- a/pystencils/llvm/llvmjit.py
+++ b/pystencils/llvm/llvmjit.py
@@ -1,4 +1,6 @@
 import ctypes as ct
+import subprocess
+from os.path import exists, join
 
 import llvmlite.binding as llvm
 import llvmlite.ir as ir
@@ -98,11 +100,12 @@ def make_python_function_incomplete_params(kernel_function_node, argument_dict,
 
 
 def generate_and_jit(ast):
+    target = 'gpu' if ast._backend == 'llvm_gpu' else 'cpu'
     gen = generate_llvm(ast)
     if isinstance(gen, ir.Module):
-        return compile_llvm(gen)
+        return compile_llvm(gen, target)
     else:
-        return compile_llvm(gen.module)
+        return compile_llvm(gen.module, target)
 
 
 def make_python_function(ast, argument_dict={}, func=None):
@@ -117,8 +120,8 @@ def make_python_function(ast, argument_dict={}, func=None):
     return lambda: func(*args)
 
 
-def compile_llvm(module):
-    jit = Jit()
+def compile_llvm(module, target='cpu'):
+    jit = CudaJit() if target == "gpu" else Jit()
     jit.parse(module)
     jit.optimize()
     jit.compile()
@@ -224,3 +227,108 @@ class Jit(object):
         fptr = self.fptr[name]
         fptr.jit = self
         return fptr
+
+
+# Following code more or less from numba
+class CudaJit(Jit):
+
+    CUDA_TRIPLE = {32: 'nvptx-nvidia-cuda',
+                   64: 'nvptx64-nvidia-cuda'}
+    MACHINE_BITS = tuple.__itemsize__ * 8
+    data_layout = {
+        32: ('e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-'
+             'f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64'),
+        64: ('e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-'
+             'f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64')}
+
+    default_data_layout = data_layout[MACHINE_BITS]
+
+    def __init__(self):
+        # super().__init__()
+
+        # self.target = llvm.Target.from_triple(self.CUDA_TRIPLE[self.MACHINE_BITS])
+        self._data_layout = self.default_data_layout[self.MACHINE_BITS]
+        # self._target_data = llvm.create_target_data(self._data_layout)
+
+    def optimize(self):
+        pmb = llvm.create_pass_manager_builder()
+        pmb.opt_level = 2
+        pmb.disable_unit_at_a_time = False
+        pmb.loop_vectorize = False
+        pmb.slp_vectorize = False
+        # TODO possible to pass for functions
+        pm = llvm.create_module_pass_manager()
+        pm.add_instruction_combining_pass()
+        pm.add_function_attrs_pass()
+        pm.add_constant_merge_pass()
+        pm.add_licm_pass()
+        pmb.populate(pm)
+        pm.run(self.llvmmod)
+        pm.run(self.llvmmod)
+
+    def write_ll(self, file):
+        with open(file, 'w') as f:
+            f.write(str(self.llvmmod))
+
+    def parse(self, module):
+
+        llvmmod = module
+        llvmmod.triple = self.CUDA_TRIPLE[self.MACHINE_BITS]
+        llvmmod.data_layout = self.default_data_layout
+        llvmmod.verify()
+        llvmmod.name = 'module'
+
+        self.module = str(llvmmod)
+        self._llvmmod = llvm.parse_assembly(str(llvmmod))
+
+    def compile(self):
+        from pystencils.cpu.cpujit import get_cache_config
+        import hashlib
+        compiler_cache = get_cache_config()['object_cache']
+        ir_file = join(compiler_cache, hashlib.md5(str(self._llvmmod).encode()).hexdigest() + '.ll')
+        ptx_file = ir_file.replace('.ll', '.ptx')
+
+        if not exists(ptx_file):
+            self.write_ll(ir_file)
+            try:
+                from pycuda.driver import Context
+                arch = "sm_%d%d" % Context.get_device().compute_capability()
+            except Exception:
+                arch = "sm_35"
+
+            subprocess.check_call(['llc-10', '-mcpu=' + arch, ir_file, '-o', ptx_file])
+
+        # TODO: make loading of ptx work
+        # import pycuda.autoinit
+
+        # def handler(compile_success_bool, info_str, error_str):
+            # if not compile_success_bool:
+            # print(info_str)
+            # print(error_str)
+
+        # # with open(ptx_file, 'rb') as f:
+            # # ptx_code = f.read()
+
+        # # from pycuda.driver import jit_input_type
+        # # self.linker.add_data(ptx_code, jit_input_type.PTX, 'foo')
+        # from pycuda.compiler import DynamicModule
+
+        # from pycuda.driver import jit_input_type
+        # module = DynamicModule().add_file(ptx_file, jit_input_type.PTX)
+        # module.link()
+        # # cuda_module = pycuda.driver.module_from_buffer(ptx_code, message_handler=handler)
+        # # print(dir(cuda_module))
+        # self.fptr = dict()
+        # module.get_function('kernel')
+
+    def __call__(self, func, *args, **kwargs):
+        fptr = {}
+        for func in self.module.functions:
+            if not func.is_declaration:
+                return_type = None
+                if func.ftype.return_type != ir.VoidType():
+                    return_type = to_ctypes(create_composite_type_from_string(str(func.ftype.return_type)))
+                args = [ctypes_from_llvm(arg) for arg in func.ftype.args]
+                function_address = self.ee.get_function_address(func.name)
+                fptr[func.name] = ct.CFUNCTYPE(return_type, *args)(function_address)
+        self.fptr = fptr