diff --git a/pystencils/llvm/llvmjit.py b/pystencils/llvm/llvmjit.py index 6dc40151e4495d5fa67f9c6081a442a558b1d09a..858509441875bb189152303aa84f751728f36d22 100644 --- a/pystencils/llvm/llvmjit.py +++ b/pystencils/llvm/llvmjit.py @@ -1,4 +1,6 @@ import ctypes as ct +import subprocess +from os.path import exists, join import llvmlite.binding as llvm import llvmlite.ir as ir @@ -98,11 +100,12 @@ def make_python_function_incomplete_params(kernel_function_node, argument_dict, def generate_and_jit(ast): + target = 'gpu' if ast._backend == 'llvm_gpu' else 'cpu' gen = generate_llvm(ast) if isinstance(gen, ir.Module): - return compile_llvm(gen) + return compile_llvm(gen, target) else: - return compile_llvm(gen.module) + return compile_llvm(gen.module, target) def make_python_function(ast, argument_dict={}, func=None): @@ -117,8 +120,8 @@ def make_python_function(ast, argument_dict={}, func=None): return lambda: func(*args) -def compile_llvm(module): - jit = Jit() +def compile_llvm(module, target='cpu'): + jit = CudaJit() if target == "gpu" else Jit() jit.parse(module) jit.optimize() jit.compile() @@ -224,3 +227,108 @@ class Jit(object): fptr = self.fptr[name] fptr.jit = self return fptr + + +# Following code more or less from numba +class CudaJit(Jit): + + CUDA_TRIPLE = {32: 'nvptx-nvidia-cuda', + 64: 'nvptx64-nvidia-cuda'} + MACHINE_BITS = tuple.__itemsize__ * 8 + data_layout = { + 32: ('e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-' + 'f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64'), + 64: ('e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-' + 'f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64')} + + default_data_layout = data_layout[MACHINE_BITS] + + def __init__(self): + # super().__init__() + + # self.target = llvm.Target.from_triple(self.CUDA_TRIPLE[self.MACHINE_BITS]) + self._data_layout = self.default_data_layout[self.MACHINE_BITS] + # self._target_data = llvm.create_target_data(self._data_layout) + + def optimize(self): + pmb = llvm.create_pass_manager_builder() + pmb.opt_level = 2 + pmb.disable_unit_at_a_time = False + pmb.loop_vectorize = False + pmb.slp_vectorize = False + # TODO possible to pass for functions + pm = llvm.create_module_pass_manager() + pm.add_instruction_combining_pass() + pm.add_function_attrs_pass() + pm.add_constant_merge_pass() + pm.add_licm_pass() + pmb.populate(pm) + pm.run(self.llvmmod) + pm.run(self.llvmmod) + + def write_ll(self, file): + with open(file, 'w') as f: + f.write(str(self.llvmmod)) + + def parse(self, module): + + llvmmod = module + llvmmod.triple = self.CUDA_TRIPLE[self.MACHINE_BITS] + llvmmod.data_layout = self.default_data_layout + llvmmod.verify() + llvmmod.name = 'module' + + self.module = str(llvmmod) + self._llvmmod = llvm.parse_assembly(str(llvmmod)) + + def compile(self): + from pystencils.cpu.cpujit import get_cache_config + import hashlib + compiler_cache = get_cache_config()['object_cache'] + ir_file = join(compiler_cache, hashlib.md5(str(self._llvmmod).encode()).hexdigest() + '.ll') + ptx_file = ir_file.replace('.ll', '.ptx') + + if not exists(ptx_file): + self.write_ll(ir_file) + try: + from pycuda.driver import Context + arch = "sm_%d%d" % Context.get_device().compute_capability() + except Exception: + arch = "sm_35" + + subprocess.check_call(['llc-10', '-mcpu=' + arch, ir_file, '-o', ptx_file]) + + # TODO: make loading of ptx work + # import pycuda.autoinit + + # def handler(compile_success_bool, info_str, error_str): + # if not compile_success_bool: + # print(info_str) + # print(error_str) + + # # with open(ptx_file, 'rb') as f: + # # ptx_code = f.read() + + # # from pycuda.driver import jit_input_type + # # self.linker.add_data(ptx_code, jit_input_type.PTX, 'foo') + # from pycuda.compiler import DynamicModule + + # from pycuda.driver import jit_input_type + # module = DynamicModule().add_file(ptx_file, jit_input_type.PTX) + # module.link() + # # cuda_module = pycuda.driver.module_from_buffer(ptx_code, message_handler=handler) + # # print(dir(cuda_module)) + # self.fptr = dict() + # module.get_function('kernel') + + def __call__(self, func, *args, **kwargs): + fptr = {} + for func in self.module.functions: + if not func.is_declaration: + return_type = None + if func.ftype.return_type != ir.VoidType(): + return_type = to_ctypes(create_composite_type_from_string(str(func.ftype.return_type))) + args = [ctypes_from_llvm(arg) for arg in func.ftype.args] + function_address = self.ee.get_function_address(func.name) + fptr[func.name] = ct.CFUNCTYPE(return_type, *args)(function_address) + self.fptr = fptr