From 0369495252fb08c473f43f367bff140216fd2054 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Fri, 25 Nov 2016 16:56:10 +0100
Subject: [PATCH] New Benchmark with kernels cached in library

---
 ast.py        |   7 +--
 cpu/cpujit.py | 127 +++++++++++++++++++++++++++++++++++++++-----------
 2 files changed, 102 insertions(+), 32 deletions(-)

diff --git a/ast.py b/ast.py
index c8b528b1f..81907cdf2 100644
--- a/ast.py
+++ b/ast.py
@@ -74,12 +74,13 @@ class KernelFunction(Node):
         self._body = body
         body.parent = self
         self._parameters = None
-        self._functionName = functionName
+        self.functionName = functionName
         self._body.parent = self
         self._fieldsAccessed = fieldsAccessed
         # these variables are assumed to be global, so no automatic parameter is generated for them
         self.globalVariables = set()
 
+
     @property
     def symbolsDefined(self):
         return set()
@@ -101,10 +102,6 @@ class KernelFunction(Node):
     def args(self):
         return [self._body]
 
-    @property
-    def functionName(self):
-        return self._functionName
-
     @property
     def fieldsAccessed(self):
         """Set of Field instances: fields which are accessed inside this kernel function"""
diff --git a/cpu/cpujit.py b/cpu/cpujit.py
index 58bc7ee2e..638cbbf09 100644
--- a/cpu/cpujit.py
+++ b/cpu/cpujit.py
@@ -4,7 +4,8 @@ from ctypes import cdll, c_double, c_float, sizeof
 from tempfile import TemporaryDirectory
 from pystencils.backends.cbackend import generateC
 import numpy as np
-
+import pickle
+import hashlib
 
 CONFIG_GCC = {
     'compiler': 'g++',
@@ -60,39 +61,43 @@ def ctypeFromNumpyType(numpyType):
     return typeMap[numpyType]
 
 
+def compile(code, tmpDir, libFile, createAssemblyCode=False):
+    srcFile = os.path.join(tmpDir, 'source.cpp')
+    with open(srcFile, 'w') as sourceFile:
+        print('#include <iostream>', file=sourceFile)
+        print("#include <cmath>", file=sourceFile)
+        print('extern "C" { ', file=sourceFile)
+        print(code, file=sourceFile)
+        print('}', file=sourceFile)
+
+    compilerCmd = [CONFIG['compiler']] + CONFIG['flags'].split()
+    compilerCmd += [srcFile, '-o', libFile]
+    configEnv = CONFIG['env'] if 'env' in CONFIG else {}
+    env = os.environ.copy()
+    env.update(configEnv)
+    subprocess.call(compilerCmd, env=env)
+
+    assembly = None
+    if createAssemblyCode:
+        assemblyFile = os.path.join(tmpDir, "assembly.s")
+        compilerCmd = [CONFIG['compiler'], '-S', '-o', assemblyFile, srcFile] + CONFIG['flags'].split()
+        subprocess.call(compilerCmd, env=env)
+        assembly = open(assemblyFile, 'r').read()
+    return assembly
+
+
 def compileAndLoad(kernelFunctionNode):
     with TemporaryDirectory() as tmpDir:
-        srcFile = os.path.join(tmpDir, 'source.cpp')
-        with open(srcFile, 'w') as sourceFile:
-            print('#include <iostream>', file=sourceFile)
-            print("#include <cmath>", file=sourceFile)
-            print('extern "C" { ', file=sourceFile)
-            print(generateC(kernelFunctionNode), file=sourceFile)
-            print('}', file=sourceFile)
-
-        compilerCmd = [CONFIG['compiler']] + CONFIG['flags'].split()
         libFile = os.path.join(tmpDir, "jit.so")
-        compilerCmd += [srcFile, '-o', libFile]
-        configEnv = CONFIG['env'] if 'env' in CONFIG else {}
-        env = os.environ.copy()
-        env.update(configEnv)
-        subprocess.call(compilerCmd, env=env)
-
-        showAssembly = True
-        if showAssembly:
-            assemblyFile = os.path.join(tmpDir, "assembly.s")
-            compilerCmd = [CONFIG['compiler'], '-S', '-o', assemblyFile, srcFile] + CONFIG['flags'].split()
-            subprocess.call(compilerCmd, env=env)
-            assembly = open(assemblyFile, 'r').read()
-            kernelFunctionNode.assembly = assembly
+        compile(generateC(kernelFunctionNode), tmpDir, libFile)
         loadedJitLib = cdll.LoadLibrary(libFile)
 
     return loadedJitLib
 
 
-def buildCTypeArgumentList(kernelFunctionNode, argumentDict):
+def buildCTypeArgumentList(parameterSpecification, argumentDict):
     ctArguments = []
-    for arg in kernelFunctionNode.parameters:
+    for arg in parameterSpecification:
         if arg.isFieldArgument:
             field = argumentDict[arg.fieldName]
             if arg.isFieldPtrArgument:
@@ -125,7 +130,7 @@ def makePythonFunctionIncompleteParams(kernelFunctionNode, argumentDict):
         from copy import copy
         fullArguments = copy(argumentDict)
         fullArguments.update(kwargs)
-        args = buildCTypeArgumentList(kernelFunctionNode, fullArguments)
+        args = buildCTypeArgumentList(kernelFunctionNode.parameters, fullArguments)
         func(*args)
     return wrapper
 
@@ -145,10 +150,78 @@ def makePythonFunction(kernelFunctionNode, argumentDict={}):
     """
     # build up list of CType arguments
     try:
-        args = buildCTypeArgumentList(kernelFunctionNode, argumentDict)
+        args = buildCTypeArgumentList(kernelFunctionNode.parameters, argumentDict)
     except KeyError:
         # not all parameters specified yet
         return makePythonFunctionIncompleteParams(kernelFunctionNode, argumentDict)
     func = compileAndLoad(kernelFunctionNode)[kernelFunctionNode.functionName]
     func.restype = None
     return lambda: func(*args)
+
+
+class CachedKernel:
+    def __init__(self, configDict, ast, parameterValues):
+        self.configDict = configDict
+        self.ast = ast
+        self.parameterValues = parameterValues
+        self.funcPtr = None
+
+    def __compile(self):
+        self.funcPtr = makePythonFunction(self.ast, self.parameterValues)
+
+    def __call__(self, *args, **kwargs):
+        if self.funcPtr is None:
+            self.__compile()
+        self.funcPtr(*args, **kwargs)
+
+
+def hashToFunctionName(h):
+    res = "func_%s" % (h,)
+    return res.replace('-', 'm')
+
+
+def createLibrary(cachedKernels, libraryFile):
+    libraryInfoFile = libraryFile + ".info"
+
+    with TemporaryDirectory() as tmpDir:
+        code = ""
+        infoDict = {}
+        for cachedKernel in cachedKernels:
+            s = repr(sorted(cachedKernel.configDict.items()))
+            configHash = hashlib.sha1(s.encode()).hexdigest()
+            cachedKernel.ast.functionName = hashToFunctionName(configHash)
+            kernelCode = generateC(cachedKernel.ast)
+            code += kernelCode + "\n"
+            infoDict[configHash] = {'code': kernelCode,
+                                    'parameterValues': cachedKernel.parameterValues,
+                                    'configDict': cachedKernel.configDict,
+                                    'parameterSpecification': cachedKernel.ast.parameters}
+
+        compile(code, tmpDir, libraryFile)
+        pickle.dump(infoDict, open(libraryInfoFile, "wb"))
+
+
+def loadLibrary(libraryFile):
+    libraryInfoFile = libraryFile + ".info"
+
+    libraryFile = cdll.LoadLibrary(libraryFile)
+    libraryInfo = pickle.load(open(libraryInfoFile, 'rb'))
+
+    def getKernel(**kwargs):
+        s = repr(sorted(kwargs.items()))
+        configHash = hashlib.sha1(s.encode()).hexdigest()
+        if configHash not in libraryInfo:
+            raise ValueError("No such kernel in library")
+        func = libraryFile[hashToFunctionName(configHash)]
+        func.restype = None
+
+        def wrapper(**kwargs):
+            from copy import copy
+            fullArguments = copy(libraryInfo[configHash]['parameterValues'])
+            fullArguments.update(kwargs)
+            args = buildCTypeArgumentList(libraryInfo[configHash]['parameterSpecification'], fullArguments)
+            func(*args)
+        wrapper.configDict = libraryInfo[configHash]['configDict']
+        return wrapper
+
+    return getKernel
-- 
GitLab