diff --git a/astnodes.py b/astnodes.py index f5489b62f2b40f0d5415e73aba5dea7387918c1a..25d4ba4a069b98c1f4307659b87ee2ca48686ff2 100644 --- a/astnodes.py +++ b/astnodes.py @@ -156,13 +156,14 @@ class KernelFunction(Node): def __repr__(self): return '<{0} {1}>'.format(self.dtype, self.name) - def __init__(self, body, functionName="kernel"): + def __init__(self, body, ghostLayers=None, functionName="kernel"): super(KernelFunction, self).__init__() self._body = body body.parent = self self._parameters = None self.functionName = functionName self._body.parent = self + self.ghostLayers = ghostLayers # these variables are assumed to be global, so no automatic parameter is generated for them self.globalVariables = set() diff --git a/cpu/kernelcreation.py b/cpu/kernelcreation.py index 18d382d72658e3f9fc1433640af859a2ab167338..718dd1cb93f0d7b84d047a3886cb61634e071a57 100644 --- a/cpu/kernelcreation.py +++ b/cpu/kernelcreation.py @@ -120,7 +120,7 @@ def createIndexedKernel(listOfEquations, indexFields, functionName="kernel", typ loopBody.append(assignment) functionBody = Block([loopNode]) - ast = KernelFunction(functionBody, functionName) + ast = KernelFunction(functionBody, functionName=functionName) fixedCoordinateMapping = {f.name: coordinateTypedSymbols for f in nonIndexFields} resolveFieldAccesses(ast, set(['indexField']), fieldToFixedCoordinates=fixedCoordinateMapping) diff --git a/gpucuda/kernelcreation.py b/gpucuda/kernelcreation.py index fd0c175644c5b50a9411b0278a8ba7ed470e8fd2..fa067348851381353d83a24a021e71ac79ff96d4 100644 --- a/gpucuda/kernelcreation.py +++ b/gpucuda/kernelcreation.py @@ -35,7 +35,7 @@ def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=None, block = Block(assignments) block = indexing.guard(block, commonShape) - ast = KernelFunction(block, functionName) + ast = KernelFunction(block, functionName=functionName, ghostLayers=ghostLayers) ast.globalVariables.update(indexing.indexVariables) coordMapping = indexing.coordinates @@ -95,7 +95,7 @@ def createdIndexedCUDAKernel(listOfEquations, indexFields, functionName="kernel" functionBody = Block(coordinateSymbolAssignments + assignments) functionBody = indexing.guard(functionBody, getCommonShape(indexFields)) - ast = KernelFunction(functionBody, functionName) + ast = KernelFunction(functionBody, functionName=functionName) ast.globalVariables.update(indexing.indexVariables) coordMapping = indexing.coordinates diff --git a/transformations.py b/transformations.py index 917a705b5048cdd6b60e3a304217a5e8667bdec4..0975e96aa677bb5b5cb8c1e859d1db17082ea987 100644 --- a/transformations.py +++ b/transformations.py @@ -101,7 +101,7 @@ def makeLoopOverDomain(body, functionName, iterationSlice=None, ghostLayers=None assignment = ast.SympyAssignment(ast.LoopOverCoordinate.getLoopCounterSymbol(loopCoordinate), sp.sympify(sliceComponent)) currentBody.insertFront(assignment) - return ast.KernelFunction(currentBody, functionName) + return ast.KernelFunction(currentBody, ghostLayers=ghostLayers, functionName=functionName) def createIntermediateBasePointer(fieldAccess, coordinates, previousPtr): diff --git a/types.py b/types.py index 2fb1fcf9328789a7983ab3f2c00c1b5aa52d77aa..86ad051b3ec63ed33c7e43ffbc39518cb66a54c6 100644 --- a/types.py +++ b/types.py @@ -1,8 +1,8 @@ import ctypes import sympy as sp import numpy as np -# import llvmlite.ir as ir from sympy.core.cache import cacheit +from pystencils.cache import memorycache class TypedSymbol(sp.Symbol): @@ -115,7 +115,6 @@ def toCtypes(dataType): else: return toCtypes.map[dataType.numpyDtype] - toCtypes.map = { np.dtype(np.int8): ctypes.c_int8, np.dtype(np.int16): ctypes.c_int16, @@ -132,33 +131,10 @@ toCtypes.map = { } -#def to_llvmlite_type(data_type): -# """ -# Transforms a given type into ctypes -# :param data_type: Subclass of Type -# :return: llvmlite type object -# """ -# if isinstance(data_type, PointerType): -# return to_llvmlite_type.map[data_type.baseType].as_pointer() -# else: -# return to_llvmlite_type.map[data_type.numpyDType] -# -#to_llvmlite_type.map = { -# np.dtype(np.int8): ir.IntType(8), -# np.dtype(np.int16): ir.IntType(16), -# np.dtype(np.int32): ir.IntType(32), -# np.dtype(np.int64): ir.IntType(64), -# -# # TODO llvmlite doesn't seem to differentiate between Int types -# np.dtype(np.uint8): ir.IntType(8), -# np.dtype(np.uint16): ir.IntType(16), -# np.dtype(np.uint32): ir.IntType(32), -# np.dtype(np.uint64): ir.IntType(64), -# -# np.dtype(np.float32): ir.FloatType(), -# np.dtype(np.float64): ir.DoubleType(), -# # TODO const, restrict, void -#} +def getTypeOfExpression(expr): + if isinstance(expr, TypedSymbol): + return expr.dtype + class Type(sp.Basic):