From 231fb6aff5f29618347d3bbaee0858e1b72ae7fb Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jan=20H=C3=B6nig?= <jan.hoenig@fau.de>
Date: Mon, 16 Oct 2017 17:06:31 +0200
Subject: [PATCH] Added demo for spinodal decomposition, which serves as a
 test. Fixed a severe bug. Renamed makePythonFunction of the llvm backend.
 Deleted code duplicity.

---
 cpu/cpujit.py             | 10 ++++------
 llvm/__init__.py          |  2 +-
 llvm/llvm.py              |  3 ++-
 llvm/llvmjit.py           | 35 ++++++-----------------------------
 transformations/stage2.py |  4 ++--
 5 files changed, 15 insertions(+), 39 deletions(-)

diff --git a/cpu/cpujit.py b/cpu/cpujit.py
index a64f04529..3e1db4520 100644
--- a/cpu/cpujit.py
+++ b/cpu/cpujit.py
@@ -90,13 +90,13 @@ def makePythonFunction(kernelFunctionNode, argumentDict={}):
     :return: kernel functor
     """
     # build up list of CType arguments
+    func = compileAndLoad(kernelFunctionNode)
+    func.restype = None
     try:
         args = buildCTypeArgumentList(kernelFunctionNode.parameters, argumentDict)
     except KeyError:
         # not all parameters specified yet
-        return makePythonFunctionIncompleteParams(kernelFunctionNode, argumentDict)
-    func = compileAndLoad(kernelFunctionNode)
-    func.restype = None
+        return makePythonFunctionIncompleteParams(kernelFunctionNode, argumentDict, func)
     return lambda: func(*args)
 
 
@@ -427,9 +427,7 @@ def buildCTypeArgumentList(parameterSpecification, argumentDict):
     return ctArguments
 
 
-def makePythonFunctionIncompleteParams(kernelFunctionNode, argumentDict):
-    func = compileAndLoad(kernelFunctionNode)
-    func.restype = None
+def makePythonFunctionIncompleteParams(kernelFunctionNode, argumentDict, func):
     parameters = kernelFunctionNode.parameters
 
     cache = {}
diff --git a/llvm/__init__.py b/llvm/__init__.py
index 16cd3d751..77fdd2da6 100644
--- a/llvm/__init__.py
+++ b/llvm/__init__.py
@@ -1,3 +1,3 @@
 from .kernelcreation import createKernel, createIndexedKernel
-from .llvmjit import compileLLVM, generate_and_jit, Jit, make_python_function
+from .llvmjit import compileLLVM, generate_and_jit, Jit, makePythonFunction
 from .llvm import generateLLVM
diff --git a/llvm/llvm.py b/llvm/llvm.py
index b48ec47bb..0eac24b93 100644
--- a/llvm/llvm.py
+++ b/llvm/llvm.py
@@ -71,7 +71,6 @@ class LLVMPrinter(Printer):
         return val
 
     def _print_Pow(self, expr):
-        #print(expr)
         base0 = self._print(expr.base)
         if expr.exp == S.NegativeOne:
             return self.builder.fdiv(ir.Constant(self.fp_type, 1.0), base0)
@@ -84,6 +83,8 @@ class LLVMPrinter(Printer):
             return self.builder.call(fn, [base0], "sqrt")
         if expr.exp == 2:
             return self.builder.fmul(base0, base0)
+        elif expr.exp == 3:
+            return self.builder.fmul(self.builder.fmul(base0, base0), base0)
 
         exp0 = self._print(expr.exp)
         fn = self.ext_fn.get("pow")
diff --git a/llvm/llvmjit.py b/llvm/llvmjit.py
index d767c0ee5..710439e8b 100644
--- a/llvm/llvmjit.py
+++ b/llvm/llvmjit.py
@@ -7,7 +7,7 @@ import shutil
 
 from ..data_types import toCtypes, createType, ctypes_from_llvm
 from .llvm import generateLLVM
-from ..cpu.cpujit import buildCTypeArgumentList
+from ..cpu.cpujit import buildCTypeArgumentList, makePythonFunctionIncompleteParams
 
 
 def generate_and_jit(ast):
@@ -18,41 +18,18 @@ def generate_and_jit(ast):
         return compileLLVM(gen.module)
 
 
-def make_python_function(ast, argumentDict={}, func=None):
+def makePythonFunction(ast, argumentDict={}, func=None):
+    if func is None:
+        jit = generate_and_jit(ast)
+        func = jit.get_function_ptr(ast.functionName)
     try:
         args = buildCTypeArgumentList(ast.parameters, argumentDict)
     except KeyError:
         # not all parameters specified yet
-        return make_python_function_incomplete(ast, argumentDict, func)
-    if func is None:
-        jit = generate_and_jit(ast)
-        func = jit.get_function_ptr(ast.functionName)
+        return makePythonFunctionIncompleteParams(ast, argumentDict, func)
     return lambda: func(*args)
 
 
-def make_python_function_incomplete(ast, argumentDict, func=None):
-    if func is None:
-        jit = generate_and_jit(ast)
-        func = jit.get_function_ptr(ast.functionName)
-    parameters = ast.parameters
-
-    cache = {}
-
-    def wrapper(**kwargs):
-        key = hash(tuple((k, id(v)) for k, v in kwargs.items()))
-        try:
-            args = cache[key]
-            func(*args)
-        except KeyError:
-            fullArguments = argumentDict.copy()
-            fullArguments.update(kwargs)
-            args = buildCTypeArgumentList(parameters, fullArguments)
-            cache[key] = args
-            func(*args)
-
-    return wrapper
-
-
 def compileLLVM(module):
     jit = Jit()
     jit.parse(module)
diff --git a/transformations/stage2.py b/transformations/stage2.py
index 92df679d6..42264a309 100644
--- a/transformations/stage2.py
+++ b/transformations/stage2.py
@@ -51,8 +51,8 @@ def insertCasts(node):
     args = []
     for arg in node.args:
         args.append(insertCasts(arg))
-    # TODO indexed, SympyAssignment, LoopOverCoordinate, Pow
-    if node.func in (sp.Add, sp.Mul, sp.Pow):
+    # TODO indexed, SympyAssignment, LoopOverCoordinate
+    if node.func in (sp.Add, sp.Mul, sp.Pow): # TODO fix pow, don't cast integer on double
         types = [getTypeOfExpression(arg) for arg in args]
         assert len(types) > 0
         target = collateTypes(types)
-- 
GitLab