From 979ee93bc30d2636241be6ff40b49dfcc766a292 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jo=C3=A3o=20Victor=20Tozatti=20Risso?=
 <joaovictortr@protonmail.com>
Date: Wed, 13 Dec 2017 12:36:13 +0100
Subject: [PATCH] Code generation for field serialization into buffers

Concept: Generate code involving the (un)packing of fields (from)to linear
(1D) arrays, i.e. (de)serialization of the field values for buffered
communication.

A linear index is generated for the buffer, by inferring the strides and
variables of the loops over fields in the AST. In the CPU, this information is
obtained through the makeLoopOverDomain function, in
pystencils/transformations/transformations.py. On CUDA, the strides of
the fields (excluding buffers) are combined with the indexing variables to infer
the indexing of the buffer.

What is supported:
    - code generation for both CPU and GPU
    - (un)packing of fields with all the memory layouts supported by
    pystencils
    - (un)packing slices of fields (from)into the buffer
    - (un)packing subsets of cell values from the fields (from)into the buffer

Limitations:

- assumes that only one buffer and one field are being operated within
each kernel, however multiple equations involving the buffer and the
field are supported.

- (un)packing multiple cell values (from)into the buffer is supported,
however it is limited to the fields with indexDimensions=1. The same
applies to (un)packing subset of cell values of each cell.

Changes in this commit:

- add the FieldType enumeration to pystencils/field.py, to mark fields
of various types. This is replaces and is a generalization of the
isIndexedField boolean flag of the Field class. For now, the types
supported are: generic, indexed and buffer fields.

- add the fieldType property to the Field class, which indicates the
type of the field. Modifications were also performed to the member
functions of the Field class to add this property.

- add resolveBufferAccesses function, which replaces the fields marked
as buffers with the actual field access in the AST traversal.

Miscelaneous changes:

- add blockDim and gridDim variables as CUDA indexing variables.
---
 __init__.py                        |  2 +-
 cpu/cpujit.py                      |  5 +-
 cpu/kernelcreation.py              | 30 +++++++++---
 field.py                           | 49 ++++++++++++++-----
 gpucuda/cudajit.py                 |  5 +-
 gpucuda/indexing.py                |  6 ++-
 gpucuda/kernelcreation.py          | 33 ++++++++++---
 llvm/kernelcreation.py             |  6 ++-
 transformations/transformations.py | 75 ++++++++++++++++++++++++++++--
 9 files changed, 171 insertions(+), 40 deletions(-)

diff --git a/__init__.py b/__init__.py
index b78f17d66..27b0b97f8 100644
--- a/__init__.py
+++ b/__init__.py
@@ -1,4 +1,4 @@
-from pystencils.field import Field, extractCommonSubexpressions
+from pystencils.field import Field, FieldType, extractCommonSubexpressions
 from pystencils.data_types import TypedSymbol
 from pystencils.slicing import makeSlice
 from pystencils.kernelcreation import createKernel, createIndexedKernel
diff --git a/cpu/cpujit.py b/cpu/cpujit.py
index 69c3d8cd8..28184d6f3 100644
--- a/cpu/cpujit.py
+++ b/cpu/cpujit.py
@@ -74,6 +74,7 @@ from pystencils.backends.cbackend import generateC, getHeaders
 from collections import OrderedDict, Mapping
 from pystencils.transformations import symbolNameToVariableName
 from pystencils.data_types import toCtypes, getBaseType, StructType
+from pystencils.field import FieldType
 
 
 def makePythonFunction(kernelFunctionNode, argumentDict={}):
@@ -387,9 +388,9 @@ def buildCTypeArgumentList(parameterSpecification, argumentDict):
                         raise ValueError("Passed array '%s' has strides %s which does not match expected strides %s" %
                                          (arg.fieldName, str(fieldArr.strides), str(symbolicFieldStrides)))
 
-                if symbolicField.isIndexField:
+                if FieldType.isIndexed(symbolicField):
                     indexArrShapes.add(fieldArr.shape[:symbolicField.spatialDimensions])
-                else:
+                elif not FieldType.isBuffer(symbolicField):
                     arrayShapes.add(fieldArr.shape[:symbolicField.spatialDimensions])
 
             elif arg.isFieldShapeArgument:
diff --git a/cpu/kernelcreation.py b/cpu/kernelcreation.py
index 298adea4d..91b6dc3eb 100644
--- a/cpu/kernelcreation.py
+++ b/cpu/kernelcreation.py
@@ -4,11 +4,11 @@ from functools import partial
 from collections import defaultdict
 
 from pystencils.astnodes import SympyAssignment, Block, LoopOverCoordinate, KernelFunction
-from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, \
+from pystencils.transformations import resolveBufferAccesses, resolveFieldAccesses, makeLoopOverDomain, \
     typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop, \
     substituteArrayAccessesWithConstants
 from pystencils.data_types import TypedSymbol, BasicType, StructType, createType
-from pystencils.field import Field
+from pystencils.field import Field, FieldType
 import pystencils.astnodes as ast
 from pystencils.cpu.cpujit import makePythonFunction
 
@@ -51,10 +51,13 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol='double',
     allFields = fieldsRead.union(fieldsWritten)
     readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten])
 
+    buffers = set([f for f in allFields if FieldType.isBuffer(f)])
+    fieldsWithoutBuffers = allFields - buffers
+
     body = ast.Block(assignments)
-    loopOrder = getOptimalLoopOrdering(allFields)
-    code = makeLoopOverDomain(body, functionName, iterationSlice=iterationSlice,
-                              ghostLayers=ghostLayers, loopOrder=loopOrder)
+    loopOrder = getOptimalLoopOrdering(fieldsWithoutBuffers)
+    code, loopStrides, loopVars = makeLoopOverDomain(body, functionName, iterationSlice=iterationSlice,
+                                                     ghostLayers=ghostLayers, loopOrder=loopOrder)
     code.target = 'cpu'
 
     if splitGroups:
@@ -62,8 +65,20 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol='double',
         splitInnerLoop(code, typedSplitGroups)
 
     basePointerInfo = [['spatialInner0'], ['spatialInner1']] if len(loopOrder) >= 2 else [['spatialInner0']]
-    basePointerInfos = {field.name: parseBasePointerInfo(basePointerInfo, loopOrder, field) for field in allFields}
+    basePointerInfos = {field.name: parseBasePointerInfo(basePointerInfo, loopOrder, field)
+                        for field in fieldsWithoutBuffers}
+
+    bufferBasePointerInfos = {field.name: parseBasePointerInfo([['spatialInner0']], [0], field) for field in buffers}
+    basePointerInfos.update(bufferBasePointerInfos)
+
+    baseBufferIndex = loopVars[0]
+    stride = 1
+    for idx, var in enumerate(loopVars[1:]):
+        curStride = loopStrides[idx]
+        stride *= int(curStride) if isinstance(curStride, float) else curStride
+        baseBufferIndex += var * stride
 
+    resolveBufferAccesses(code, baseBufferIndex, readOnlyFields)
     resolveFieldAccesses(code, readOnlyFields, fieldToBasePointerInfo=basePointerInfos)
     substituteArrayAccessesWithConstants(code)
     moveConstantsBeforeLoop(code)
@@ -93,7 +108,8 @@ def createIndexedKernel(listOfEquations, indexFields, functionName="kernel", typ
     allFields = fieldsRead.union(fieldsWritten)
 
     for indexField in indexFields:
-        indexField.isIndexField = True
+        indexField.fieldType = FieldType.INDEXED
+        assert FieldType.isIndexed(indexField)
         assert indexField.spatialDimensions == 1, "Index fields have to be 1D"
 
     nonIndexFields = [f for f in allFields if f not in indexFields]
diff --git a/field.py b/field.py
index d545d15c8..51d08df97 100644
--- a/field.py
+++ b/field.py
@@ -1,3 +1,4 @@
+from enum import Enum
 from itertools import chain
 import numpy as np
 import sympy as sp
@@ -7,6 +8,31 @@ from pystencils.data_types import TypedSymbol, createType, createCompositeTypeFr
 from pystencils.sympyextensions import isIntegerSequence
 
 
+class FieldType(Enum):
+    # generic fields
+    GENERIC = 0
+    # index fields are currently only used for boundary handling
+    # the coordinates are not the loop counters in that case, but are read from this index field
+    INDEXED = 1
+    # communication buffer, used for (un)packing data in communication.
+    BUFFER = 2
+
+    @staticmethod
+    def isGeneric(field):
+        assert isinstance(field, Field)
+        return field.fieldType == FieldType.GENERIC
+
+    @staticmethod
+    def isIndexed(field):
+        assert isinstance(field, Field)
+        return field.fieldType == FieldType.INDEXED
+
+    @staticmethod
+    def isBuffer(field):
+        assert isinstance(field, Field)
+        return field.fieldType == FieldType.BUFFER
+
+
 class Field(object):
     """
     With fields one can formulate stencil-like update rules on structured grids.
@@ -51,7 +77,7 @@ class Field(object):
 
     @staticmethod
     def createGeneric(fieldName, spatialDimensions, dtype=np.float64, indexDimensions=0, layout='numpy',
-                      indexShape=None):
+                      indexShape=None, fieldType=FieldType.GENERIC):
         """
         Creates a generic field where the field size is not fixed i.e. can be called with arrays of different sizes
 
@@ -85,7 +111,7 @@ class Field(object):
             shape += (1,)
             strides += (1,)
 
-        return Field(fieldName, dtype, layout, shape, strides)
+        return Field(fieldName, fieldType, dtype, layout, shape, strides)
 
     @staticmethod
     def createFromNumpyArray(fieldName, npArray, indexDimensions=0):
@@ -114,7 +140,7 @@ class Field(object):
             shape += (1,)
             strides += (1,)
 
-        return Field(fieldName, npArray.dtype, spatialLayout, shape, strides)
+        return Field(fieldName, FieldType.GENERIC, npArray.dtype, spatialLayout, shape, strides)
 
     @staticmethod
     def createFixedSize(fieldName, shape, indexDimensions=0, dtype=np.float64, layout='numpy'):
@@ -146,21 +172,20 @@ class Field(object):
         spatialLayout = list(layout)
         for i in range(spatialDimensions, len(layout)):
             spatialLayout.remove(i)
-        return Field(fieldName, dtype, tuple(spatialLayout), shape, strides)
+        return Field(fieldName, FieldType.GENERIC, dtype, tuple(spatialLayout), shape, strides)
 
-    def __init__(self, fieldName, dtype, layout, shape, strides):
+    def __init__(self, fieldName, fieldType, dtype, layout, shape, strides):
         """Do not use directly. Use static create* methods"""
         self._fieldName = fieldName
+        assert isinstance(fieldType, FieldType)
+        self.fieldType = fieldType
         self._dtype = createType(dtype)
         self._layout = normalizeLayout(layout)
         self.shape = shape
         self.strides = strides
-        # index fields are currently only used for boundary handling
-        # the coordinates are not the loop counters in that case, but are read from this index field
-        self.isIndexField = False
 
     def newFieldWithDifferentName(self, newName):
-        return Field(newName, self._dtype, self._layout, self.shape, self.strides)
+        return Field(newName, self.fieldType, self._dtype, self._layout, self.shape, self.strides)
 
     @property
     def spatialDimensions(self):
@@ -243,11 +268,11 @@ class Field(object):
         return Field.Access(self, center)(*args, **kwargs)
 
     def __hash__(self):
-        return hash((self._layout, self.shape, self.strides, self._dtype, self._fieldName))
+        return hash((self._layout, self.shape, self.strides, self._dtype, self.fieldType, self._fieldName))
 
     def __eq__(self, other):
-        selfTuple = (self.shape, self.strides, self.name, self.dtype)
-        otherTuple = (other.shape, other.strides, other.name, other.dtype)
+        selfTuple = (self.shape, self.strides, self.name, self.dtype, self.fieldType)
+        otherTuple = (other.shape, other.strides, other.name, other.dtype, other.fieldType)
         return selfTuple == otherTuple
 
     PREFIX = "f"
diff --git a/gpucuda/cudajit.py b/gpucuda/cudajit.py
index b815f5c33..3a0fe85f4 100644
--- a/gpucuda/cudajit.py
+++ b/gpucuda/cudajit.py
@@ -2,6 +2,7 @@ import numpy as np
 from pystencils.backends.cbackend import generateC
 from pystencils.transformations import symbolNameToVariableName
 from pystencils.data_types import StructType, getBaseType
+from pystencils.field import FieldType
 
 
 def makePythonFunction(kernelFunctionNode, argumentDict={}):
@@ -119,9 +120,9 @@ def _checkArguments(parameterSpecification, argumentDict):
                         raise ValueError("Passed array '%s' has strides %s which does not match expected strides %s" %
                                          (arg.fieldName, str(fieldArr.strides), str(symbolicFieldStrides)))
 
-                if symbolicField.isIndexField:
+                if FieldType.isIndexed(symbolicField):
                     indexArrShapes.add(fieldArr.shape[:symbolicField.spatialDimensions])
-                else:
+                elif not FieldType.isBuffer(symbolicField):
                     arrayShapes.add(fieldArr.shape[:symbolicField.spatialDimensions])
 
     if len(arrayShapes) > 1:
diff --git a/gpucuda/indexing.py b/gpucuda/indexing.py
index 99db1efc9..aa05078e8 100644
--- a/gpucuda/indexing.py
+++ b/gpucuda/indexing.py
@@ -11,6 +11,8 @@ AUTO_BLOCKSIZE_LIMITING = True
 
 BLOCK_IDX = [TypedSymbol("blockIdx." + coord, createType("int")) for coord in ('x', 'y', 'z')]
 THREAD_IDX = [TypedSymbol("threadIdx." + coord, createType("int")) for coord in ('x', 'y', 'z')]
+BLOCK_DIM = [TypedSymbol("blockDim." + coord, createType("int")) for coord in ('x', 'y', 'z')]
+GRID_DIM = [TypedSymbol("gridDim." + coord, createType("int")) for coord in ('x', 'y', 'z')]
 
 
 class AbstractIndexing(abc.ABCMeta('ABC', (object,), {})):
@@ -28,8 +30,8 @@ class AbstractIndexing(abc.ABCMeta('ABC', (object,), {})):
 
     @property
     def indexVariables(self):
-        """Sympy symbols for CUDA's block and thread indices"""
-        return BLOCK_IDX + THREAD_IDX
+        """Sympy symbols for CUDA's block and thread indices, and block and grid dimensions. """
+        return BLOCK_IDX + THREAD_IDX + BLOCK_DIM + GRID_DIM
 
     @abc.abstractmethod
     def getCallParameters(self, arrShape):
diff --git a/gpucuda/kernelcreation.py b/gpucuda/kernelcreation.py
index 738709362..5bd2f5d7d 100644
--- a/gpucuda/kernelcreation.py
+++ b/gpucuda/kernelcreation.py
@@ -2,10 +2,10 @@ from functools import partial
 
 from pystencils.gpucuda.indexing import BlockIndexing
 from pystencils.transformations import resolveFieldAccesses, typeAllEquations, parseBasePointerInfo, getCommonShape, \
-    substituteArrayAccessesWithConstants
+    substituteArrayAccessesWithConstants, resolveBufferAccesses
 from pystencils.astnodes import Block, KernelFunction, SympyAssignment, LoopOverCoordinate
 from pystencils.data_types import TypedSymbol, BasicType, StructType
-from pystencils import Field
+from pystencils import Field, FieldType
 from pystencils.gpucuda.cudajit import makePythonFunction
 
 
@@ -15,11 +15,18 @@ def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=None,
     allFields = fieldsRead.union(fieldsWritten)
     readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten])
 
+    buffers = set([f for f in allFields if FieldType.isBuffer(f)])
+    fieldsWithoutBuffers = allFields - buffers
+
     fieldAccesses = set()
+    numBufferAccesses = 0
     for eq in listOfEquations:
         fieldAccesses.update(eq.atoms(Field.Access))
 
-    commonShape = getCommonShape(allFields)
+        numBufferAccesses += sum([1 for access in eq.atoms(Field.Access) if FieldType.isBuffer(access.field)])
+
+    commonShape = getCommonShape(fieldsWithoutBuffers)
+
     if iterationSlice is None:
         # determine iteration slice from ghost layers
         if ghostLayers is None:
@@ -34,7 +41,7 @@ def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=None,
             for i in range(len(commonShape)):
                 iterationSlice.append(slice(ghostLayers[i][0], -ghostLayers[i][1] if ghostLayers[i][1] > 0 else None))
 
-    indexing = indexingCreator(field=list(allFields)[0], iterationSlice=iterationSlice)
+    indexing = indexingCreator(field=list(fieldsWithoutBuffers)[0], iterationSlice=iterationSlice)
 
     block = Block(assignments)
     block = indexing.guard(block, commonShape)
@@ -46,8 +53,19 @@ def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=None,
     basePointerInfos = {f.name: parseBasePointerInfo(basePointerInfo, [2, 1, 0], f) for f in allFields}
 
     coordMapping = {f.name: coordMapping for f in allFields}
-    resolveFieldAccesses(ast, readOnlyFields, fieldToFixedCoordinates=coordMapping,
-                         fieldToBasePointerInfo=basePointerInfos)
+
+    loopVars = [numBufferAccesses * i for i in indexing.coordinates]
+    loopStrides = list(fieldsWithoutBuffers)[0].shape
+
+    baseBufferIndex = loopVars[0]
+    stride = 1
+    for idx, var in enumerate(loopVars[1:]):
+        stride *= loopStrides[idx]
+        baseBufferIndex += var * stride
+
+    resolveBufferAccesses(ast, baseBufferIndex, readOnlyFields)
+    resolveFieldAccesses(ast, readOnlyFields, fieldToBasePointerInfo=basePointerInfos,
+                         fieldToFixedCoordinates=coordMapping)
 
     substituteArrayAccessesWithConstants(ast)
 
@@ -73,7 +91,8 @@ def createdIndexedCUDAKernel(listOfEquations, indexFields, functionName="kernel"
     readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten])
 
     for indexField in indexFields:
-        indexField.isIndexField = True
+        indexField.fieldType = FieldType.INDEXED
+        assert FieldType.isIndexed(indexField)
         assert indexField.spatialDimensions == 1, "Index fields have to be 1D"
 
     nonIndexFields = [f for f in allFields if f not in indexFields]
diff --git a/llvm/kernelcreation.py b/llvm/kernelcreation.py
index 78cd1166a..121dcdd92 100644
--- a/llvm/kernelcreation.py
+++ b/llvm/kernelcreation.py
@@ -1,7 +1,8 @@
 from pystencils.astnodes import SympyAssignment, Block, LoopOverCoordinate, KernelFunction
-from pystencils.transformations import resolveFieldAccesses, \
+from pystencils.transformations import resolveFieldAccesses, resolveBufferAccesses, \
     typeAllEquations, moveConstantsBeforeLoop, insertCasts
 from pystencils.data_types import TypedSymbol, BasicType, StructType
+from pystencils.field import Field, FieldType
 from functools import partial
 from pystencils.llvm.llvmjit import makePythonFunction
 
@@ -57,7 +58,8 @@ def createIndexedKernel(listOfEquations, indexFields, functionName="kernel", typ
     allFields = fieldsRead.union(fieldsWritten)
 
     for indexField in indexFields:
-        indexField.isIndexField = True
+        indexField.fieldType = FieldType.INDEXED
+        assert FieldType.isIndexed(indexField)
         assert indexField.spatialDimensions == 1, "Index fields have to be 1D"
 
     nonIndexFields = [f for f in allFields if f not in indexFields]
diff --git a/transformations/transformations.py b/transformations/transformations.py
index d41cadc01..bf9426f32 100644
--- a/transformations/transformations.py
+++ b/transformations/transformations.py
@@ -6,7 +6,7 @@ import sympy as sp
 from sympy.logic.boolalg import Boolean
 from sympy.tensor import IndexedBase
 
-from pystencils.field import Field, offsetComponentToDirectionString
+from pystencils.field import Field, FieldType, offsetComponentToDirectionString
 from pystencils.data_types import TypedSymbol, createType, PointerType, StructType, getBaseType, castFunc
 from pystencils.slicing import normalizeSlice
 import pystencils.astnodes as ast
@@ -71,13 +71,15 @@ def makeLoopOverDomain(body, functionName, iterationSlice=None, ghostLayers=None
     """
     # find correct ordering by inspecting participating FieldAccesses
     fieldAccesses = body.atoms(Field.Access)
-    fieldList = [e.field for e in fieldAccesses]
+    # exclude accesses to buffers from fieldList, because buffers are treated separately
+    fieldList = [e.field for e in fieldAccesses if not FieldType.isBuffer(e.field)]
     fields = set(fieldList)
+    numBufferAccesses = len(fieldAccesses) - len(fieldList)
 
     if loopOrder is None:
         loopOrder = getOptimalLoopOrdering(fields)
 
-    shape = getCommonShape(fields)
+    shape = getCommonShape(list(fields))
 
     if iterationSlice is not None:
         iterationSlice = normalizeSlice(iterationSlice, shape)
@@ -88,6 +90,11 @@ def makeLoopOverDomain(body, functionName, iterationSlice=None, ghostLayers=None
     if isinstance(ghostLayers, int):
         ghostLayers = [(ghostLayers, ghostLayers)] * len(loopOrder)
 
+    def getLoopStride(begin, end, step):
+        return (end - begin) / step
+
+    loopStrides = []
+    loopVars = []
     currentBody = body
     lastLoop = None
     for i, loopCoordinate in enumerate(reversed(loopOrder)):
@@ -97,6 +104,8 @@ def makeLoopOverDomain(body, functionName, iterationSlice=None, ghostLayers=None
             newLoop = ast.LoopOverCoordinate(currentBody, loopCoordinate, begin, end, 1)
             lastLoop = newLoop
             currentBody = ast.Block([lastLoop])
+            loopStrides.append(getLoopStride(begin, end, 1))
+            loopVars.append(newLoop.loopCounterSymbol)
         else:
             sliceComponent = iterationSlice[loopCoordinate]
             if type(sliceComponent) is slice:
@@ -104,11 +113,16 @@ def makeLoopOverDomain(body, functionName, iterationSlice=None, ghostLayers=None
                 newLoop = ast.LoopOverCoordinate(currentBody, loopCoordinate, sc.start, sc.stop, sc.step)
                 lastLoop = newLoop
                 currentBody = ast.Block([lastLoop])
+                loopStrides.append(getLoopStride(sc.start, sc.stop, sc.step))
+                loopVars.append(newLoop.loopCounterSymbol)
             else:
                 assignment = ast.SympyAssignment(ast.LoopOverCoordinate.getLoopCounterSymbol(loopCoordinate),
                                                  sp.sympify(sliceComponent))
                 currentBody.insertFront(assignment)
-    return ast.KernelFunction(currentBody, ghostLayers=ghostLayers, functionName=functionName)
+
+    loopVars = [numBufferAccesses * var for var in loopVars]
+    astNode = ast.KernelFunction(currentBody, ghostLayers=ghostLayers, functionName=functionName)
+    return (astNode, loopStrides, loopVars)
 
 
 def createIntermediateBasePointer(fieldAccess, coordinates, previousPtr):
@@ -133,7 +147,6 @@ def createIntermediateBasePointer(fieldAccess, coordinates, previousPtr):
         (ptr_E_2S, x*fstride_myfield[0] + y*fstride_myfield[1] + fstride_myfield[0] - 2*fstride_myfield[1])
     """
     field = fieldAccess.field
-
     offset = 0
     name = ""
     listToHash = []
@@ -158,6 +171,7 @@ def createIntermediateBasePointer(fieldAccess, coordinates, previousPtr):
         name += "%0.6X" % (abs(hash(tuple(listToHash))))
 
     newPtr = TypedSymbol(previousPtr.name + name, previousPtr.dtype)
+
     return newPtr, offset
 
 
@@ -223,6 +237,7 @@ def parseBasePointerInfo(basePointerSpecification, loopOrder, field):
     rest = allCoordinates - specifiedCoordinates
     if rest:
         result.append(list(rest))
+
     return result
 
 
@@ -278,6 +293,52 @@ def substituteArrayAccessesWithConstants(astNode):
         for a in astNode.args:
             substituteArrayAccessesWithConstants(a)
 
+def resolveBufferAccesses(astNode, baseBufferIndex, readOnlyFieldNames=set()):
+    def visitSympyExpr(expr, enclosingBlock, sympyAssignment):
+        if isinstance(expr, Field.Access):
+            fieldAccess = expr
+
+            # Do not apply transformation if field is not a buffer
+            if not FieldType.isBuffer(fieldAccess.field):
+                return expr
+
+            buffer = fieldAccess.field
+
+            dtype = PointerType(buffer.dtype, const=buffer.name in readOnlyFieldNames, restrict=True)
+            fieldPtr = TypedSymbol("%s%s" % (Field.DATA_PREFIX, symbolNameToVariableName(buffer.name)), dtype)
+
+            bufferIndex = baseBufferIndex
+            if len(fieldAccess.index) > 1:
+                raise RuntimeError('Only indexing dimensions up to 1 are currently supported in buffers!')
+
+            if len(fieldAccess.index) > 0:
+                cellIndex = fieldAccess.index[0]
+                bufferIndex += cellIndex
+
+            result = ast.ResolvedFieldAccess(fieldPtr, bufferIndex, fieldAccess.field, fieldAccess.offsets,
+                                             fieldAccess.index)
+
+            return visitSympyExpr(result, enclosingBlock, sympyAssignment)
+        else:
+            if isinstance(expr, ast.ResolvedFieldAccess):
+                return expr
+
+            newArgs = [visitSympyExpr(e, enclosingBlock, sympyAssignment) for e in expr.args]
+            kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
+            return expr.func(*newArgs, **kwargs) if newArgs else expr
+
+    def visitNode(subAst):
+        if isinstance(subAst, ast.SympyAssignment):
+            enclosingBlock = subAst.parent
+            assert type(enclosingBlock) is ast.Block
+            subAst.lhs = visitSympyExpr(subAst.lhs, enclosingBlock, subAst)
+            subAst.rhs = visitSympyExpr(subAst.rhs, enclosingBlock, subAst)
+        else:
+            for i, a in enumerate(subAst.args):
+                visitNode(a)
+
+    return visitNode(astNode)
+
 
 def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerInfo={}, fieldToFixedCoordinates={}):
     """
@@ -298,6 +359,7 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn
         if isinstance(expr, Field.Access):
             fieldAccess = expr
             field = fieldAccess.field
+
             if field.name in fieldToBasePointerInfo:
                 basePointerInfo = fieldToBasePointerInfo[field.name]
             else:
@@ -324,6 +386,7 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn
                             coordDict[e] = field.dtype.getElementOffset(accessedFieldName)
                         else:
                             coordDict[e] = fieldAccess.index[e - field.spatialDimensions]
+
                 return coordDict
 
             lastPointer = fieldPtr
@@ -337,6 +400,7 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn
                 lastPointer = newPtr
 
             coordDict = createCoordinateDict(basePointerInfo[0])
+
             _, offset = createIntermediateBasePointer(fieldAccess, coordDict, lastPointer)
             result = ast.ResolvedFieldAccess(lastPointer, offset, fieldAccess.field,
                                              fieldAccess.offsets, fieldAccess.index)
@@ -344,6 +408,7 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn
             if isinstance(getBaseType(fieldAccess.field.dtype), StructType):
                 newType = fieldAccess.field.dtype.getElementType(fieldAccess.index[0])
                 result = castFunc(result, newType)
+
             return visitSympyExpr(result, enclosingBlock, sympyAssignment)
         else:
             if isinstance(expr, ast.ResolvedFieldAccess):
-- 
GitLab