Commit a526fe47 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Make compilation to PTX work (loading of function is WIP)

parent 9d3e1113
import ctypes as ct
import subprocess
from os.path import exists, join
import llvmlite.binding as llvm
import llvmlite.ir as ir
......@@ -98,11 +100,12 @@ def make_python_function_incomplete_params(kernel_function_node, argument_dict,
def generate_and_jit(ast):
target = 'gpu' if ast._backend == 'llvm_gpu' else 'cpu'
gen = generate_llvm(ast)
if isinstance(gen, ir.Module):
return compile_llvm(gen)
return compile_llvm(gen, target)
else:
return compile_llvm(gen.module)
return compile_llvm(gen.module, target)
def make_python_function(ast, argument_dict={}, func=None):
......@@ -117,8 +120,8 @@ def make_python_function(ast, argument_dict={}, func=None):
return lambda: func(*args)
def compile_llvm(module):
jit = Jit()
def compile_llvm(module, target='cpu'):
jit = CudaJit() if target == "gpu" else Jit()
jit.parse(module)
jit.optimize()
jit.compile()
......@@ -224,3 +227,108 @@ class Jit(object):
fptr = self.fptr[name]
fptr.jit = self
return fptr
# Following code more or less from numba
class CudaJit(Jit):
CUDA_TRIPLE = {32: 'nvptx-nvidia-cuda',
64: 'nvptx64-nvidia-cuda'}
MACHINE_BITS = tuple.__itemsize__ * 8
data_layout = {
32: ('e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-'
'f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64'),
64: ('e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-'
'f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64')}
default_data_layout = data_layout[MACHINE_BITS]
def __init__(self):
# super().__init__()
# self.target = llvm.Target.from_triple(self.CUDA_TRIPLE[self.MACHINE_BITS])
self._data_layout = self.default_data_layout[self.MACHINE_BITS]
# self._target_data = llvm.create_target_data(self._data_layout)
def optimize(self):
pmb = llvm.create_pass_manager_builder()
pmb.opt_level = 2
pmb.disable_unit_at_a_time = False
pmb.loop_vectorize = False
pmb.slp_vectorize = False
# TODO possible to pass for functions
pm = llvm.create_module_pass_manager()
pm.add_instruction_combining_pass()
pm.add_function_attrs_pass()
pm.add_constant_merge_pass()
pm.add_licm_pass()
pmb.populate(pm)
pm.run(self.llvmmod)
pm.run(self.llvmmod)
def write_ll(self, file):
with open(file, 'w') as f:
f.write(str(self.llvmmod))
def parse(self, module):
llvmmod = module
llvmmod.triple = self.CUDA_TRIPLE[self.MACHINE_BITS]
llvmmod.data_layout = self.default_data_layout
llvmmod.verify()
llvmmod.name = 'module'
self.module = str(llvmmod)
self._llvmmod = llvm.parse_assembly(str(llvmmod))
def compile(self):
from pystencils.cpu.cpujit import get_cache_config
import hashlib
compiler_cache = get_cache_config()['object_cache']
ir_file = join(compiler_cache, hashlib.md5(str(self._llvmmod).encode()).hexdigest() + '.ll')
ptx_file = ir_file.replace('.ll', '.ptx')
if not exists(ptx_file):
self.write_ll(ir_file)
try:
from pycuda.driver import Context
arch = "sm_%d%d" % Context.get_device().compute_capability()
except Exception:
arch = "sm_35"
subprocess.check_call(['llc-10', '-mcpu=' + arch, ir_file, '-o', ptx_file])
# TODO: make loading of ptx work
# import pycuda.autoinit
# def handler(compile_success_bool, info_str, error_str):
# if not compile_success_bool:
# print(info_str)
# print(error_str)
# # with open(ptx_file, 'rb') as f:
# # ptx_code = f.read()
# # from pycuda.driver import jit_input_type
# # self.linker.add_data(ptx_code, jit_input_type.PTX, 'foo')
# from pycuda.compiler import DynamicModule
# from pycuda.driver import jit_input_type
# module = DynamicModule().add_file(ptx_file, jit_input_type.PTX)
# module.link()
# # cuda_module = pycuda.driver.module_from_buffer(ptx_code, message_handler=handler)
# # print(dir(cuda_module))
# self.fptr = dict()
# module.get_function('kernel')
def __call__(self, func, *args, **kwargs):
fptr = {}
for func in self.module.functions:
if not func.is_declaration:
return_type = None
if func.ftype.return_type != ir.VoidType():
return_type = to_ctypes(create_composite_type_from_string(str(func.ftype.return_type)))
args = [ctypes_from_llvm(arg) for arg in func.ftype.args]
function_address = self.ee.get_function_address(func.name)
fptr[func.name] = ct.CFUNCTYPE(return_type, *args)(function_address)
self.fptr = fptr
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment