Commit bab19cf4 authored by Martin Bauer's avatar Martin Bauer
Browse files

Module to create waLBerla sweeps from pystencils

parent f1b61821
......@@ -17,7 +17,7 @@ class ResolvedFieldAccess(sp.Indexed):
return superClassContents + tuple(self.offsets) + (repr(self.idxCoordinateValues), hash(self.field))
def __getnewargs__(self):
return self.name, self.indices[0], self.field, self.offsets, self.idxCoordinateValues
return self.base, self.indices[0], self.field, self.offsets, self.idxCoordinateValues
class Node(object):
......@@ -96,7 +96,7 @@ class Conditional(Node):
class KernelFunction(Node):
class Argument:
def __init__(self, name, dtype, kernelFunctionNode):
def __init__(self, name, dtype, symbol, kernelFunctionNode):
from pystencils.transformations import symbolNameToVariableName
self.name = name
self.dtype = dtype
......@@ -106,6 +106,7 @@ class KernelFunction(Node):
self.isFieldArgument = False
self.fieldName = ""
self.coordinate = None
self.symbol = symbol
if name.startswith(Field.DATA_PREFIX):
self.isFieldPtrArgument = True
......@@ -125,6 +126,23 @@ class KernelFunction(Node):
fieldMap = {symbolNameToVariableName(f.name): f for f in kernelFunctionNode.fieldsAccessed}
self.field = fieldMap[self.fieldName]
def __lt__(self, other):
def score(l):
if l.isFieldPtrArgument:
return -4
elif l.isFieldShapeArgument:
return -3
elif l.isFieldStrideArgument:
return -2
return 0
if score(self) < score(other):
return True
elif score(self) == score(other):
return self.name < other.name
else:
return False
def __repr__(self):
return '<{0} {1}>'.format(self.dtype, self.name)
......@@ -166,10 +184,9 @@ class KernelFunction(Node):
def _updateParameters(self):
undefinedSymbols = self._body.undefinedSymbols - self.globalVariables
self._parameters = [KernelFunction.Argument(s.name, s.dtype, self) for s in undefinedSymbols]
self._parameters.sort(key=lambda l: (l.fieldName, l.isFieldPtrArgument, l.isFieldShapeArgument,
l.isFieldStrideArgument, l.name),
reverse=True)
self._parameters = [KernelFunction.Argument(s.name, s.dtype, s, self) for s in undefinedSymbols]
self._parameters.sort()
def __str__(self):
self._updateParameters()
......
......@@ -4,13 +4,13 @@ from pystencils.astnodes import Node
from pystencils.types import createType, PointerType
def generateC(astNode):
def generateC(astNode, signatureOnly=False):
"""
Prints the abstract syntax tree as C function
"""
fieldTypes = set([f.dtype for f in astNode.fieldsAccessed])
useFloatConstants = createType("double") not in fieldTypes
printer = CBackend(constantsAsFloats=useFloatConstants)
printer = CBackend(constantsAsFloats=useFloatConstants, signatureOnly=signatureOnly)
return printer(astNode)
......@@ -51,13 +51,14 @@ class PrintNode(CustomCppCode):
class CBackend(object):
def __init__(self, constantsAsFloats=False, sympyPrinter=None):
def __init__(self, constantsAsFloats=False, sympyPrinter=None, signatureOnly=False):
if sympyPrinter is None:
self.sympyPrinter = CustomSympyPrinter(constantsAsFloats)
else:
self.sympyPrinter = sympyPrinter
self._indent = " "
self._signatureOnly = signatureOnly
def __call__(self, node):
return str(self._print(node))
......@@ -72,6 +73,9 @@ class CBackend(object):
def _print_KernelFunction(self, node):
functionArguments = ["%s %s" % (str(s.dtype), s.name) for s in node.parameters]
funcDeclaration = "FUNC_PREFIX void %s(%s)" % (node.functionName, ", ".join(functionArguments))
if self._signatureOnly:
return funcDeclaration
body = self._print(node.body)
return funcDeclaration + "\n" + body
......
......@@ -42,7 +42,9 @@ def makePythonFunction(kernelFunctionNode, argumentDict={}):
shape = _checkArguments(parameters, fullArguments)
indexing = kernelFunctionNode.indexing
dictWithBlockAndThreadNumbers = indexing.getCallParameters(shape, func)
dictWithBlockAndThreadNumbers = indexing.getCallParameters(shape)
dictWithBlockAndThreadNumbers['block'] = tuple(int(i) for i in dictWithBlockAndThreadNumbers['block'])
dictWithBlockAndThreadNumbers['grid'] = tuple(int(i) for i in dictWithBlockAndThreadNumbers['grid'])
args = _buildNumpyArgumentList(parameters, fullArguments)
cache[key] = (args, dictWithBlockAndThreadNumbers)
......
......@@ -31,12 +31,10 @@ class AbstractIndexing(abc.ABCMeta('ABC', (object,), {})):
return BLOCK_IDX + THREAD_IDX
@abc.abstractmethod
def getCallParameters(self, arrShape, functionToCall):
def getCallParameters(self, arrShape):
"""
Determine grid and block size for kernel call
:param arrShape: the numeric (not symbolic) shape of the array
:param functionToCall: compile kernel function that should be called. Use this object to get information
about required resources like number of registers
:return: dict with keys 'blocks' and 'threads' with tuple values for number of (x,y,z) threads and blocks
the kernel should be started with
"""
......@@ -87,14 +85,14 @@ class BlockIndexing(AbstractIndexing):
return coordinates[:self._dim]
def getCallParameters(self, arrShape, functionToCall):
def getCallParameters(self, arrShape):
substitutionDict = {sym: value for sym, value in zip(self._symbolicShape, arrShape) if sym is not None}
widths = [end - start for start, end in zip(_getStartFromSlice(self._iterationSlice),
_getEndFromSlice(self._iterationSlice, arrShape))]
widths = sp.Matrix(widths).subs(substitutionDict)
grid = tuple(math.ceil(length / blockSize) for length, blockSize in zip(widths, self._blockSize))
grid = tuple(sp.ceiling(length / blockSize) for length, blockSize in zip(widths, self._blockSize))
extendBs = (1,) * (3 - len(self._blockSize))
extendGr = (1,) * (3 - len(grid))
......@@ -230,7 +228,7 @@ class LineIndexing(AbstractIndexing):
def coordinates(self):
return [i + offset for i, offset in zip(self._coordinates, _getStartFromSlice(self._iterationSlice))]
def getCallParameters(self, arrShape, functionToCall):
def getCallParameters(self, arrShape):
substitutionDict = {sym: value for sym, value in zip(self._symbolicShape, arrShape) if sym is not None}
widths = [end - start for start, end in zip(_getStartFromSlice(self._iterationSlice),
......@@ -242,7 +240,7 @@ class LineIndexing(AbstractIndexing):
return 1
else:
idx = self._coordinates.index(cudaIdx)
return int(widths[idx])
return widths[idx]
return {'block': tuple([getShapeOfCudaIdx(idx) for idx in THREAD_IDX]),
'grid': tuple([getShapeOfCudaIdx(idx) for idx in BLOCK_IDX])}
......
......@@ -239,6 +239,10 @@ class BasicType(Type):
def is_other(self):
return self.numpyDtype in np.sctypes['others']
@property
def baseName(self):
return BasicType.numpyNameToC(str(self._dtype))
def __str__(self):
result = BasicType.numpyNameToC(str(self._dtype))
if self.const:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment