From 4a8659e8f6eb45de7885bf2318078d668312c60d Mon Sep 17 00:00:00 2001 From: Jan Hoenig <hrominium@gmail.com> Date: Fri, 10 Mar 2017 15:06:58 +0100 Subject: [PATCH] it actually somehow comiles --- astnodes.py | 6 ++++ backends/llvm.py | 73 +++++++++++++++++++++++++++++++++++------------- llvm/__init__.py | 1 + llvm/jit.py | 12 ++++---- types.py | 3 ++ 5 files changed, 70 insertions(+), 25 deletions(-) diff --git a/astnodes.py b/astnodes.py index f7a0abac8..a2eaa1e13 100644 --- a/astnodes.py +++ b/astnodes.py @@ -511,4 +511,10 @@ class Number(Node, sp.AtomicExpr): def __repr__(self): return repr(self.value) + def __float__(self): + return float(self.value) + + def __int__(self): + return int(self.value) + diff --git a/backends/llvm.py b/backends/llvm.py index a70627f02..b97f1c4e5 100644 --- a/backends/llvm.py +++ b/backends/llvm.py @@ -1,10 +1,13 @@ import llvmlite.ir as ir +import functools from sympy.printing.printer import Printer from sympy import S # S is numbers? from pystencils.llvm.control_flow import Loop +from ..types import DataType +from ..astnodes import Indexed def generateLLVM(ast_node): @@ -25,6 +28,7 @@ class LLVMPrinter(Printer): self.fp_type = ir.DoubleType() self.fp_pointer = self.fp_type.as_pointer() self.integer = ir.IntType(64) + self.integer_pointer = self.integer.as_pointer() self.void = ir.VoidType() self.module = module self.builder = builder @@ -35,8 +39,13 @@ class LLVMPrinter(Printer): def _add_tmp_var(self, name, value): self.tmp_var[name] = value - def _print_Number(self, n, **kwargs): - return ir.Constant(self.fp_type, n) + def _print_Number(self, n): + if n.dtype == DataType("int"): + return ir.Constant(self.integer, int(n)) + elif n.dtype == DataType("double"): + return ir.Constant(self.fp_type, float(n)) + else: + raise NotImplementedError("Numbers can only have int and double", n) def _print_Float(self, expr): return ir.Constant(self.fp_type, expr.p) @@ -81,16 +90,23 @@ class LLVMPrinter(Printer): def _print_Mul(self, expr): nodes = [self._print(a) for a in expr.args] e = nodes[0] + if expr.dtype == DataType('double'): + mul = self.builder.fmul + else: # int TODO others? + mul = self.builder.mul for node in nodes[1:]: - e = self.builder.fmul(e, node) + e = mul(e, node) return e def _print_Add(self, expr): nodes = [self._print(a) for a in expr.args] e = nodes[0] + if expr.dtype == DataType('double'): + add = self.builder.fadd + else: # int TODO others? + add = self.builder.add for node in nodes[1:]: - print(e, node) - e = self.builder.fadd(e, node) + e = add(e, node) return e def _print_KernelFunction(self, function): @@ -118,6 +134,7 @@ class LLVMPrinter(Printer): block = fn.append_basic_block(name="entry") self.builder = ir.IRBuilder(block) self._print(function.body) + self.builder.ret_void() self.fn = fn return fn @@ -129,29 +146,47 @@ class LLVMPrinter(Printer): with Loop(self.builder, self._print(loop.start), self._print(loop.stop), self._print(loop.step), loop.loopCounterName, loop.loopCounterSymbol.name) as i: self._add_tmp_var(loop.loopCounterSymbol, i) + # TODO remove tmp var self._print(loop.body) def _print_SympyAssignment(self, assignment): expr = self._print(assignment.rhs) + lhs = assignment.lhs + if isinstance(lhs, Indexed): + ptr = self._print(lhs.base.label) + index = self._print(lhs.args[1]) + gep = self.builder.gep(ptr, [index]) + return self.builder.store(expr, gep) + self.func_arg_map[assignment.lhs.name] = expr + return expr def _print_Conversion(self, conversion): + node = self._print(conversion.args[0]) to_dtype = conversion.dtype from_dtype = conversion.args[0].dtype - print(to_dtype, from_dtype) - # fp -> int: fptosi - # int -> fp: sitofp - # ptr -> int: ptrtoint - # int -> ptr: inttoptr - # ?bitcast, ?addrspacecast + # (From, to) + decision = { + (DataType("int"), DataType("double")): functools.partial(self.builder.sitofp, node, self.fp_type), + (DataType("double"), DataType("int")): functools.partial(self.builder.fptosi, node, self.integer), + (DataType("double *"), DataType("int")): functools.partial(self.builder.ptrtoint, node, self.integer), + (DataType("int"), DataType("double *")): functools.partial(self.builder.inttoptr, node, self.fp_pointer), + (DataType("double * __restrict__"), DataType("int")): functools.partial(self.builder.ptrtoint, node, self.integer), + (DataType("int"), DataType("double * __restrict__")): functools.partial(self.builder.inttoptr, node, self.fp_pointer), + (DataType("const double * __restrict__"), DataType("int")): functools.partial(self.builder.ptrtoint, node, self.integer), + (DataType("int"), DataType("const double * __restrict__")): functools.partial(self.builder.inttoptr, node, self.fp_pointer), + } + # TODO float, const, restrict + # TODO bitcast, addrspacecast + return decision[(from_dtype, to_dtype)]() def _print_Indexed(self, indexed): - pass + ptr = self._print(indexed.base.label) + index = self._print(indexed.args[1]) + gep = self.builder.gep(ptr, [index]) + return self.builder.load(gep, name=indexed.base.label.name) - - - # Should have a list of math library functions to validate this. - - # TODO delete this -> NO this should be a function call + # Should have a list of math library functions to validate this. + # TODO function calls def _print_Function(self, expr): name = expr.func.__name__ e0 = self._print(expr.args[0]) @@ -163,5 +198,5 @@ class LLVMPrinter(Printer): return self.builder.call(fn, [e0], name) def emptyPrinter(self, expr): - raise TypeError("Unsupported type for LLVM JIT conversion: %s" - % type(expr)) + raise TypeError("Unsupported type for LLVM JIT conversion: %s %s" + % type(expr), expr) diff --git a/llvm/__init__.py b/llvm/__init__.py index da5dfa39d..f34532f8b 100644 --- a/llvm/__init__.py +++ b/llvm/__init__.py @@ -1 +1,2 @@ from .kernelcreation import createKernel +from .jit import compileLLVM \ No newline at end of file diff --git a/llvm/jit.py b/llvm/jit.py index 8e6fdb56f..2b13d7e7b 100644 --- a/llvm/jit.py +++ b/llvm/jit.py @@ -1,6 +1,12 @@ import llvmlite.binding as llvm import logging.config +logger = logging.getLogger(__name__) + + +def compileLLVM(module): + return Eval().compile(module) + class Eval(object): def __init__(self): @@ -63,9 +69,3 @@ class Eval(object): # result = fptr(2, 3) # print(result) return 0 - - -if __name__ == "__main__": - logger = logging.getLogger(__name__) -else: - logger = logging.getLogger(__name__) diff --git a/types.py b/types.py index 3550de398..0fd58daa7 100644 --- a/types.py +++ b/types.py @@ -70,6 +70,9 @@ class DataType(object): if self.dtype > other.dtype: return True + def __hash__(self): + return hash(repr(self)) + def get_type_from_sympy(node): # Rational, NumberSymbol? -- GitLab