Commit c0f31ce6 authored by Martin Bauer's avatar Martin Bauer
Browse files

Storing shape & strides in constants, not reading from array

-> reading from arrays was slower with some compilers
parent 60b245f8
......@@ -3,10 +3,10 @@ from operator import attrgetter
import sympy as sp
from sympy.logic.boolalg import Boolean
from sympy.tensor import IndexedBase
from sympy.tensor import IndexedBase, Indexed
from pystencils.field import Field, offsetComponentToDirectionString
from pystencils.types import TypedSymbol, createType, PointerType, StructType, getBaseType
from pystencils.types import TypedSymbol, createType, PointerType, StructType, getBaseType, createTypeFromString
from pystencils.slicing import normalizeSlice
import pystencils.astnodes as ast
......@@ -219,6 +219,34 @@ def parseBasePointerInfo(basePointerSpecification, loopOrder, field):
return result
def substituteShapeAndStrideWithConstants(kernelFunctionAstNode):
fieldAccesses = kernelFunctionAstNode.atoms(Field.Access)
fields = {fa.field for fa in fieldAccesses}
fields = list(fields)
fields.sort(key=lambda f:
for field in fields:
if all(isinstance(e, Indexed) for e in field.shape):
newShape = []
shapeDtype = getBaseType(createTypeFromString(Field.SHAPE_DTYPE))
shapeDtype.const = False
for i, shape in enumerate(field.shape):
symbol = TypedSymbol("shapeConst_%s_%d" % (, i), shapeDtype)
kernelFunctionAstNode.body.insertFront(ast.SympyAssignment(symbol, shape))
field.shape = tuple(newShape)
if all(isinstance(e, Indexed) for e in field.strides):
newStrides = []
strideDtype = getBaseType(createTypeFromString(Field.STRIDE_DTYPE))
strideDtype.const = False
for i, stride in enumerate(field.strides):
symbol = TypedSymbol("strideConst_%s_%d" % (, i), strideDtype)
kernelFunctionAstNode.body.insertFront(ast.SympyAssignment(symbol, stride))
field.strides = tuple(newStrides)
def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerInfo={}, fieldToFixedCoordinates={}):
Substitutes :class:`pystencils.field.Field.Access` nodes by array indexing
......@@ -231,6 +259,7 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn
counters to index the field these symbols are used as coordinates
:return: transformed AST
fieldToBasePointerInfo = OrderedDict(sorted(fieldToBasePointerInfo.items(), key=lambda pair: pair[0]))
fieldToFixedCoordinates = OrderedDict(sorted(fieldToFixedCoordinates.items(), key=lambda pair: pair[0]))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment