Skip to content
Snippets Groups Projects
Commit 96566fce authored by Martin Bauer's avatar Martin Bauer
Browse files

Support for symbol names which are not legal C++ variable identifiers

parent ab3fd339
Branches
Tags
No related merge requests found
......@@ -7,6 +7,8 @@ import numpy as np
import pickle
import hashlib
from pystencils.transformations import symbolNameToVariableName
CONFIG_GCC = {
'compiler': 'g++',
'flags': '-Ofast -DNDEBUG -fPIC -shared -march=native -fopenmp',
......@@ -104,6 +106,7 @@ def compileAndLoad(kernelFunctionNode):
def buildCTypeArgumentList(parameterSpecification, argumentDict):
argumentDict = {symbolNameToVariableName(k): v for k, v in argumentDict.items()}
ctArguments = []
for arg in parameterSpecification:
if arg.isFieldArgument:
......
......@@ -74,6 +74,7 @@ class Field:
def createGeneric(fieldName, spatialDimensions, dtype=np.float64, indexDimensions=0, layout='numpy'):
"""
Creates a generic field where the field size is not fixed i.e. can be called with arrays of different sizes
:param fieldName: symbolic name for the field
:param dtype: numpy data type of the array the kernel is called with later
:param spatialDimensions: see documentation of Field
......
......@@ -4,6 +4,7 @@ import pycuda.autoinit
from pycuda.compiler import SourceModule
from pycuda.gpuarray import GPUArray
from pystencils.backends.cbackend import generateCUDA
from pystencils.transformations import symbolNameToVariableName
def numpyTypeFromString(typename, includePointers=True):
......@@ -34,6 +35,7 @@ def numpyTypeFromString(typename, includePointers=True):
def buildNumpyArgumentList(kernelFunctionNode, argumentDict):
argumentDict = {symbolNameToVariableName(k): v for k, v in argumentDict.items()}
result = []
for arg in kernelFunctionNode.parameters:
if arg.isFieldArgument:
......
......@@ -39,6 +39,8 @@ def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=None)
fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol)
readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten])
allFields = fieldsRead.union(fieldsWritten)
code = KernelFunction(Block(assignments), fieldsRead.union(fieldsWritten), functionName)
code.globalVariables.update(BLOCK_IDX + THREAD_IDX)
......@@ -49,7 +51,9 @@ def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=None)
allFields = fieldsRead.union(fieldsWritten)
basePointerInfo = [['spatialInner0']]
basePointerInfos = {f.name: parseBasePointerInfo(basePointerInfo, [2, 1, 0], f) for f in allFields}
resolveFieldAccesses(code, readOnlyFields, fieldToFixedCoordinates={'src': coordMapping, 'dst': coordMapping},
coordMapping = {f.name: coordMapping for f in allFields}
resolveFieldAccesses(code, readOnlyFields, fieldToFixedCoordinates=coordMapping,
fieldToBasePointerInfo=basePointerInfos)
# add the function which determines #blocks and #threads as additional member to KernelFunction node
# this is used by the jit
......
......@@ -226,7 +226,7 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn
if field.name in readOnlyFieldNames:
dtype.const = True
fieldPtr = TypedSymbol("%s%s" % (Field.DATA_PREFIX, field.name), dtype)
fieldPtr = TypedSymbol("%s%s" % (Field.DATA_PREFIX, symbolNameToVariableName(field.name)), dtype)
lastPointer = fieldPtr
......@@ -392,6 +392,11 @@ def splitInnerLoop(astNode, symbolGroups):
outerLoop.parent.append(ast.TemporaryMemoryFree(tmpArray))
def symbolNameToVariableName(symbolName):
"""Replaces characters which are allowed in sympy symbol names but not in C/C++ variable names"""
return symbolName.replace("^", "_")
def typeAllEquations(eqs, typeForSymbol):
"""
Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`.
......@@ -415,7 +420,7 @@ def typeAllEquations(eqs, typeForSymbol):
elif isinstance(term, TypedSymbol):
return term
elif isinstance(term, sp.Symbol):
return TypedSymbol(term.name, typeForSymbol[term.name])
return TypedSymbol(symbolNameToVariableName(term.name), typeForSymbol[term.name])
else:
newArgs = [processRhs(arg) for arg in term.args]
return term.func(*newArgs) if newArgs else term
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment