Commit 6627dffa authored by Martin Bauer's avatar Martin Bauer
Browse files

Fields with known index shape, check now for out-ouf-bounds errors

parent 4e5faf85
......@@ -4,7 +4,7 @@ import numpy as np
import sympy as sp
from sympy.core.cache import cacheit
from sympy.tensor import IndexedBase
from pystencils.data_types import TypedSymbol, createType, createCompositeTypeFromString
from pystencils.data_types import TypedSymbol, createType, createCompositeTypeFromString, StructType
from pystencils.sympyextensions import isIntegerSequence
......@@ -100,7 +100,6 @@ class Field(object):
shape = tuple([shapeSymbol[i] for i in range(totalDimensions)])
else:
shape = tuple([shapeSymbol[i] for i in range(spatialDimensions)] + list(indexShape))
assert len(shape) == totalDimensions
strides = tuple([strideSymbol[i] for i in range(totalDimensions)])
......@@ -252,6 +251,19 @@ class Field(object):
def neighbors(self, stencil):
return [self.__getitem__(s) for s in stencil]
@property
def vecCenter(self):
indexShape = self.indexShape
if len(indexShape) == 0:
return self.center
elif len(indexShape) == 1:
return sp.Matrix([self(i) for i in range(indexShape[0])])
elif len(indexShape) == 2:
def cb(*args):
r = self.__call__(*args)
return r
return sp.Matrix(*indexShape, cb)
@property
def center(self):
center = tuple([0] * self.spatialDimensions)
......@@ -281,6 +293,7 @@ class Field(object):
otherTuple = (other.shape, other.strides, other.name, other.dtype, other.fieldType)
return selfTuple == otherTuple
PREFIX = "f"
STRIDE_PREFIX = PREFIX + "stride_"
SHAPE_PREFIX = PREFIX + "shape_"
......@@ -296,7 +309,7 @@ class Field(object):
def __new_stage2__(self, field, offsets=(0, 0, 0), idx=None):
fieldName = field.name
offsetsAndIndex = chain(offsets, idx) if idx is not None else offsets
constantOffsets = not any([isinstance(o, sp.Basic) for o in offsetsAndIndex])
constantOffsets = not any([isinstance(o, sp.Basic) and not o.is_Integer for o in offsetsAndIndex])
if not idx:
idx = tuple([0] * field.indexDimensions)
......@@ -310,6 +323,10 @@ class Field(object):
else:
idxStr = ",".join([str(e) for e in idx])
superscript = idxStr
if field.hasFixedIndexShape and not isinstance(field.dtype, StructType):
for i, bound in zip(idx, field.indexShape):
if i >= bound:
raise ValueError("Field index out of bounds")
else:
offsetName = "%0.10X" % (abs(hash(tuple(offsetsAndIndex))))
superscript = None
......@@ -340,7 +357,6 @@ class Field(object):
def __call__(self, *idx):
if self._index != tuple([0]*self.field.indexDimensions):
print(self._index, tuple([0]*self.field.indexDimensions))
raise ValueError("Indexing an already indexed Field.Access")
idx = tuple(idx)
......@@ -395,6 +411,11 @@ class Field(object):
def getNeighbor(self, *offsets):
return Field.Access(self.field, offsets, self.index)
def neighbor(self, coordId, offset):
offsetList = list(self.offsets)
offsetList[coordId] += offset
return Field.Access(self.field, tuple(offsetList), self.index)
def getShifted(self, *shift):
return Field.Access(self.field, tuple(a + b for a, b in zip(shift, self.offsets)), self.index)
......@@ -622,7 +643,8 @@ def directionStringToOffset(directionStr, dim=3):
if __name__ == '__main__':
f = Field.createGeneric('f', spatialDimensions=2, indexDimensions=1)
f = Field.createGeneric('f', spatialDimensions=2, indexShape=(2,4))
f(2, 0)
fa = f[0, 1](4) ** 2
print(fa)
print(sp.latex(fa))
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment