llvm.py 13.7 KB
Newer Older
Jan Hoenig's avatar
Jan Hoenig committed
1
import functools
Martin Bauer's avatar
Martin Bauer committed
2

Martin Bauer's avatar
Martin Bauer committed
3
import llvmlite.ir as ir
Martin Bauer's avatar
Martin Bauer committed
4
5
6
7
import sympy as sp
from sympy import Indexed, S
from sympy.printing.printer import Printer

Martin Bauer's avatar
Martin Bauer committed
8
from pystencils.assignment import Assignment
Martin Bauer's avatar
Martin Bauer committed
9
10
11
from pystencils.data_types import (
    collate_types, create_composite_type_from_string, create_type, get_type_of_expression,
    to_llvm_type)
Jan Hoenig's avatar
Jan Hoenig committed
12
from pystencils.llvm.control_flow import Loop
Jan Hoenig's avatar
Jan Hoenig committed
13
14


Martin Bauer's avatar
Martin Bauer committed
15
16
def generate_llvm(ast_node, module=None, builder=None):
    """Prints the ast as llvm code."""
17
18
19
20
    if module is None:
        module = ir.Module()
    if builder is None:
        builder = ir.IRBuilder()
Jan Hoenig's avatar
Jan Hoenig committed
21
    printer = LLVMPrinter(module, builder)
Martin Bauer's avatar
Martin Bauer committed
22
    return printer._print(ast_node)
Jan Hoenig's avatar
Jan Hoenig committed
23

24

Martin Bauer's avatar
Martin Bauer committed
25
# noinspection PyPep8Naming
26
27
class LLVMPrinter(Printer):
    """Convert expressions to LLVM IR"""
28

29
    def __init__(self, module, builder, fn=None, *args, **kwargs):
30
31
32
        self.func_arg_map = kwargs.pop("func_arg_map", {})
        super(LLVMPrinter, self).__init__(*args, **kwargs)
        self.fp_type = ir.DoubleType()
33
34
        self.fp_pointer = self.fp_type.as_pointer()
        self.integer = ir.IntType(64)
Jan Hoenig's avatar
Jan Hoenig committed
35
        self.integer_pointer = self.integer.as_pointer()
36
        self.void = ir.VoidType()
37
38
39
40
41
42
43
44
45
        self.module = module
        self.builder = builder
        self.fn = fn
        self.ext_fn = {}  # keep track of wrappers to external functions
        self.tmp_var = {}

    def _add_tmp_var(self, name, value):
        self.tmp_var[name] = value

46
47
48
    def _remove_tmp_var(self, name):
        del self.tmp_var[name]

Jan Hoenig's avatar
Jan Hoenig committed
49
    def _print_Number(self, n):
Martin Bauer's avatar
Martin Bauer committed
50
        if get_type_of_expression(n) == create_type("int"):
Jan Hoenig's avatar
Jan Hoenig committed
51
            return ir.Constant(self.integer, int(n))
Martin Bauer's avatar
Martin Bauer committed
52
        elif get_type_of_expression(n) == create_type("double"):
Jan Hoenig's avatar
Jan Hoenig committed
53
54
55
            return ir.Constant(self.fp_type, float(n))
        else:
            raise NotImplementedError("Numbers can only have int and double", n)
56

57
    def _print_Float(self, expr):
58
        return ir.Constant(self.fp_type, float(expr))
59

60
    def _print_Integer(self, expr):
61
        return ir.Constant(self.integer, int(expr))
62

Jan Hoenig's avatar
Jan Hoenig committed
63
64
65
    def _print_int(self, i):
        return ir.Constant(self.integer, i)

66
67
68
69
    def _print_Symbol(self, s):
        val = self.tmp_var.get(s)
        if not val:
            # look up parameter with name s
70
            val = self.func_arg_map.get(s.name)
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
        if not val:
            raise LookupError("Symbol not found: %s" % s)
        return val

    def _print_Pow(self, expr):
        base0 = self._print(expr.base)
        if expr.exp == S.NegativeOne:
            return self.builder.fdiv(ir.Constant(self.fp_type, 1.0), base0)
        if expr.exp == S.Half:
            fn = self.ext_fn.get("sqrt")
            if not fn:
                fn_type = ir.FunctionType(self.fp_type, [self.fp_type])
                fn = ir.Function(self.module, fn_type, "sqrt")
                self.ext_fn["sqrt"] = fn
            return self.builder.call(fn, [base0], "sqrt")
        if expr.exp == 2:
            return self.builder.fmul(base0, base0)
88
89
        elif expr.exp == 3:
            return self.builder.fmul(self.builder.fmul(base0, base0), base0)
90
91
92
93
94
95
96
97
98
99
100
101

        exp0 = self._print(expr.exp)
        fn = self.ext_fn.get("pow")
        if not fn:
            fn_type = ir.FunctionType(self.fp_type, [self.fp_type, self.fp_type])
            fn = ir.Function(self.module, fn_type, "pow")
            self.ext_fn["pow"] = fn
        return self.builder.call(fn, [base0, exp0], "pow")

    def _print_Mul(self, expr):
        nodes = [self._print(a) for a in expr.args]
        e = nodes[0]
Martin Bauer's avatar
Martin Bauer committed
102
        if get_type_of_expression(expr) == create_type('double'):
Jan Hoenig's avatar
Jan Hoenig committed
103
            mul = self.builder.fmul
104
        else:  # int TODO unsigned/signed
Jan Hoenig's avatar
Jan Hoenig committed
105
            mul = self.builder.mul
106
        for node in nodes[1:]:
Jan Hoenig's avatar
Jan Hoenig committed
107
            e = mul(e, node)
108
109
110
111
112
        return e

    def _print_Add(self, expr):
        nodes = [self._print(a) for a in expr.args]
        e = nodes[0]
Martin Bauer's avatar
Martin Bauer committed
113
        if get_type_of_expression(expr) == create_type('double'):
Jan Hoenig's avatar
Jan Hoenig committed
114
            add = self.builder.fadd
115
        else:  # int TODO unsigned/signed
Jan Hoenig's avatar
Jan Hoenig committed
116
            add = self.builder.add
117
        for node in nodes[1:]:
Jan Hoenig's avatar
Jan Hoenig committed
118
            e = add(e, node)
119
120
        return e

121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
    def _print_Or(self, expr):
        nodes = [self._print(a) for a in expr.args]
        e = nodes[0]
        for node in nodes[1:]:
            e = self.builder.or_(e, node)
        return e

    def _print_And(self, expr):
        nodes = [self._print(a) for a in expr.args]
        e = nodes[0]
        for node in nodes[1:]:
            e = self.builder.and_(e, node)
        return e

    def _print_StrictLessThan(self, expr):
        return self._comparison('<', expr)

    def _print_LessThan(self, expr):
        return self._comparison('<=', expr)

    def _print_StrictGreaterThan(self, expr):
        return self._comparison('>', expr)

    def _print_GreaterThan(self, expr):
        return self._comparison('>=', expr)

    def _print_Unequality(self, expr):
        return self._comparison('!=', expr)

    def _print_Equality(self, expr):
        return self._comparison('==', expr)

    def _comparison(self, cmpop, expr):
Martin Bauer's avatar
Martin Bauer committed
154
        if collate_types([get_type_of_expression(arg) for arg in expr.args]) == create_type('double'):
155
156
157
158
159
            comparison = self.builder.fcmp_unordered
        else:
            comparison = self.builder.icmp_signed
        return comparison(cmpop, self._print(expr.lhs), self._print(expr.rhs))

Martin Bauer's avatar
Martin Bauer committed
160
    def _print_KernelFunction(self, func):
161
        # KernelFunction does not posses a return type
162
163
        return_type = self.void
        parameter_type = []
164
165
166
        parameters = func.get_parameters()
        for parameter in parameters:
            parameter_type.append(to_llvm_type(parameter.symbol.dtype))
167
        func_type = ir.FunctionType(return_type, tuple(parameter_type))
Martin Bauer's avatar
Martin Bauer committed
168
        name = func.function_name
169
170
        fn = ir.Function(self.module, func_type, name)
        self.ext_fn[name] = fn
171

Jan Hoenig's avatar
Jan Hoenig committed
172
        # set proper names to arguments
173
        for i, arg in enumerate(fn.args):
174
175
            arg.name = parameters[i].symbol.name
            self.func_arg_map[parameters[i].symbol.name] = arg
Jan Hoenig's avatar
Jan Hoenig committed
176

177
178
        # func.attributes.add("inlinehint")
        # func.attributes.add("argmemonly")
179
        block = fn.append_basic_block(name="entry")
180
        self.builder = ir.IRBuilder(block)  # TODO use goto_block instead
Martin Bauer's avatar
Martin Bauer committed
181
        self._print(func.body)
Jan Hoenig's avatar
Jan Hoenig committed
182
        self.builder.ret_void()
183
184
        self.fn = fn
        return fn
185

Jan Hoenig's avatar
Jan Hoenig committed
186
    def _print_Block(self, block):
187
        for node in block.args:
Jan Hoenig's avatar
Jan Hoenig committed
188
            self._print(node)
189

Jan Hoenig's avatar
Jan Hoenig committed
190
191
    def _print_LoopOverCoordinate(self, loop):
        with Loop(self.builder, self._print(loop.start), self._print(loop.stop), self._print(loop.step),
Martin Bauer's avatar
Martin Bauer committed
192
193
                  loop.loop_counter_name, loop.loop_counter_symbol.name) as i:
            self._add_tmp_var(loop.loop_counter_symbol, i)
Jan Hoenig's avatar
Jan Hoenig committed
194
            self._print(loop.body)
Martin Bauer's avatar
Martin Bauer committed
195
            self._remove_tmp_var(loop.loop_counter_symbol)
196

Jan Hoenig's avatar
Jan Hoenig committed
197
198
    def _print_SympyAssignment(self, assignment):
        expr = self._print(assignment.rhs)
Jan Hoenig's avatar
Jan Hoenig committed
199
200
201
202
203
204
205
206
        lhs = assignment.lhs
        if isinstance(lhs, Indexed):
            ptr = self._print(lhs.base.label)
            index = self._print(lhs.args[1])
            gep = self.builder.gep(ptr, [index])
            return self.builder.store(expr, gep)
        self.func_arg_map[assignment.lhs.name] = expr
        return expr
207
208
209

    def _print_boolean_cast_func(self, conversion):
        return self._print_cast_func(conversion)
Jan Hoenig's avatar
Jan Hoenig committed
210

Martin Bauer's avatar
Martin Bauer committed
211
    def _print_cast_func(self, conversion):
Jan Hoenig's avatar
Jan Hoenig committed
212
        node = self._print(conversion.args[0])
Martin Bauer's avatar
Martin Bauer committed
213
214
        to_dtype = get_type_of_expression(conversion)
        from_dtype = get_type_of_expression(conversion.args[0])
215
216
217
        if from_dtype == to_dtype:
            return self._print(conversion.args[0])

Jan Hoenig's avatar
Jan Hoenig committed
218
219
        # (From, to)
        decision = {
220
221
            (create_composite_type_from_string("int32"),
             create_composite_type_from_string("int64")): functools.partial(self.builder.zext, node, self.integer),
222
            (create_composite_type_from_string("int16"),
223
             create_composite_type_from_string("int64")): functools.partial(self.builder.zext, node, self.integer),
Martin Bauer's avatar
Martin Bauer committed
224
225
            (create_composite_type_from_string("int"),
             create_composite_type_from_string("double")): functools.partial(self.builder.sitofp, node, self.fp_type),
226
227
            (create_composite_type_from_string("int16"),
             create_composite_type_from_string("double")): functools.partial(self.builder.sitofp, node, self.fp_type),
Martin Bauer's avatar
Martin Bauer committed
228
229
230
231
232
233
234
235
236
            (create_composite_type_from_string("double"),
             create_composite_type_from_string("int")): functools.partial(self.builder.fptosi, node, self.integer),
            (create_composite_type_from_string("double *"),
             create_composite_type_from_string("int")): functools.partial(self.builder.ptrtoint, node, self.integer),
            (create_composite_type_from_string("int"),
             create_composite_type_from_string("double *")): functools.partial(self.builder.inttoptr,
                                                                               node, self.fp_pointer),
            (create_composite_type_from_string("double * restrict"),
             create_composite_type_from_string("int")): functools.partial(self.builder.ptrtoint, node, self.integer),
Martin Bauer's avatar
Martin Bauer committed
237
238
239
240
241
242
243
            (create_composite_type_from_string("int"),
             create_composite_type_from_string("double * restrict")): functools.partial(self.builder.inttoptr, node,
                                                                                        self.fp_pointer),
            (create_composite_type_from_string("double * restrict const"),
             create_composite_type_from_string("int")): functools.partial(self.builder.ptrtoint, node,
                                                                          self.integer),
            (create_composite_type_from_string("int"),
Martin Bauer's avatar
Martin Bauer committed
244
245
             create_composite_type_from_string("double * restrict const")): functools.partial(self.builder.inttoptr,
                                                                                              node, self.fp_pointer),
246
        }
247
        # TODO float, TEST: const, restrict
Jan Hoenig's avatar
Jan Hoenig committed
248
        # TODO bitcast, addrspacecast
249
        # TODO unsigned/signed fills
Jan Hoenig's avatar
Jan Hoenig committed
250
251
252
253
254
        # print([x for x in decision.keys()])
        # print("Types:")
        # print([(type(x), type(y)) for (x, y) in decision.keys()])
        # print("Cast:")
        # print((from_dtype, to_dtype))
Jan Hoenig's avatar
Jan Hoenig committed
255
        return decision[(from_dtype, to_dtype)]()
Jan Hoenig's avatar
Jan Hoenig committed
256

Martin Bauer's avatar
Martin Bauer committed
257
    def _print_pointer_arithmetic_func(self, pointer):
258
259
260
261
        ptr = self._print(pointer.args[0])
        index = self._print(pointer.args[1])
        return self.builder.gep(ptr, [index])

Jan Hoenig's avatar
Jan Hoenig committed
262
    def _print_Indexed(self, indexed):
Jan Hoenig's avatar
Jan Hoenig committed
263
264
265
266
        ptr = self._print(indexed.base.label)
        index = self._print(indexed.args[1])
        gep = self.builder.gep(ptr, [index])
        return self.builder.load(gep, name=indexed.base.label.name)
Jan Hoenig's avatar
Jan Hoenig committed
267

268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
    def _print_Piecewise(self, piece):
        if not piece.args[-1].cond:
            # We need the last conditional to be a True, otherwise the resulting
            # function may not return a result.
            raise ValueError("All Piecewise expressions must contain an "
                             "(expr, True) statement to be used as a default "
                             "condition. Without one, the generated "
                             "expression may not evaluate to anything under "
                             "some condition.")
        if piece.has(Assignment):
            raise NotImplementedError('The llvm-backend does not support assignments'
                                      'in the Piecewise function. It is questionable'
                                      'whether to implement it. So far there is no'
                                      'use-case to test it.')
        else:
Martin Bauer's avatar
Martin Bauer committed
283
            phi_data = []
284
285
            after_block = self.builder.append_basic_block()
            for (expr, condition) in piece.args:
Martin Bauer's avatar
Martin Bauer committed
286
                if condition == sp.sympify(True):  # Don't use 'is' use '=='!
Martin Bauer's avatar
Martin Bauer committed
287
                    phi_data.append((self._print(expr), self.builder.block))
288
289
290
291
                    self.builder.branch(after_block)
                    self.builder.position_at_end(after_block)
                else:
                    cond = self._print(condition)
Martin Bauer's avatar
Martin Bauer committed
292
293
294
295
296
                    true_block = self.builder.append_basic_block()
                    false_block = self.builder.append_basic_block()
                    self.builder.cbranch(cond, true_block, false_block)
                    self.builder.position_at_end(true_block)
                    phi_data.append((self._print(expr), true_block))
297
                    self.builder.branch(after_block)
Martin Bauer's avatar
Martin Bauer committed
298
                    self.builder.position_at_end(false_block)
299

Martin Bauer's avatar
Martin Bauer committed
300
            phi = self.builder.phi(to_llvm_type(get_type_of_expression(piece)))
Martin Bauer's avatar
Martin Bauer committed
301
            for (val, block) in phi_data:
302
303
304
                phi.add_incoming(val, block)
            return phi

305
306
307
308
309
310
311
312
313
314
    def _print_Conditional(self, node):
        cond = self._print(node.condition_expr)
        with self.builder.if_else(cond) as (then, otherwise):
            with then:
                self._print(node.true_block)       # emit instructions for when the predicate is true
            with otherwise:
                self._print(node.false_block)       # emit instructions for when the predicate is true

        # No return!

315
    def _print_Function(self, expr):
316
        name = expr.func.__name__
317
318
319
320
321
322
323
324
        e0 = self._print(expr.args[0])
        fn = self.ext_fn.get(name)
        if not fn:
            fn_type = ir.FunctionType(self.fp_type, [self.fp_type])
            fn = ir.Function(self.module, fn_type, name)
            self.ext_fn[name] = fn
        return self.builder.call(fn, [e0], name)

Martin Bauer's avatar
Martin Bauer committed
325
    def empty_printer(self, expr):
326
327
328
329
330
331
332
        try:
            import inspect
            mro = inspect.getmro(expr)
        except AttributeError:
            mro = "None"
        raise TypeError("Unsupported type for LLVM JIT conversion: Expression:\"%s\", Type:\"%s\", MRO:%s"
                        % (expr, type(expr), mro))