llvm.py 15.3 KB
Newer Older
Jan Hoenig's avatar
Jan Hoenig committed
1
import functools
Martin Bauer's avatar
Martin Bauer committed
2

Martin Bauer's avatar
Martin Bauer committed
3
import llvmlite.ir as ir
4
import llvmlite.llvmpy.core as lc
Martin Bauer's avatar
Martin Bauer committed
5
6
7
8
import sympy as sp
from sympy import Indexed, S
from sympy.printing.printer import Printer

Martin Bauer's avatar
Martin Bauer committed
9
from pystencils.assignment import Assignment
Martin Bauer's avatar
Martin Bauer committed
10
11
12
from pystencils.data_types import (
    collate_types, create_composite_type_from_string, create_type, get_type_of_expression,
    to_llvm_type)
Jan Hoenig's avatar
Jan Hoenig committed
13
from pystencils.llvm.control_flow import Loop
Jan Hoenig's avatar
Jan Hoenig committed
14
15


16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# From Numba
def set_cuda_kernel(lfunc):
    from llvmlite.llvmpy.core import MetaData, MetaDataString, Constant, Type

    m = lfunc.module

    ops = lfunc, MetaDataString.get(m, "kernel"), Constant.int(Type.int(), 1)
    md = MetaData.get(m, ops)

    nmd = m.get_or_insert_named_metadata('nvvm.annotations')
    nmd.add(md)

    # set nvvm ir version
    i32 = ir.IntType(32)
    md_ver = m.add_metadata([i32(1), i32(2), i32(2), i32(0)])
    m.add_named_metadata('nvvmir.version', md_ver)


34
35
36
37
38
39
40
41
# From Numba
def _call_sreg(builder, name):
    module = builder.module
    fnty = lc.Type.function(lc.Type.int(), ())
    fn = module.get_or_insert_function(fnty, name=name)
    return builder.call(fn, ())


42
def generate_llvm(ast_node, module=None, builder=None, target='cpu'):
Martin Bauer's avatar
Martin Bauer committed
43
    """Prints the ast as llvm code."""
44
    if module is None:
45
        module = lc.Module()
46
47
    if builder is None:
        builder = ir.IRBuilder()
48
    printer = LLVMPrinter(module, builder, target=target)
Martin Bauer's avatar
Martin Bauer committed
49
    return printer._print(ast_node)
Jan Hoenig's avatar
Jan Hoenig committed
50

51

Martin Bauer's avatar
Martin Bauer committed
52
# noinspection PyPep8Naming
53
54
class LLVMPrinter(Printer):
    """Convert expressions to LLVM IR"""
55

56
    def __init__(self, module, builder, fn=None, target='cpu', *args, **kwargs):
57
58
59
        self.func_arg_map = kwargs.pop("func_arg_map", {})
        super(LLVMPrinter, self).__init__(*args, **kwargs)
        self.fp_type = ir.DoubleType()
60
61
        self.fp_pointer = self.fp_type.as_pointer()
        self.integer = ir.IntType(64)
Jan Hoenig's avatar
Jan Hoenig committed
62
        self.integer_pointer = self.integer.as_pointer()
63
        self.void = ir.VoidType()
64
65
66
67
68
        self.module = module
        self.builder = builder
        self.fn = fn
        self.ext_fn = {}  # keep track of wrappers to external functions
        self.tmp_var = {}
69
        self.target = target
70
71
72
73

    def _add_tmp_var(self, name, value):
        self.tmp_var[name] = value

74
75
76
    def _remove_tmp_var(self, name):
        del self.tmp_var[name]

Jan Hoenig's avatar
Jan Hoenig committed
77
    def _print_Number(self, n):
Martin Bauer's avatar
Martin Bauer committed
78
        if get_type_of_expression(n) == create_type("int"):
Jan Hoenig's avatar
Jan Hoenig committed
79
            return ir.Constant(self.integer, int(n))
Martin Bauer's avatar
Martin Bauer committed
80
        elif get_type_of_expression(n) == create_type("double"):
Jan Hoenig's avatar
Jan Hoenig committed
81
82
83
            return ir.Constant(self.fp_type, float(n))
        else:
            raise NotImplementedError("Numbers can only have int and double", n)
84

85
    def _print_Float(self, expr):
86
        return ir.Constant(self.fp_type, float(expr))
87

88
    def _print_Integer(self, expr):
89
        return ir.Constant(self.integer, int(expr))
90

Jan Hoenig's avatar
Jan Hoenig committed
91
92
93
    def _print_int(self, i):
        return ir.Constant(self.integer, i)

94
95
96
97
    def _print_Symbol(self, s):
        val = self.tmp_var.get(s)
        if not val:
            # look up parameter with name s
98
            val = self.func_arg_map.get(s.name)
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
        if not val:
            raise LookupError("Symbol not found: %s" % s)
        return val

    def _print_Pow(self, expr):
        base0 = self._print(expr.base)
        if expr.exp == S.NegativeOne:
            return self.builder.fdiv(ir.Constant(self.fp_type, 1.0), base0)
        if expr.exp == S.Half:
            fn = self.ext_fn.get("sqrt")
            if not fn:
                fn_type = ir.FunctionType(self.fp_type, [self.fp_type])
                fn = ir.Function(self.module, fn_type, "sqrt")
                self.ext_fn["sqrt"] = fn
            return self.builder.call(fn, [base0], "sqrt")
        if expr.exp == 2:
            return self.builder.fmul(base0, base0)
116
117
        elif expr.exp == 3:
            return self.builder.fmul(self.builder.fmul(base0, base0), base0)
118
119
120
121
122
123
124
125
126
127
128
129

        exp0 = self._print(expr.exp)
        fn = self.ext_fn.get("pow")
        if not fn:
            fn_type = ir.FunctionType(self.fp_type, [self.fp_type, self.fp_type])
            fn = ir.Function(self.module, fn_type, "pow")
            self.ext_fn["pow"] = fn
        return self.builder.call(fn, [base0, exp0], "pow")

    def _print_Mul(self, expr):
        nodes = [self._print(a) for a in expr.args]
        e = nodes[0]
Martin Bauer's avatar
Martin Bauer committed
130
        if get_type_of_expression(expr) == create_type('double'):
Jan Hoenig's avatar
Jan Hoenig committed
131
            mul = self.builder.fmul
132
        else:  # int TODO unsigned/signed
Jan Hoenig's avatar
Jan Hoenig committed
133
            mul = self.builder.mul
134
        for node in nodes[1:]:
Jan Hoenig's avatar
Jan Hoenig committed
135
            e = mul(e, node)
136
137
138
139
140
        return e

    def _print_Add(self, expr):
        nodes = [self._print(a) for a in expr.args]
        e = nodes[0]
Martin Bauer's avatar
Martin Bauer committed
141
        if get_type_of_expression(expr) == create_type('double'):
Jan Hoenig's avatar
Jan Hoenig committed
142
            add = self.builder.fadd
143
        else:  # int TODO unsigned/signed
Jan Hoenig's avatar
Jan Hoenig committed
144
            add = self.builder.add
145
        for node in nodes[1:]:
Jan Hoenig's avatar
Jan Hoenig committed
146
            e = add(e, node)
147
148
        return e

149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
    def _print_Or(self, expr):
        nodes = [self._print(a) for a in expr.args]
        e = nodes[0]
        for node in nodes[1:]:
            e = self.builder.or_(e, node)
        return e

    def _print_And(self, expr):
        nodes = [self._print(a) for a in expr.args]
        e = nodes[0]
        for node in nodes[1:]:
            e = self.builder.and_(e, node)
        return e

    def _print_StrictLessThan(self, expr):
        return self._comparison('<', expr)

    def _print_LessThan(self, expr):
        return self._comparison('<=', expr)

    def _print_StrictGreaterThan(self, expr):
        return self._comparison('>', expr)

    def _print_GreaterThan(self, expr):
        return self._comparison('>=', expr)

    def _print_Unequality(self, expr):
        return self._comparison('!=', expr)

    def _print_Equality(self, expr):
        return self._comparison('==', expr)

    def _comparison(self, cmpop, expr):
Martin Bauer's avatar
Martin Bauer committed
182
        if collate_types([get_type_of_expression(arg) for arg in expr.args]) == create_type('double'):
183
184
185
186
187
            comparison = self.builder.fcmp_unordered
        else:
            comparison = self.builder.icmp_signed
        return comparison(cmpop, self._print(expr.lhs), self._print(expr.rhs))

Martin Bauer's avatar
Martin Bauer committed
188
    def _print_KernelFunction(self, func):
189
        # KernelFunction does not posses a return type
190
191
        return_type = self.void
        parameter_type = []
192
193
        parameters = func.get_parameters()
        for parameter in parameters:
194
            parameter_type.append(to_llvm_type(parameter.symbol.dtype, nvvm_target=self.target == 'gpu'))
195
        func_type = ir.FunctionType(return_type, tuple(parameter_type))
Martin Bauer's avatar
Martin Bauer committed
196
        name = func.function_name
197
198
        fn = ir.Function(self.module, func_type, name)
        self.ext_fn[name] = fn
199

Jan Hoenig's avatar
Jan Hoenig committed
200
        # set proper names to arguments
201
        for i, arg in enumerate(fn.args):
202
203
            arg.name = parameters[i].symbol.name
            self.func_arg_map[parameters[i].symbol.name] = arg
Jan Hoenig's avatar
Jan Hoenig committed
204

205
206
        # func.attributes.add("inlinehint")
        # func.attributes.add("argmemonly")
207
        block = fn.append_basic_block(name="entry")
208
        self.builder = ir.IRBuilder(block)  # TODO use goto_block instead
Martin Bauer's avatar
Martin Bauer committed
209
        self._print(func.body)
Jan Hoenig's avatar
Jan Hoenig committed
210
        self.builder.ret_void()
211
        self.fn = fn
212
213
214
        if self.target == 'gpu':
            set_cuda_kernel(fn)

215
        return fn
216

Jan Hoenig's avatar
Jan Hoenig committed
217
    def _print_Block(self, block):
218
        for node in block.args:
Jan Hoenig's avatar
Jan Hoenig committed
219
            self._print(node)
220

Jan Hoenig's avatar
Jan Hoenig committed
221
222
    def _print_LoopOverCoordinate(self, loop):
        with Loop(self.builder, self._print(loop.start), self._print(loop.stop), self._print(loop.step),
Martin Bauer's avatar
Martin Bauer committed
223
224
                  loop.loop_counter_name, loop.loop_counter_symbol.name) as i:
            self._add_tmp_var(loop.loop_counter_symbol, i)
Jan Hoenig's avatar
Jan Hoenig committed
225
            self._print(loop.body)
Martin Bauer's avatar
Martin Bauer committed
226
            self._remove_tmp_var(loop.loop_counter_symbol)
227

Jan Hoenig's avatar
Jan Hoenig committed
228
229
    def _print_SympyAssignment(self, assignment):
        expr = self._print(assignment.rhs)
Jan Hoenig's avatar
Jan Hoenig committed
230
231
232
233
234
235
236
237
        lhs = assignment.lhs
        if isinstance(lhs, Indexed):
            ptr = self._print(lhs.base.label)
            index = self._print(lhs.args[1])
            gep = self.builder.gep(ptr, [index])
            return self.builder.store(expr, gep)
        self.func_arg_map[assignment.lhs.name] = expr
        return expr
238
239
240

    def _print_boolean_cast_func(self, conversion):
        return self._print_cast_func(conversion)
Jan Hoenig's avatar
Jan Hoenig committed
241

Martin Bauer's avatar
Martin Bauer committed
242
    def _print_cast_func(self, conversion):
Jan Hoenig's avatar
Jan Hoenig committed
243
        node = self._print(conversion.args[0])
Martin Bauer's avatar
Martin Bauer committed
244
245
        to_dtype = get_type_of_expression(conversion)
        from_dtype = get_type_of_expression(conversion.args[0])
246
247
248
        if from_dtype == to_dtype:
            return self._print(conversion.args[0])

Jan Hoenig's avatar
Jan Hoenig committed
249
250
        # (From, to)
        decision = {
251
252
            (create_composite_type_from_string("int32"),
             create_composite_type_from_string("int64")): functools.partial(self.builder.zext, node, self.integer),
253
            (create_composite_type_from_string("int16"),
254
             create_composite_type_from_string("int64")): functools.partial(self.builder.zext, node, self.integer),
Martin Bauer's avatar
Martin Bauer committed
255
256
            (create_composite_type_from_string("int"),
             create_composite_type_from_string("double")): functools.partial(self.builder.sitofp, node, self.fp_type),
257
258
            (create_composite_type_from_string("int16"),
             create_composite_type_from_string("double")): functools.partial(self.builder.sitofp, node, self.fp_type),
Martin Bauer's avatar
Martin Bauer committed
259
260
261
262
263
264
265
266
267
            (create_composite_type_from_string("double"),
             create_composite_type_from_string("int")): functools.partial(self.builder.fptosi, node, self.integer),
            (create_composite_type_from_string("double *"),
             create_composite_type_from_string("int")): functools.partial(self.builder.ptrtoint, node, self.integer),
            (create_composite_type_from_string("int"),
             create_composite_type_from_string("double *")): functools.partial(self.builder.inttoptr,
                                                                               node, self.fp_pointer),
            (create_composite_type_from_string("double * restrict"),
             create_composite_type_from_string("int")): functools.partial(self.builder.ptrtoint, node, self.integer),
Martin Bauer's avatar
Martin Bauer committed
268
269
270
271
272
273
274
            (create_composite_type_from_string("int"),
             create_composite_type_from_string("double * restrict")): functools.partial(self.builder.inttoptr, node,
                                                                                        self.fp_pointer),
            (create_composite_type_from_string("double * restrict const"),
             create_composite_type_from_string("int")): functools.partial(self.builder.ptrtoint, node,
                                                                          self.integer),
            (create_composite_type_from_string("int"),
Martin Bauer's avatar
Martin Bauer committed
275
276
             create_composite_type_from_string("double * restrict const")): functools.partial(self.builder.inttoptr,
                                                                                              node, self.fp_pointer),
277
        }
278
        # TODO float, TEST: const, restrict
Jan Hoenig's avatar
Jan Hoenig committed
279
        # TODO bitcast, addrspacecast
280
        # TODO unsigned/signed fills
Jan Hoenig's avatar
Jan Hoenig committed
281
282
283
284
285
        # print([x for x in decision.keys()])
        # print("Types:")
        # print([(type(x), type(y)) for (x, y) in decision.keys()])
        # print("Cast:")
        # print((from_dtype, to_dtype))
Jan Hoenig's avatar
Jan Hoenig committed
286
        return decision[(from_dtype, to_dtype)]()
Jan Hoenig's avatar
Jan Hoenig committed
287

Martin Bauer's avatar
Martin Bauer committed
288
    def _print_pointer_arithmetic_func(self, pointer):
289
290
291
292
        ptr = self._print(pointer.args[0])
        index = self._print(pointer.args[1])
        return self.builder.gep(ptr, [index])

Jan Hoenig's avatar
Jan Hoenig committed
293
    def _print_Indexed(self, indexed):
Jan Hoenig's avatar
Jan Hoenig committed
294
295
296
297
        ptr = self._print(indexed.base.label)
        index = self._print(indexed.args[1])
        gep = self.builder.gep(ptr, [index])
        return self.builder.load(gep, name=indexed.base.label.name)
Jan Hoenig's avatar
Jan Hoenig committed
298

299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
    def _print_Piecewise(self, piece):
        if not piece.args[-1].cond:
            # We need the last conditional to be a True, otherwise the resulting
            # function may not return a result.
            raise ValueError("All Piecewise expressions must contain an "
                             "(expr, True) statement to be used as a default "
                             "condition. Without one, the generated "
                             "expression may not evaluate to anything under "
                             "some condition.")
        if piece.has(Assignment):
            raise NotImplementedError('The llvm-backend does not support assignments'
                                      'in the Piecewise function. It is questionable'
                                      'whether to implement it. So far there is no'
                                      'use-case to test it.')
        else:
Martin Bauer's avatar
Martin Bauer committed
314
            phi_data = []
315
316
            after_block = self.builder.append_basic_block()
            for (expr, condition) in piece.args:
Martin Bauer's avatar
Martin Bauer committed
317
                if condition == sp.sympify(True):  # Don't use 'is' use '=='!
Martin Bauer's avatar
Martin Bauer committed
318
                    phi_data.append((self._print(expr), self.builder.block))
319
320
321
322
                    self.builder.branch(after_block)
                    self.builder.position_at_end(after_block)
                else:
                    cond = self._print(condition)
Martin Bauer's avatar
Martin Bauer committed
323
324
325
326
327
                    true_block = self.builder.append_basic_block()
                    false_block = self.builder.append_basic_block()
                    self.builder.cbranch(cond, true_block, false_block)
                    self.builder.position_at_end(true_block)
                    phi_data.append((self._print(expr), true_block))
328
                    self.builder.branch(after_block)
Martin Bauer's avatar
Martin Bauer committed
329
                    self.builder.position_at_end(false_block)
330

331
            phi = self.builder.phi(to_llvm_type(get_type_of_expression(piece), nvvm_target=self.target == 'gpu'))
Martin Bauer's avatar
Martin Bauer committed
332
            for (val, block) in phi_data:
333
334
335
                phi.add_incoming(val, block)
            return phi

336
337
338
339
340
341
342
343
344
345
    def _print_Conditional(self, node):
        cond = self._print(node.condition_expr)
        with self.builder.if_else(cond) as (then, otherwise):
            with then:
                self._print(node.true_block)       # emit instructions for when the predicate is true
            with otherwise:
                self._print(node.false_block)       # emit instructions for when the predicate is true

        # No return!

346
    def _print_Function(self, expr):
347
        name = expr.func.__name__
348
349
350
351
352
353
354
355
        e0 = self._print(expr.args[0])
        fn = self.ext_fn.get(name)
        if not fn:
            fn_type = ir.FunctionType(self.fp_type, [self.fp_type])
            fn = ir.Function(self.module, fn_type, name)
            self.ext_fn[name] = fn
        return self.builder.call(fn, [e0], name)

Martin Bauer's avatar
Martin Bauer committed
356
    def empty_printer(self, expr):
357
358
359
360
361
362
363
        try:
            import inspect
            mro = inspect.getmro(expr)
        except AttributeError:
            mro = "None"
        raise TypeError("Unsupported type for LLVM JIT conversion: Expression:\"%s\", Type:\"%s\", MRO:%s"
                        % (expr, type(expr), mro))
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379

    # from: https://llvm.org/docs/NVPTXUsage.html#nvptx-intrinsics
    INDEXING_FUNCTION_MAPPING = {
        'blockIdx': 'llvm.nvvm.read.ptx.sreg.ctaid',
        'threadIdx': 'llvm.nvvm.read.ptx.sreg.tid',
        'blockDim': 'llvm.nvvm.read.ptx.sreg.ntid',
        'gridDim': 'llvm.nvvm.read.ptx.sreg.nctaid'
    }

    def _print_ThreadIndexingSymbol(self, node):
        symbol_name: str = node.name
        function_name, dimension = tuple(symbol_name.split("."))
        function_name = self.INDEXING_FUNCTION_MAPPING[function_name]
        name = f"{function_name}.{dimension}"

        return self.builder.zext(_call_sreg(self.builder, name), self.integer)