llvmjit.py 12.2 KB
Newer Older
Martin Bauer's avatar
Martin Bauer committed
1
import ctypes as ct
2
import subprocess
3
4
from functools import partial
from itertools import chain
5
from os.path import exists, join
Martin Bauer's avatar
Martin Bauer committed
6

7
import llvmlite.binding as llvm
Martin Bauer's avatar
Martin Bauer committed
8
import llvmlite.ir as ir
9
import numpy as np
Martin Bauer's avatar
Martin Bauer committed
10

Martin Bauer's avatar
Martin Bauer committed
11
from pystencils.data_types import create_composite_type_from_string
Martin Bauer's avatar
Martin Bauer committed
12
13
from pystencils.field import FieldType

Martin Bauer's avatar
Martin Bauer committed
14
15
16
from ..data_types import StructType, ctypes_from_llvm, to_ctypes
from .llvm import generate_llvm

Martin Bauer's avatar
Martin Bauer committed
17
18

def build_ctypes_argument_list(parameter_specification, argument_dict):
19
    argument_dict = {k: v for k, v in argument_dict.items()}
Martin Bauer's avatar
Martin Bauer committed
20
21
22
23
    ct_arguments = []
    array_shapes = set()
    index_arr_shapes = set()

24
25
    for param in parameter_specification:
        if param.is_field_parameter:
Martin Bauer's avatar
Martin Bauer committed
26
            try:
27
                field_arr = argument_dict[param.field_name]
Martin Bauer's avatar
Martin Bauer committed
28
            except KeyError:
29
                raise KeyError("Missing field parameter for kernel call " + param.field_name)
Martin Bauer's avatar
Martin Bauer committed
30

31
32
33
            symbolic_field = param.fields[0]
            if param.is_field_pointer:
                ct_arguments.append(field_arr.ctypes.data_as(to_ctypes(param.symbol.dtype)))
Martin Bauer's avatar
Martin Bauer committed
34
35
36
37
38
39
                if symbolic_field.has_fixed_shape:
                    symbolic_field_shape = tuple(int(i) for i in symbolic_field.shape)
                    if isinstance(symbolic_field.dtype, StructType):
                        symbolic_field_shape = symbolic_field_shape[:-1]
                    if symbolic_field_shape != field_arr.shape:
                        raise ValueError("Passed array '%s' has shape %s which does not match expected shape %s" %
40
                                         (param.field_name, str(field_arr.shape), str(symbolic_field.shape)))
Martin Bauer's avatar
Martin Bauer committed
41
42
43
44
45
46
                if symbolic_field.has_fixed_shape:
                    symbolic_field_strides = tuple(int(i) * field_arr.itemsize for i in symbolic_field.strides)
                    if isinstance(symbolic_field.dtype, StructType):
                        symbolic_field_strides = symbolic_field_strides[:-1]
                    if symbolic_field_strides != field_arr.strides:
                        raise ValueError("Passed array '%s' has strides %s which does not match expected strides %s" %
47
                                         (param.field_name, str(field_arr.strides), str(symbolic_field_strides)))
Martin Bauer's avatar
Martin Bauer committed
48
49
50
51
52
53

                if FieldType.is_indexed(symbolic_field):
                    index_arr_shapes.add(field_arr.shape[:symbolic_field.spatial_dimensions])
                elif FieldType.is_generic(symbolic_field):
                    array_shapes.add(field_arr.shape[:symbolic_field.spatial_dimensions])

54
55
56
57
58
59
60
61
            elif param.is_field_shape:
                data_type = to_ctypes(param.symbol.dtype)
                ct_arguments.append(data_type(field_arr.shape[param.symbol.coordinate]))
            elif param.is_field_stride:
                data_type = to_ctypes(param.symbol.dtype)
                assert field_arr.strides[param.symbol.coordinate] % field_arr.itemsize == 0
                item_stride = field_arr.strides[param.symbol.coordinate] // field_arr.itemsize
                ct_arguments.append(data_type(item_stride))
Martin Bauer's avatar
Martin Bauer committed
62
63
64
65
            else:
                assert False
        else:
            try:
66
                value = argument_dict[param.symbol.name]
Martin Bauer's avatar
Martin Bauer committed
67
            except KeyError:
68
69
70
                raise KeyError("Missing parameter for kernel call " + param.symbol.name)
            expected_type = to_ctypes(param.symbol.dtype)
            ct_arguments.append(expected_type(value))
Martin Bauer's avatar
Martin Bauer committed
71
72
73
74
75
76
77
78
79
80

    if len(array_shapes) > 1:
        raise ValueError("All passed arrays have to have the same size " + str(array_shapes))
    if len(index_arr_shapes) > 1:
        raise ValueError("All passed index arrays have to have the same size " + str(array_shapes))

    return ct_arguments


def make_python_function_incomplete_params(kernel_function_node, argument_dict, func):
81
    parameters = kernel_function_node.get_parameters()
Martin Bauer's avatar
Martin Bauer committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99

    cache = {}
    cache_values = []

    def wrapper(**kwargs):
        key = hash(tuple((k, v.ctypes.data, v.strides, v.shape) if isinstance(v, np.ndarray) else (k, id(v))
                         for k, v in kwargs.items()))
        try:
            args = cache[key]
            func(*args)
        except KeyError:
            full_arguments = argument_dict.copy()
            full_arguments.update(kwargs)
            args = build_ctypes_argument_list(parameters, full_arguments)
            cache[key] = args
            cache_values.append(kwargs)  # keep objects alive such that ids remain unique
            func(*args)
    wrapper.ast = kernel_function_node
100
    wrapper.parameters = kernel_function_node.get_parameters()
Martin Bauer's avatar
Martin Bauer committed
101
    return wrapper
102
103
104


def generate_and_jit(ast):
105
    target = 'gpu' if ast._backend == 'llvm_gpu' else 'cpu'
106
    gen = generate_llvm(ast, target=target)
107
    if isinstance(gen, ir.Module):
108
        return compile_llvm(gen, target, ast)
109
    else:
110
        return compile_llvm(gen.module, target, ast)
111
112


Martin Bauer's avatar
Martin Bauer committed
113
def make_python_function(ast, argument_dict={}, func=None):
114
115
    if func is None:
        jit = generate_and_jit(ast)
Martin Bauer's avatar
Martin Bauer committed
116
        func = jit.get_function_ptr(ast.function_name)
117
    try:
118
        args = build_ctypes_argument_list(ast.get_parameters(), argument_dict)
119
120
    except KeyError:
        # not all parameters specified yet
Martin Bauer's avatar
Martin Bauer committed
121
        return make_python_function_incomplete_params(ast, argument_dict, func)
122
123
124
    return lambda: func(*args)


125
126
def compile_llvm(module, target='cpu', ast=None):
    jit = CudaJit(ast) if target == "gpu" else Jit()
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    jit.parse(module)
    jit.optimize()
    jit.compile()
    return jit


class Jit(object):
    def __init__(self):
        llvm.initialize()
        llvm.initialize_all_targets()
        llvm.initialize_native_target()
        llvm.initialize_native_asmprinter()

        self.module = None
        self._llvmmod = llvm.parse_assembly("")
        self.target = llvm.Target.from_default_triple()
        self.cpu = llvm.get_host_cpu_name()
        self.cpu_features = llvm.get_host_cpu_features()
Martin Bauer's avatar
Martin Bauer committed
145
146
        self.target_machine = self.target.create_target_machine(cpu=self.cpu, features=self.cpu_features.flatten(),
                                                                opt=2)
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
        llvm.check_jit_execution()
        self.ee = llvm.create_mcjit_compiler(self.llvmmod, self.target_machine)
        self.ee.finalize_object()
        self.fptr = None

    @property
    def llvmmod(self):
        return self._llvmmod

    @llvmmod.setter
    def llvmmod(self, mod):
        self.ee.remove_module(self.llvmmod)
        self.ee.add_module(mod)
        self.ee.finalize_object()
        self.compile()
        self._llvmmod = mod

    def parse(self, module):
        self.module = module
        llvmmod = llvm.parse_assembly(str(module))
        llvmmod.verify()
        llvmmod.triple = self.target.triple
        llvmmod.name = 'module'
        self.llvmmod = llvmmod

    def write_ll(self, file):
        with open(file, 'w') as f:
            f.write(str(self.llvmmod))

    def write_assembly(self, file):
        with open(file, 'w') as f:
            f.write(self.target_machine.emit_assembly(self.llvmmod))

    def write_object_file(self, file):
        with open(file, 'wb') as f:
            f.write(self.target_machine.emit_object(self.llvmmod))

    def optimize(self):
        pmb = llvm.create_pass_manager_builder()
        pmb.opt_level = 2
        pmb.disable_unit_at_a_time = False
        pmb.loop_vectorize = True
        pmb.slp_vectorize = True
        # TODO possible to pass for functions
        pm = llvm.create_module_pass_manager()
        pm.add_instruction_combining_pass()
        pm.add_function_attrs_pass()
        pm.add_constant_merge_pass()
        pm.add_licm_pass()
        pmb.populate(pm)
        pm.run(self.llvmmod)

    def compile(self):
        fptr = {}
Martin Bauer's avatar
Martin Bauer committed
201
202
        for func in self.module.functions:
            if not func.is_declaration:
203
                return_type = None
Martin Bauer's avatar
Martin Bauer committed
204
205
206
207
208
                if func.ftype.return_type != ir.VoidType():
                    return_type = to_ctypes(create_composite_type_from_string(str(func.ftype.return_type)))
                args = [ctypes_from_llvm(arg) for arg in func.ftype.args]
                function_address = self.ee.get_function_address(func.name)
                fptr[func.name] = ct.CFUNCTYPE(return_type, *args)(function_address)
209
210
        self.fptr = fptr

Martin Bauer's avatar
Martin Bauer committed
211
212
    def __call__(self, func, *args, **kwargs):
        target_function = next(f for f in self.module.functions if f.name == func)
213
214
215
216
217
218
219
220
221
        arg_types = [ctypes_from_llvm(arg.type) for arg in target_function.args]

        transformed_args = []
        for i, arg in enumerate(args):
            if isinstance(arg, np.ndarray):
                transformed_args.append(arg.ctypes.data_as(arg_types[i]))
            else:
                transformed_args.append(arg)

Martin Bauer's avatar
Martin Bauer committed
222
        self.fptr[func](*transformed_args)
223
224
225
226
227
228
229
230
231

    def print_functions(self):
        for f in self.module.functions:
            print(f.ftype.return_type, f.name, f.args)

    def get_function_ptr(self, name):
        fptr = self.fptr[name]
        fptr.jit = self
        return fptr
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247


# Following code more or less from numba
class CudaJit(Jit):

    CUDA_TRIPLE = {32: 'nvptx-nvidia-cuda',
                   64: 'nvptx64-nvidia-cuda'}
    MACHINE_BITS = tuple.__itemsize__ * 8
    data_layout = {
        32: ('e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-'
             'f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64'),
        64: ('e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-'
             'f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64')}

    default_data_layout = data_layout[MACHINE_BITS]

248
    def __init__(self, ast):
249
250
251
252
253
        # super().__init__()

        # self.target = llvm.Target.from_triple(self.CUDA_TRIPLE[self.MACHINE_BITS])
        self._data_layout = self.default_data_layout[self.MACHINE_BITS]
        # self._target_data = llvm.create_target_data(self._data_layout)
254
        self.indexing = ast.indexing
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291

    def optimize(self):
        pmb = llvm.create_pass_manager_builder()
        pmb.opt_level = 2
        pmb.disable_unit_at_a_time = False
        pmb.loop_vectorize = False
        pmb.slp_vectorize = False
        # TODO possible to pass for functions
        pm = llvm.create_module_pass_manager()
        pm.add_instruction_combining_pass()
        pm.add_function_attrs_pass()
        pm.add_constant_merge_pass()
        pm.add_licm_pass()
        pmb.populate(pm)
        pm.run(self.llvmmod)
        pm.run(self.llvmmod)

    def write_ll(self, file):
        with open(file, 'w') as f:
            f.write(str(self.llvmmod))

    def parse(self, module):

        llvmmod = module
        llvmmod.triple = self.CUDA_TRIPLE[self.MACHINE_BITS]
        llvmmod.data_layout = self.default_data_layout
        llvmmod.verify()
        llvmmod.name = 'module'

        self._llvmmod = llvm.parse_assembly(str(llvmmod))

    def compile(self):
        from pystencils.cpu.cpujit import get_cache_config
        import hashlib
        compiler_cache = get_cache_config()['object_cache']
        ir_file = join(compiler_cache, hashlib.md5(str(self._llvmmod).encode()).hexdigest() + '.ll')
        ptx_file = ir_file.replace('.ll', '.ptx')
292
293
294
295
296
        try:
            from pycuda.driver import Context
            arch = "sm_%d%d" % Context.get_device().compute_capability()
        except Exception:
            arch = "sm_35"
297
298
299
300
301

        if not exists(ptx_file):
            self.write_ll(ir_file)
            subprocess.check_call(['llc-10', '-mcpu=' + arch, ir_file, '-o', ptx_file])

302
303
304
305
        # cubin_file = ir_file.replace('.ll', '.cubin')
        # if not exists(cubin_file):
            # subprocess.check_call(['ptxas', '--gpu-name', arch, ptx_file, '-o', cubin_file])
        import pycuda.driver
306

307
308
        cuda_module = pycuda.driver.module_from_file(ptx_file)  # also works: cubin_file
        self.cuda_module = cuda_module
309
310

    def __call__(self, func, *args, **kwargs):
311
312
313
314
315
316
317
318
        shape = [a.shape for a in chain(args, kwargs.values()) if hasattr(a, 'shape')][0]
        block_and_thread_numbers = self.indexing.call_parameters(shape)
        block_and_thread_numbers['block'] = tuple(int(i) for i in block_and_thread_numbers['block'])
        block_and_thread_numbers['grid'] = tuple(int(i) for i in block_and_thread_numbers['grid'])
        self.cuda_module.get_function(func)(*args, **kwargs, **block_and_thread_numbers)

    def get_function_ptr(self, name):
        return partial(self._call__, name)