Commit 90c9bb3c authored by Martin Bauer's avatar Martin Bauer
Browse files

lbmpy: now can build periodicity into kernel

parent f4cf9352
import sympy as sp
from sympy.tensor import IndexedBase
from pystencils.field import Field
from pystencils.types import TypedSymbol, createType, get_type_from_sympy
from pystencils.types import TypedSymbol, createType, get_type_from_sympy, createTypeFromString
class Node(object):
......@@ -301,6 +301,16 @@ class LoopOverCoordinate(Node):
def loopCounterName(self):
return LoopOverCoordinate.getLoopCounterName(self.coordinateToLoopOver)
@staticmethod
def isLoopCounterSymbol(symbol):
prefix = LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
if not symbol.name.startswith(prefix):
return None
if symbol.dtype != createTypeFromString('int'):
return None
coordinate = int(symbol.name[len(prefix)+1:])
return coordinate
@staticmethod
def getLoopCounterSymbol(coordinateToLoopOver):
return TypedSymbol(LoopOverCoordinate.getLoopCounterName(coordinateToLoopOver), 'int')
......
from pystencils.gpucuda.indexing import BlockIndexing
from pystencils.transformations import resolveFieldAccesses, typeAllEquations, parseBasePointerInfo, getCommonShape
from pystencils.astnodes import Block, KernelFunction, SympyAssignment
from pystencils.astnodes import Block, KernelFunction, SympyAssignment, LoopOverCoordinate
from pystencils.types import TypedSymbol, BasicType, StructType
from pystencils import Field
......@@ -25,7 +25,7 @@ def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=None,
iterationSlice = []
if isinstance(ghostLayers, int):
for i in range(len(commonShape)):
iterationSlice.append(slice(ghostLayers[i], -ghostLayers[i] if ghostLayers[i] > 0 else None))
iterationSlice.append(slice(ghostLayers, -ghostLayers if ghostLayers > 0 else None))
else:
for i in range(len(commonShape)):
iterationSlice.append(slice(ghostLayers[i][0], -ghostLayers[i][1] if ghostLayers[i][1] > 0 else None))
......@@ -46,6 +46,14 @@ def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=None,
fieldToBasePointerInfo=basePointerInfos)
# add the function which determines #blocks and #threads as additional member to KernelFunction node
# this is used by the jit
# If loop counter symbols have been explicitly used in the update equations (e.g. for built in periodicity),
# they are defined here
undefinedLoopCounters = {LoopOverCoordinate.isLoopCounterSymbol(s): s for s in ast.body.undefinedSymbols
if LoopOverCoordinate.isLoopCounterSymbol(s) is not None}
for i, loopCounter in undefinedLoopCounters.items():
ast.body.insertFront(SympyAssignment(loopCounter, indexing.coordinates[i]))
ast.indexing = indexing
return ast
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment