diff --git a/astnodes.py b/astnodes.py index f7a0abac8e0340cf80cab6e725f7734fdf3f2f2a..a2eaa1e1375b5cc73bc02ff4140d818754db2e16 100644 --- a/astnodes.py +++ b/astnodes.py @@ -511,4 +511,10 @@ class Number(Node, sp.AtomicExpr): def __repr__(self): return repr(self.value) + def __float__(self): + return float(self.value) + + def __int__(self): + return int(self.value) + diff --git a/backends/llvm.py b/backends/llvm.py index a70627f0249486bac1f34c63b94c853f07055346..b97f1c4e5c5ccc02dda07e77e1c93488c0f5076d 100644 --- a/backends/llvm.py +++ b/backends/llvm.py @@ -1,10 +1,13 @@ import llvmlite.ir as ir +import functools from sympy.printing.printer import Printer from sympy import S # S is numbers? from pystencils.llvm.control_flow import Loop +from ..types import DataType +from ..astnodes import Indexed def generateLLVM(ast_node): @@ -25,6 +28,7 @@ class LLVMPrinter(Printer): self.fp_type = ir.DoubleType() self.fp_pointer = self.fp_type.as_pointer() self.integer = ir.IntType(64) + self.integer_pointer = self.integer.as_pointer() self.void = ir.VoidType() self.module = module self.builder = builder @@ -35,8 +39,13 @@ class LLVMPrinter(Printer): def _add_tmp_var(self, name, value): self.tmp_var[name] = value - def _print_Number(self, n, **kwargs): - return ir.Constant(self.fp_type, n) + def _print_Number(self, n): + if n.dtype == DataType("int"): + return ir.Constant(self.integer, int(n)) + elif n.dtype == DataType("double"): + return ir.Constant(self.fp_type, float(n)) + else: + raise NotImplementedError("Numbers can only have int and double", n) def _print_Float(self, expr): return ir.Constant(self.fp_type, expr.p) @@ -81,16 +90,23 @@ class LLVMPrinter(Printer): def _print_Mul(self, expr): nodes = [self._print(a) for a in expr.args] e = nodes[0] + if expr.dtype == DataType('double'): + mul = self.builder.fmul + else: # int TODO others? + mul = self.builder.mul for node in nodes[1:]: - e = self.builder.fmul(e, node) + e = mul(e, node) return e def _print_Add(self, expr): nodes = [self._print(a) for a in expr.args] e = nodes[0] + if expr.dtype == DataType('double'): + add = self.builder.fadd + else: # int TODO others? + add = self.builder.add for node in nodes[1:]: - print(e, node) - e = self.builder.fadd(e, node) + e = add(e, node) return e def _print_KernelFunction(self, function): @@ -118,6 +134,7 @@ class LLVMPrinter(Printer): block = fn.append_basic_block(name="entry") self.builder = ir.IRBuilder(block) self._print(function.body) + self.builder.ret_void() self.fn = fn return fn @@ -129,29 +146,47 @@ class LLVMPrinter(Printer): with Loop(self.builder, self._print(loop.start), self._print(loop.stop), self._print(loop.step), loop.loopCounterName, loop.loopCounterSymbol.name) as i: self._add_tmp_var(loop.loopCounterSymbol, i) + # TODO remove tmp var self._print(loop.body) def _print_SympyAssignment(self, assignment): expr = self._print(assignment.rhs) + lhs = assignment.lhs + if isinstance(lhs, Indexed): + ptr = self._print(lhs.base.label) + index = self._print(lhs.args[1]) + gep = self.builder.gep(ptr, [index]) + return self.builder.store(expr, gep) + self.func_arg_map[assignment.lhs.name] = expr + return expr def _print_Conversion(self, conversion): + node = self._print(conversion.args[0]) to_dtype = conversion.dtype from_dtype = conversion.args[0].dtype - print(to_dtype, from_dtype) - # fp -> int: fptosi - # int -> fp: sitofp - # ptr -> int: ptrtoint - # int -> ptr: inttoptr - # ?bitcast, ?addrspacecast + # (From, to) + decision = { + (DataType("int"), DataType("double")): functools.partial(self.builder.sitofp, node, self.fp_type), + (DataType("double"), DataType("int")): functools.partial(self.builder.fptosi, node, self.integer), + (DataType("double *"), DataType("int")): functools.partial(self.builder.ptrtoint, node, self.integer), + (DataType("int"), DataType("double *")): functools.partial(self.builder.inttoptr, node, self.fp_pointer), + (DataType("double * __restrict__"), DataType("int")): functools.partial(self.builder.ptrtoint, node, self.integer), + (DataType("int"), DataType("double * __restrict__")): functools.partial(self.builder.inttoptr, node, self.fp_pointer), + (DataType("const double * __restrict__"), DataType("int")): functools.partial(self.builder.ptrtoint, node, self.integer), + (DataType("int"), DataType("const double * __restrict__")): functools.partial(self.builder.inttoptr, node, self.fp_pointer), + } + # TODO float, const, restrict + # TODO bitcast, addrspacecast + return decision[(from_dtype, to_dtype)]() def _print_Indexed(self, indexed): - pass + ptr = self._print(indexed.base.label) + index = self._print(indexed.args[1]) + gep = self.builder.gep(ptr, [index]) + return self.builder.load(gep, name=indexed.base.label.name) - - - # Should have a list of math library functions to validate this. - - # TODO delete this -> NO this should be a function call + # Should have a list of math library functions to validate this. + # TODO function calls def _print_Function(self, expr): name = expr.func.__name__ e0 = self._print(expr.args[0]) @@ -163,5 +198,5 @@ class LLVMPrinter(Printer): return self.builder.call(fn, [e0], name) def emptyPrinter(self, expr): - raise TypeError("Unsupported type for LLVM JIT conversion: %s" - % type(expr)) + raise TypeError("Unsupported type for LLVM JIT conversion: %s %s" + % type(expr), expr) diff --git a/llvm/__init__.py b/llvm/__init__.py index da5dfa39db26286f274c84e10aa38dd635b75465..f34532f8bea8ee95dd295cde4f521fb8902cb3d0 100644 --- a/llvm/__init__.py +++ b/llvm/__init__.py @@ -1 +1,2 @@ from .kernelcreation import createKernel +from .jit import compileLLVM \ No newline at end of file diff --git a/llvm/jit.py b/llvm/jit.py index 8e6fdb56f947ce6a98ba93aad5f8bb2ef9006014..2b13d7e7bcfbbb978502ec54fba089d02fdf8103 100644 --- a/llvm/jit.py +++ b/llvm/jit.py @@ -1,6 +1,12 @@ import llvmlite.binding as llvm import logging.config +logger = logging.getLogger(__name__) + + +def compileLLVM(module): + return Eval().compile(module) + class Eval(object): def __init__(self): @@ -63,9 +69,3 @@ class Eval(object): # result = fptr(2, 3) # print(result) return 0 - - -if __name__ == "__main__": - logger = logging.getLogger(__name__) -else: - logger = logging.getLogger(__name__) diff --git a/types.py b/types.py index 3550de398ffe3ecbd287c13e838d65c2c8f0143a..0fd58daa79aca1d838cf08b49a92e95545f8cc84 100644 --- a/types.py +++ b/types.py @@ -70,6 +70,9 @@ class DataType(object): if self.dtype > other.dtype: return True + def __hash__(self): + return hash(repr(self)) + def get_type_from_sympy(node): # Rational, NumberSymbol?