diff --git a/cpu/cpujit.py b/cpu/cpujit.py index a64f045298099650295ff9692dba5646c0ab8008..3e1db4520bce529a5022d2b1460cf16c835de4f5 100644 --- a/cpu/cpujit.py +++ b/cpu/cpujit.py @@ -90,13 +90,13 @@ def makePythonFunction(kernelFunctionNode, argumentDict={}): :return: kernel functor """ # build up list of CType arguments + func = compileAndLoad(kernelFunctionNode) + func.restype = None try: args = buildCTypeArgumentList(kernelFunctionNode.parameters, argumentDict) except KeyError: # not all parameters specified yet - return makePythonFunctionIncompleteParams(kernelFunctionNode, argumentDict) - func = compileAndLoad(kernelFunctionNode) - func.restype = None + return makePythonFunctionIncompleteParams(kernelFunctionNode, argumentDict, func) return lambda: func(*args) @@ -427,9 +427,7 @@ def buildCTypeArgumentList(parameterSpecification, argumentDict): return ctArguments -def makePythonFunctionIncompleteParams(kernelFunctionNode, argumentDict): - func = compileAndLoad(kernelFunctionNode) - func.restype = None +def makePythonFunctionIncompleteParams(kernelFunctionNode, argumentDict, func): parameters = kernelFunctionNode.parameters cache = {} diff --git a/llvm/__init__.py b/llvm/__init__.py index 16cd3d7511b653ca09eddfcdab9f317f5ef6cded..77fdd2da619c0a715ca8551811a6f59eb7b746ba 100644 --- a/llvm/__init__.py +++ b/llvm/__init__.py @@ -1,3 +1,3 @@ from .kernelcreation import createKernel, createIndexedKernel -from .llvmjit import compileLLVM, generate_and_jit, Jit, make_python_function +from .llvmjit import compileLLVM, generate_and_jit, Jit, makePythonFunction from .llvm import generateLLVM diff --git a/llvm/llvm.py b/llvm/llvm.py index b48ec47bb564add71af36a5396ff59ca478c4022..0eac24b93f354a97ac5a67f87e8b188aa93c5035 100644 --- a/llvm/llvm.py +++ b/llvm/llvm.py @@ -71,7 +71,6 @@ class LLVMPrinter(Printer): return val def _print_Pow(self, expr): - #print(expr) base0 = self._print(expr.base) if expr.exp == S.NegativeOne: return self.builder.fdiv(ir.Constant(self.fp_type, 1.0), base0) @@ -84,6 +83,8 @@ class LLVMPrinter(Printer): return self.builder.call(fn, [base0], "sqrt") if expr.exp == 2: return self.builder.fmul(base0, base0) + elif expr.exp == 3: + return self.builder.fmul(self.builder.fmul(base0, base0), base0) exp0 = self._print(expr.exp) fn = self.ext_fn.get("pow") diff --git a/llvm/llvmjit.py b/llvm/llvmjit.py index d767c0ee5a107c6a570323b005f524a0e5dc2d46..710439e8b353a0a91af82a6c444f8b9ece0a8a69 100644 --- a/llvm/llvmjit.py +++ b/llvm/llvmjit.py @@ -7,7 +7,7 @@ import shutil from ..data_types import toCtypes, createType, ctypes_from_llvm from .llvm import generateLLVM -from ..cpu.cpujit import buildCTypeArgumentList +from ..cpu.cpujit import buildCTypeArgumentList, makePythonFunctionIncompleteParams def generate_and_jit(ast): @@ -18,41 +18,18 @@ def generate_and_jit(ast): return compileLLVM(gen.module) -def make_python_function(ast, argumentDict={}, func=None): +def makePythonFunction(ast, argumentDict={}, func=None): + if func is None: + jit = generate_and_jit(ast) + func = jit.get_function_ptr(ast.functionName) try: args = buildCTypeArgumentList(ast.parameters, argumentDict) except KeyError: # not all parameters specified yet - return make_python_function_incomplete(ast, argumentDict, func) - if func is None: - jit = generate_and_jit(ast) - func = jit.get_function_ptr(ast.functionName) + return makePythonFunctionIncompleteParams(ast, argumentDict, func) return lambda: func(*args) -def make_python_function_incomplete(ast, argumentDict, func=None): - if func is None: - jit = generate_and_jit(ast) - func = jit.get_function_ptr(ast.functionName) - parameters = ast.parameters - - cache = {} - - def wrapper(**kwargs): - key = hash(tuple((k, id(v)) for k, v in kwargs.items())) - try: - args = cache[key] - func(*args) - except KeyError: - fullArguments = argumentDict.copy() - fullArguments.update(kwargs) - args = buildCTypeArgumentList(parameters, fullArguments) - cache[key] = args - func(*args) - - return wrapper - - def compileLLVM(module): jit = Jit() jit.parse(module) diff --git a/transformations/stage2.py b/transformations/stage2.py index 92df679d6ff6f3631ee19517873eb515cdd0a96c..42264a309b83ec88770d1393c86f13bb81d4ca98 100644 --- a/transformations/stage2.py +++ b/transformations/stage2.py @@ -51,8 +51,8 @@ def insertCasts(node): args = [] for arg in node.args: args.append(insertCasts(arg)) - # TODO indexed, SympyAssignment, LoopOverCoordinate, Pow - if node.func in (sp.Add, sp.Mul, sp.Pow): + # TODO indexed, SympyAssignment, LoopOverCoordinate + if node.func in (sp.Add, sp.Mul, sp.Pow): # TODO fix pow, don't cast integer on double types = [getTypeOfExpression(arg) for arg in args] assert len(types) > 0 target = collateTypes(types)