From 231fb6aff5f29618347d3bbaee0858e1b72ae7fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20H=C3=B6nig?= <jan.hoenig@fau.de> Date: Mon, 16 Oct 2017 17:06:31 +0200 Subject: [PATCH] Added demo for spinodal decomposition, which serves as a test. Fixed a severe bug. Renamed makePythonFunction of the llvm backend. Deleted code duplicity. --- cpu/cpujit.py | 10 ++++------ llvm/__init__.py | 2 +- llvm/llvm.py | 3 ++- llvm/llvmjit.py | 35 ++++++----------------------------- transformations/stage2.py | 4 ++-- 5 files changed, 15 insertions(+), 39 deletions(-) diff --git a/cpu/cpujit.py b/cpu/cpujit.py index a64f04529..3e1db4520 100644 --- a/cpu/cpujit.py +++ b/cpu/cpujit.py @@ -90,13 +90,13 @@ def makePythonFunction(kernelFunctionNode, argumentDict={}): :return: kernel functor """ # build up list of CType arguments + func = compileAndLoad(kernelFunctionNode) + func.restype = None try: args = buildCTypeArgumentList(kernelFunctionNode.parameters, argumentDict) except KeyError: # not all parameters specified yet - return makePythonFunctionIncompleteParams(kernelFunctionNode, argumentDict) - func = compileAndLoad(kernelFunctionNode) - func.restype = None + return makePythonFunctionIncompleteParams(kernelFunctionNode, argumentDict, func) return lambda: func(*args) @@ -427,9 +427,7 @@ def buildCTypeArgumentList(parameterSpecification, argumentDict): return ctArguments -def makePythonFunctionIncompleteParams(kernelFunctionNode, argumentDict): - func = compileAndLoad(kernelFunctionNode) - func.restype = None +def makePythonFunctionIncompleteParams(kernelFunctionNode, argumentDict, func): parameters = kernelFunctionNode.parameters cache = {} diff --git a/llvm/__init__.py b/llvm/__init__.py index 16cd3d751..77fdd2da6 100644 --- a/llvm/__init__.py +++ b/llvm/__init__.py @@ -1,3 +1,3 @@ from .kernelcreation import createKernel, createIndexedKernel -from .llvmjit import compileLLVM, generate_and_jit, Jit, make_python_function +from .llvmjit import compileLLVM, generate_and_jit, Jit, makePythonFunction from .llvm import generateLLVM diff --git a/llvm/llvm.py b/llvm/llvm.py index b48ec47bb..0eac24b93 100644 --- a/llvm/llvm.py +++ b/llvm/llvm.py @@ -71,7 +71,6 @@ class LLVMPrinter(Printer): return val def _print_Pow(self, expr): - #print(expr) base0 = self._print(expr.base) if expr.exp == S.NegativeOne: return self.builder.fdiv(ir.Constant(self.fp_type, 1.0), base0) @@ -84,6 +83,8 @@ class LLVMPrinter(Printer): return self.builder.call(fn, [base0], "sqrt") if expr.exp == 2: return self.builder.fmul(base0, base0) + elif expr.exp == 3: + return self.builder.fmul(self.builder.fmul(base0, base0), base0) exp0 = self._print(expr.exp) fn = self.ext_fn.get("pow") diff --git a/llvm/llvmjit.py b/llvm/llvmjit.py index d767c0ee5..710439e8b 100644 --- a/llvm/llvmjit.py +++ b/llvm/llvmjit.py @@ -7,7 +7,7 @@ import shutil from ..data_types import toCtypes, createType, ctypes_from_llvm from .llvm import generateLLVM -from ..cpu.cpujit import buildCTypeArgumentList +from ..cpu.cpujit import buildCTypeArgumentList, makePythonFunctionIncompleteParams def generate_and_jit(ast): @@ -18,41 +18,18 @@ def generate_and_jit(ast): return compileLLVM(gen.module) -def make_python_function(ast, argumentDict={}, func=None): +def makePythonFunction(ast, argumentDict={}, func=None): + if func is None: + jit = generate_and_jit(ast) + func = jit.get_function_ptr(ast.functionName) try: args = buildCTypeArgumentList(ast.parameters, argumentDict) except KeyError: # not all parameters specified yet - return make_python_function_incomplete(ast, argumentDict, func) - if func is None: - jit = generate_and_jit(ast) - func = jit.get_function_ptr(ast.functionName) + return makePythonFunctionIncompleteParams(ast, argumentDict, func) return lambda: func(*args) -def make_python_function_incomplete(ast, argumentDict, func=None): - if func is None: - jit = generate_and_jit(ast) - func = jit.get_function_ptr(ast.functionName) - parameters = ast.parameters - - cache = {} - - def wrapper(**kwargs): - key = hash(tuple((k, id(v)) for k, v in kwargs.items())) - try: - args = cache[key] - func(*args) - except KeyError: - fullArguments = argumentDict.copy() - fullArguments.update(kwargs) - args = buildCTypeArgumentList(parameters, fullArguments) - cache[key] = args - func(*args) - - return wrapper - - def compileLLVM(module): jit = Jit() jit.parse(module) diff --git a/transformations/stage2.py b/transformations/stage2.py index 92df679d6..42264a309 100644 --- a/transformations/stage2.py +++ b/transformations/stage2.py @@ -51,8 +51,8 @@ def insertCasts(node): args = [] for arg in node.args: args.append(insertCasts(arg)) - # TODO indexed, SympyAssignment, LoopOverCoordinate, Pow - if node.func in (sp.Add, sp.Mul, sp.Pow): + # TODO indexed, SympyAssignment, LoopOverCoordinate + if node.func in (sp.Add, sp.Mul, sp.Pow): # TODO fix pow, don't cast integer on double types = [getTypeOfExpression(arg) for arg in args] assert len(types) > 0 target = collateTypes(types) -- GitLab