From 4a8659e8f6eb45de7885bf2318078d668312c60d Mon Sep 17 00:00:00 2001
From: Jan Hoenig <hrominium@gmail.com>
Date: Fri, 10 Mar 2017 15:06:58 +0100
Subject: [PATCH] it actually somehow comiles

---
 astnodes.py      |  6 ++++
 backends/llvm.py | 73 +++++++++++++++++++++++++++++++++++-------------
 llvm/__init__.py |  1 +
 llvm/jit.py      | 12 ++++----
 types.py         |  3 ++
 5 files changed, 70 insertions(+), 25 deletions(-)

diff --git a/astnodes.py b/astnodes.py
index f7a0abac8..a2eaa1e13 100644
--- a/astnodes.py
+++ b/astnodes.py
@@ -511,4 +511,10 @@ class Number(Node, sp.AtomicExpr):
     def __repr__(self):
         return repr(self.value)
 
+    def __float__(self):
+        return float(self.value)
+
+    def __int__(self):
+        return int(self.value)
+
 
diff --git a/backends/llvm.py b/backends/llvm.py
index a70627f02..b97f1c4e5 100644
--- a/backends/llvm.py
+++ b/backends/llvm.py
@@ -1,10 +1,13 @@
 import llvmlite.ir as ir
+import functools
 
 from sympy.printing.printer import Printer
 from sympy import S
 # S is numbers?
 
 from pystencils.llvm.control_flow import Loop
+from ..types import DataType
+from ..astnodes import Indexed
 
 
 def generateLLVM(ast_node):
@@ -25,6 +28,7 @@ class LLVMPrinter(Printer):
         self.fp_type = ir.DoubleType()
         self.fp_pointer = self.fp_type.as_pointer()
         self.integer = ir.IntType(64)
+        self.integer_pointer = self.integer.as_pointer()
         self.void = ir.VoidType()
         self.module = module
         self.builder = builder
@@ -35,8 +39,13 @@ class LLVMPrinter(Printer):
     def _add_tmp_var(self, name, value):
         self.tmp_var[name] = value
 
-    def _print_Number(self, n, **kwargs):
-        return ir.Constant(self.fp_type, n)
+    def _print_Number(self, n):
+        if n.dtype == DataType("int"):
+            return ir.Constant(self.integer, int(n))
+        elif n.dtype == DataType("double"):
+            return ir.Constant(self.fp_type, float(n))
+        else:
+            raise NotImplementedError("Numbers can only have int and double", n)
 
     def _print_Float(self, expr):
         return ir.Constant(self.fp_type, expr.p)
@@ -81,16 +90,23 @@ class LLVMPrinter(Printer):
     def _print_Mul(self, expr):
         nodes = [self._print(a) for a in expr.args]
         e = nodes[0]
+        if expr.dtype == DataType('double'):
+            mul = self.builder.fmul
+        else: # int TODO others?
+            mul = self.builder.mul
         for node in nodes[1:]:
-            e = self.builder.fmul(e, node)
+            e = mul(e, node)
         return e
 
     def _print_Add(self, expr):
         nodes = [self._print(a) for a in expr.args]
         e = nodes[0]
+        if expr.dtype == DataType('double'):
+            add = self.builder.fadd
+        else: # int TODO others?
+            add = self.builder.add
         for node in nodes[1:]:
-            print(e, node)
-            e = self.builder.fadd(e, node)
+            e = add(e, node)
         return e
 
     def _print_KernelFunction(self, function):
@@ -118,6 +134,7 @@ class LLVMPrinter(Printer):
         block = fn.append_basic_block(name="entry")
         self.builder = ir.IRBuilder(block)
         self._print(function.body)
+        self.builder.ret_void()
         self.fn = fn
         return fn
 
@@ -129,29 +146,47 @@ class LLVMPrinter(Printer):
         with Loop(self.builder, self._print(loop.start), self._print(loop.stop), self._print(loop.step),
                   loop.loopCounterName, loop.loopCounterSymbol.name) as i:
             self._add_tmp_var(loop.loopCounterSymbol, i)
+            # TODO remove tmp var
             self._print(loop.body)
 
     def _print_SympyAssignment(self, assignment):
         expr = self._print(assignment.rhs)
+        lhs = assignment.lhs
+        if isinstance(lhs, Indexed):
+            ptr = self._print(lhs.base.label)
+            index = self._print(lhs.args[1])
+            gep = self.builder.gep(ptr, [index])
+            return self.builder.store(expr, gep)
+        self.func_arg_map[assignment.lhs.name] = expr
+        return expr
 
     def _print_Conversion(self, conversion):
+        node = self._print(conversion.args[0])
         to_dtype = conversion.dtype
         from_dtype = conversion.args[0].dtype
-        print(to_dtype, from_dtype)
-        # fp -> int: fptosi
-        # int -> fp: sitofp
-        # ptr -> int: ptrtoint
-        # int -> ptr: inttoptr
-        # ?bitcast, ?addrspacecast
+        # (From, to)
+        decision = {
+                    (DataType("int"), DataType("double")): functools.partial(self.builder.sitofp, node, self.fp_type),
+                    (DataType("double"), DataType("int")): functools.partial(self.builder.fptosi, node, self.integer),
+                    (DataType("double *"), DataType("int")): functools.partial(self.builder.ptrtoint, node, self.integer),
+                    (DataType("int"), DataType("double *")): functools.partial(self.builder.inttoptr, node, self.fp_pointer),
+                    (DataType("double * __restrict__"), DataType("int")): functools.partial(self.builder.ptrtoint, node, self.integer),
+                    (DataType("int"), DataType("double * __restrict__")): functools.partial(self.builder.inttoptr, node, self.fp_pointer),
+                    (DataType("const double * __restrict__"), DataType("int")): functools.partial(self.builder.ptrtoint, node, self.integer),
+                    (DataType("int"), DataType("const double * __restrict__")): functools.partial(self.builder.inttoptr, node, self.fp_pointer),
+                    }
+        # TODO float, const, restrict
+        # TODO bitcast, addrspacecast
+        return decision[(from_dtype, to_dtype)]()
 
     def _print_Indexed(self, indexed):
-        pass
+        ptr = self._print(indexed.base.label)
+        index = self._print(indexed.args[1])
+        gep = self.builder.gep(ptr, [index])
+        return self.builder.load(gep, name=indexed.base.label.name)
 
-
-
-        #  Should have a list of math library functions to validate this.
-
-    # TODO delete this -> NO this should be a function call
+    # Should have a list of math library functions to validate this.
+    # TODO function calls
     def _print_Function(self, expr):
         name = expr.func.__name__
         e0 = self._print(expr.args[0])
@@ -163,5 +198,5 @@ class LLVMPrinter(Printer):
         return self.builder.call(fn, [e0], name)
 
     def emptyPrinter(self, expr):
-        raise TypeError("Unsupported type for LLVM JIT conversion: %s"
-                        % type(expr))
+        raise TypeError("Unsupported type for LLVM JIT conversion: %s %s"
+                        % type(expr), expr)
diff --git a/llvm/__init__.py b/llvm/__init__.py
index da5dfa39d..f34532f8b 100644
--- a/llvm/__init__.py
+++ b/llvm/__init__.py
@@ -1 +1,2 @@
 from .kernelcreation import createKernel
+from .jit import compileLLVM
\ No newline at end of file
diff --git a/llvm/jit.py b/llvm/jit.py
index 8e6fdb56f..2b13d7e7b 100644
--- a/llvm/jit.py
+++ b/llvm/jit.py
@@ -1,6 +1,12 @@
 import llvmlite.binding as llvm
 import logging.config
 
+logger = logging.getLogger(__name__)
+
+
+def compileLLVM(module):
+    return Eval().compile(module)
+
 
 class Eval(object):
     def __init__(self):
@@ -63,9 +69,3 @@ class Eval(object):
             # result = fptr(2, 3)
             # print(result)
             return 0
-
-
-if __name__ == "__main__":
-    logger = logging.getLogger(__name__)
-else:
-    logger = logging.getLogger(__name__)
diff --git a/types.py b/types.py
index 3550de398..0fd58daa7 100644
--- a/types.py
+++ b/types.py
@@ -70,6 +70,9 @@ class DataType(object):
         if self.dtype > other.dtype:
             return True
 
+    def __hash__(self):
+        return hash(repr(self))
+
 
 def get_type_from_sympy(node):
     # Rational, NumberSymbol?
-- 
GitLab