Skip to content
Snippets Groups Projects
Commit 805f6cc8 authored by Jan Hoenig's avatar Jan Hoenig
Browse files

LLVM-Backend can create a CFUNCTYPES function and run it.

parent 56af5641
No related merge requests found
import llvmlite.ir as ir import llvmlite.ir as ir
import llvmlite.binding as llvm import llvmlite.binding as llvm
import logging.config
from ..types import toCtypes, createType from ..types import toCtypes, createType
import ctypes as ct import ctypes as ct
logger = logging.getLogger(__name__)
def compileLLVM(module): def compileLLVM(module):
return Eval().compile(module) jit = Jit()
jit.parse(module)
jit.optimize()
jit.compile()
return jit
class Eval(object): class Jit(object):
def __init__(self): def __init__(self):
llvm.initialize() llvm.initialize()
llvm.initialize_all_targets() llvm.initialize_all_targets()
llvm.initialize_native_target() llvm.initialize_native_target()
llvm.initialize_native_asmprinter() llvm.initialize_native_asmprinter()
self.module = None
self.llvmmod = None
self.target = llvm.Target.from_default_triple() self.target = llvm.Target.from_default_triple()
self.cpu = llvm.get_host_cpu_name()
self.cpu_features = llvm.get_host_cpu_features()
self.target_machine = self.target.create_target_machine(cpu=self.cpu, features=self.cpu_features.flatten(), opt=2)
self.ee = None
self.fptr = None
def compile(self, module): def parse(self, module):
logger.debug('=============Preparse') self.module = module
logger.debug(str(module))
llvmmod = llvm.parse_assembly(str(module)) llvmmod = llvm.parse_assembly(str(module))
llvmmod.verify() llvmmod.verify()
logger.debug('=============Function in IR') self.llvmmod = llvmmod
logger.debug(str(llvmmod))
# TODO cpu, features, opt
cpu = llvm.get_host_cpu_name()
features = llvm.get_host_cpu_features()
logger.debug('=======Things')
logger.debug(cpu)
logger.debug(features.flatten())
target_machine = self.target.create_target_machine(cpu=cpu, features=features.flatten(), opt=2)
logger.debug('Machine = ' + str(target_machine.target_data)) def write_ll(self, file):
with open(file, 'w') as f:
f.write(str(self.llvmmod))
with open('gen.ll', 'w') as f: def optimize(self):
f.write(str(llvmmod)) pmb = llvm.create_pass_manager_builder()
optimize = True pmb.opt_level = 2
if optimize: pmb.disable_unit_at_a_time = False
pmb = llvm.create_pass_manager_builder() pmb.loop_vectorize = True
pmb.opt_level = 2 pmb.slp_vectorize = True
pmb.disable_unit_at_a_time = False # TODO possible to pass for functions
pmb.loop_vectorize = True pm = llvm.create_module_pass_manager()
pmb.slp_vectorize = True pm.add_instruction_combining_pass()
# TODO possible to pass for functions pm.add_function_attrs_pass()
pm = llvm.create_module_pass_manager() pm.add_constant_merge_pass()
pm.add_instruction_combining_pass() pm.add_licm_pass()
pm.add_function_attrs_pass() pmb.populate(pm)
pm.add_constant_merge_pass() pm.run(self.llvmmod)
pm.add_licm_pass()
pmb.populate(pm)
pm.run(llvmmod)
logger.debug("==========Opt")
logger.debug(str(llvmmod))
with open('gen_opt.ll', 'w') as f:
f.write(str(llvmmod))
with llvm.create_mcjit_compiler(llvmmod, target_machine) as ee: def compile(self, assembly_file=None, object_file=None):
ee.finalize_object() ee = llvm.create_mcjit_compiler(self.llvmmod, self.target_machine)
ee.finalize_object()
logger.debug('==========Machine code') if assembly_file is not None:
logger.debug(target_machine.emit_assembly(llvmmod)) with open(assembly_file, 'w') as f:
with open('gen.S', 'w') as f: f.write(self.target_machine.emit_assembly(self.llvmmod))
f.write(target_machine.emit_assembly(llvmmod)) if object_file is not None:
with open('gen.o', 'wb') as f: with open(object_file, 'wb') as f:
f.write(target_machine.emit_object(llvmmod)) f.write(self.target_machine.emit_object(self.llvmmod))
fptr = {} fptr = {}
for function in module.functions: for function in self.module.functions:
if not function.is_declaration: if not function.is_declaration:
return_type = None
if function.ftype.return_type != ir.VoidType():
return_type = toCtypes(createType(str(function.ftype.return_type)))
args = [toCtypes(createType(str(arg))) for arg in function.ftype.args]
function_address = ee.get_function_address(function.name)
fptr[function.name] = ct.CFUNCTYPE(return_type, *args)(function_address)
self.ee = ee
self.fptr = fptr
print(function.name) def __call__(self, function, *args, **kwargs):
print(type(function)) self.fptr[function](*args, **kwargs)
print(function.ftype.return_type)
print(type(function.ftype.return_type))
return_type = None
if function.ftype.return_type != ir.VoidType():
return_type = toCtypes(createType(str(function.ftype.return_type)))
args = [toCtypes(createType(str(arg))) for arg in function.ftype.args]
function_address = ee.get_function_address(function.name)
fptr[function.name] = ct.CFUNCTYPE(return_type, *args)(function_address)
# result = fptr(2, 3)
# print(result)
return fptr
...@@ -60,8 +60,10 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl ...@@ -60,8 +60,10 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl
resolveFieldAccesses(code, readOnlyFields, fieldToBasePointerInfo=basePointerInfos) resolveFieldAccesses(code, readOnlyFields, fieldToBasePointerInfo=basePointerInfos)
moveConstantsBeforeLoop(code) moveConstantsBeforeLoop(code)
print('Ast:')
print(code) print(code)
desympy_ast(code) desympy_ast(code)
print('Desympied ast:')
print(code) print(code)
insert_casts(code) insert_casts(code)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment