Skip to content
Snippets Groups Projects
Commit 03694952 authored by Martin Bauer's avatar Martin Bauer
Browse files

New Benchmark with kernels cached in library

parent 340a4f3c
No related merge requests found
......@@ -74,12 +74,13 @@ class KernelFunction(Node):
self._body = body
body.parent = self
self._parameters = None
self._functionName = functionName
self.functionName = functionName
self._body.parent = self
self._fieldsAccessed = fieldsAccessed
# these variables are assumed to be global, so no automatic parameter is generated for them
self.globalVariables = set()
@property
def symbolsDefined(self):
return set()
......@@ -101,10 +102,6 @@ class KernelFunction(Node):
def args(self):
return [self._body]
@property
def functionName(self):
return self._functionName
@property
def fieldsAccessed(self):
"""Set of Field instances: fields which are accessed inside this kernel function"""
......
......@@ -4,7 +4,8 @@ from ctypes import cdll, c_double, c_float, sizeof
from tempfile import TemporaryDirectory
from pystencils.backends.cbackend import generateC
import numpy as np
import pickle
import hashlib
CONFIG_GCC = {
'compiler': 'g++',
......@@ -60,39 +61,43 @@ def ctypeFromNumpyType(numpyType):
return typeMap[numpyType]
def compile(code, tmpDir, libFile, createAssemblyCode=False):
srcFile = os.path.join(tmpDir, 'source.cpp')
with open(srcFile, 'w') as sourceFile:
print('#include <iostream>', file=sourceFile)
print("#include <cmath>", file=sourceFile)
print('extern "C" { ', file=sourceFile)
print(code, file=sourceFile)
print('}', file=sourceFile)
compilerCmd = [CONFIG['compiler']] + CONFIG['flags'].split()
compilerCmd += [srcFile, '-o', libFile]
configEnv = CONFIG['env'] if 'env' in CONFIG else {}
env = os.environ.copy()
env.update(configEnv)
subprocess.call(compilerCmd, env=env)
assembly = None
if createAssemblyCode:
assemblyFile = os.path.join(tmpDir, "assembly.s")
compilerCmd = [CONFIG['compiler'], '-S', '-o', assemblyFile, srcFile] + CONFIG['flags'].split()
subprocess.call(compilerCmd, env=env)
assembly = open(assemblyFile, 'r').read()
return assembly
def compileAndLoad(kernelFunctionNode):
with TemporaryDirectory() as tmpDir:
srcFile = os.path.join(tmpDir, 'source.cpp')
with open(srcFile, 'w') as sourceFile:
print('#include <iostream>', file=sourceFile)
print("#include <cmath>", file=sourceFile)
print('extern "C" { ', file=sourceFile)
print(generateC(kernelFunctionNode), file=sourceFile)
print('}', file=sourceFile)
compilerCmd = [CONFIG['compiler']] + CONFIG['flags'].split()
libFile = os.path.join(tmpDir, "jit.so")
compilerCmd += [srcFile, '-o', libFile]
configEnv = CONFIG['env'] if 'env' in CONFIG else {}
env = os.environ.copy()
env.update(configEnv)
subprocess.call(compilerCmd, env=env)
showAssembly = True
if showAssembly:
assemblyFile = os.path.join(tmpDir, "assembly.s")
compilerCmd = [CONFIG['compiler'], '-S', '-o', assemblyFile, srcFile] + CONFIG['flags'].split()
subprocess.call(compilerCmd, env=env)
assembly = open(assemblyFile, 'r').read()
kernelFunctionNode.assembly = assembly
compile(generateC(kernelFunctionNode), tmpDir, libFile)
loadedJitLib = cdll.LoadLibrary(libFile)
return loadedJitLib
def buildCTypeArgumentList(kernelFunctionNode, argumentDict):
def buildCTypeArgumentList(parameterSpecification, argumentDict):
ctArguments = []
for arg in kernelFunctionNode.parameters:
for arg in parameterSpecification:
if arg.isFieldArgument:
field = argumentDict[arg.fieldName]
if arg.isFieldPtrArgument:
......@@ -125,7 +130,7 @@ def makePythonFunctionIncompleteParams(kernelFunctionNode, argumentDict):
from copy import copy
fullArguments = copy(argumentDict)
fullArguments.update(kwargs)
args = buildCTypeArgumentList(kernelFunctionNode, fullArguments)
args = buildCTypeArgumentList(kernelFunctionNode.parameters, fullArguments)
func(*args)
return wrapper
......@@ -145,10 +150,78 @@ def makePythonFunction(kernelFunctionNode, argumentDict={}):
"""
# build up list of CType arguments
try:
args = buildCTypeArgumentList(kernelFunctionNode, argumentDict)
args = buildCTypeArgumentList(kernelFunctionNode.parameters, argumentDict)
except KeyError:
# not all parameters specified yet
return makePythonFunctionIncompleteParams(kernelFunctionNode, argumentDict)
func = compileAndLoad(kernelFunctionNode)[kernelFunctionNode.functionName]
func.restype = None
return lambda: func(*args)
class CachedKernel:
def __init__(self, configDict, ast, parameterValues):
self.configDict = configDict
self.ast = ast
self.parameterValues = parameterValues
self.funcPtr = None
def __compile(self):
self.funcPtr = makePythonFunction(self.ast, self.parameterValues)
def __call__(self, *args, **kwargs):
if self.funcPtr is None:
self.__compile()
self.funcPtr(*args, **kwargs)
def hashToFunctionName(h):
res = "func_%s" % (h,)
return res.replace('-', 'm')
def createLibrary(cachedKernels, libraryFile):
libraryInfoFile = libraryFile + ".info"
with TemporaryDirectory() as tmpDir:
code = ""
infoDict = {}
for cachedKernel in cachedKernels:
s = repr(sorted(cachedKernel.configDict.items()))
configHash = hashlib.sha1(s.encode()).hexdigest()
cachedKernel.ast.functionName = hashToFunctionName(configHash)
kernelCode = generateC(cachedKernel.ast)
code += kernelCode + "\n"
infoDict[configHash] = {'code': kernelCode,
'parameterValues': cachedKernel.parameterValues,
'configDict': cachedKernel.configDict,
'parameterSpecification': cachedKernel.ast.parameters}
compile(code, tmpDir, libraryFile)
pickle.dump(infoDict, open(libraryInfoFile, "wb"))
def loadLibrary(libraryFile):
libraryInfoFile = libraryFile + ".info"
libraryFile = cdll.LoadLibrary(libraryFile)
libraryInfo = pickle.load(open(libraryInfoFile, 'rb'))
def getKernel(**kwargs):
s = repr(sorted(kwargs.items()))
configHash = hashlib.sha1(s.encode()).hexdigest()
if configHash not in libraryInfo:
raise ValueError("No such kernel in library")
func = libraryFile[hashToFunctionName(configHash)]
func.restype = None
def wrapper(**kwargs):
from copy import copy
fullArguments = copy(libraryInfo[configHash]['parameterValues'])
fullArguments.update(kwargs)
args = buildCTypeArgumentList(libraryInfo[configHash]['parameterSpecification'], fullArguments)
func(*args)
wrapper.configDict = libraryInfo[configHash]['configDict']
return wrapper
return getKernel
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment