llvmjit.py 4.49 KB
Newer Older
1
2
3
4
import llvmlite.ir as ir
import llvmlite.binding as llvm
import numpy as np
import ctypes as ct
Martin Bauer's avatar
Martin Bauer committed
5
6
from pystencils.data_types import create_composite_type_from_string
from ..data_types import to_ctypes, ctypes_from_llvm
Martin Bauer's avatar
Martin Bauer committed
7
8
from .llvm import generate_llvm
from ..cpu.cpujit import build_ctypes_argument_list, make_python_function_incomplete_params
9
10
11


def generate_and_jit(ast):
Martin Bauer's avatar
Martin Bauer committed
12
    gen = generate_llvm(ast)
13
    if isinstance(gen, ir.Module):
Martin Bauer's avatar
Martin Bauer committed
14
        return compile_llvm(gen)
15
    else:
Martin Bauer's avatar
Martin Bauer committed
16
        return compile_llvm(gen.module)
17
18


Martin Bauer's avatar
Martin Bauer committed
19
def make_python_function(ast, argument_dict={}, func=None):
20
21
    if func is None:
        jit = generate_and_jit(ast)
Martin Bauer's avatar
Martin Bauer committed
22
        func = jit.get_function_ptr(ast.function_name)
23
    try:
Martin Bauer's avatar
Martin Bauer committed
24
        args = build_ctypes_argument_list(ast.parameters, argument_dict)
25
26
    except KeyError:
        # not all parameters specified yet
Martin Bauer's avatar
Martin Bauer committed
27
        return make_python_function_incomplete_params(ast, argument_dict, func)
28
29
30
    return lambda: func(*args)


Martin Bauer's avatar
Martin Bauer committed
31
def compile_llvm(module):
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    jit = Jit()
    jit.parse(module)
    jit.optimize()
    jit.compile()
    return jit


class Jit(object):
    def __init__(self):
        llvm.initialize()
        llvm.initialize_all_targets()
        llvm.initialize_native_target()
        llvm.initialize_native_asmprinter()

        self.module = None
        self._llvmmod = llvm.parse_assembly("")
        self.target = llvm.Target.from_default_triple()
        self.cpu = llvm.get_host_cpu_name()
        self.cpu_features = llvm.get_host_cpu_features()
Martin Bauer's avatar
Martin Bauer committed
51
52
        self.target_machine = self.target.create_target_machine(cpu=self.cpu, features=self.cpu_features.flatten(),
                                                                opt=2)
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
        llvm.check_jit_execution()
        self.ee = llvm.create_mcjit_compiler(self.llvmmod, self.target_machine)
        self.ee.finalize_object()
        self.fptr = None

    @property
    def llvmmod(self):
        return self._llvmmod

    @llvmmod.setter
    def llvmmod(self, mod):
        self.ee.remove_module(self.llvmmod)
        self.ee.add_module(mod)
        self.ee.finalize_object()
        self.compile()
        self._llvmmod = mod

    def parse(self, module):
        self.module = module
        llvmmod = llvm.parse_assembly(str(module))
        llvmmod.verify()
        llvmmod.triple = self.target.triple
        llvmmod.name = 'module'
        self.llvmmod = llvmmod

    def write_ll(self, file):
        with open(file, 'w') as f:
            f.write(str(self.llvmmod))

    def write_assembly(self, file):
        with open(file, 'w') as f:
            f.write(self.target_machine.emit_assembly(self.llvmmod))

    def write_object_file(self, file):
        with open(file, 'wb') as f:
            f.write(self.target_machine.emit_object(self.llvmmod))

    def optimize(self):
        pmb = llvm.create_pass_manager_builder()
        pmb.opt_level = 2
        pmb.disable_unit_at_a_time = False
        pmb.loop_vectorize = True
        pmb.slp_vectorize = True
        # TODO possible to pass for functions
        pm = llvm.create_module_pass_manager()
        pm.add_instruction_combining_pass()
        pm.add_function_attrs_pass()
        pm.add_constant_merge_pass()
        pm.add_licm_pass()
        pmb.populate(pm)
        pm.run(self.llvmmod)

    def compile(self):
        fptr = {}
Martin Bauer's avatar
Martin Bauer committed
107
108
        for func in self.module.functions:
            if not func.is_declaration:
109
                return_type = None
Martin Bauer's avatar
Martin Bauer committed
110
111
112
113
114
                if func.ftype.return_type != ir.VoidType():
                    return_type = to_ctypes(create_composite_type_from_string(str(func.ftype.return_type)))
                args = [ctypes_from_llvm(arg) for arg in func.ftype.args]
                function_address = self.ee.get_function_address(func.name)
                fptr[func.name] = ct.CFUNCTYPE(return_type, *args)(function_address)
115
116
        self.fptr = fptr

Martin Bauer's avatar
Martin Bauer committed
117
118
    def __call__(self, func, *args, **kwargs):
        target_function = next(f for f in self.module.functions if f.name == func)
119
120
121
122
123
124
125
126
127
        arg_types = [ctypes_from_llvm(arg.type) for arg in target_function.args]

        transformed_args = []
        for i, arg in enumerate(args):
            if isinstance(arg, np.ndarray):
                transformed_args.append(arg.ctypes.data_as(arg_types[i]))
            else:
                transformed_args.append(arg)

Martin Bauer's avatar
Martin Bauer committed
128
        self.fptr[func](*transformed_args)
129
130
131
132
133
134
135
136
137

    def print_functions(self):
        for f in self.module.functions:
            print(f.ftype.return_type, f.name, f.args)

    def get_function_ptr(self, name):
        fptr = self.fptr[name]
        fptr.jit = self
        return fptr