Skip to content
Snippets Groups Projects
Commit 6627dffa authored by Martin Bauer's avatar Martin Bauer
Browse files

Fields with known index shape, check now for out-ouf-bounds errors

parent 4e5faf85
No related merge requests found
...@@ -4,7 +4,7 @@ import numpy as np ...@@ -4,7 +4,7 @@ import numpy as np
import sympy as sp import sympy as sp
from sympy.core.cache import cacheit from sympy.core.cache import cacheit
from sympy.tensor import IndexedBase from sympy.tensor import IndexedBase
from pystencils.data_types import TypedSymbol, createType, createCompositeTypeFromString from pystencils.data_types import TypedSymbol, createType, createCompositeTypeFromString, StructType
from pystencils.sympyextensions import isIntegerSequence from pystencils.sympyextensions import isIntegerSequence
...@@ -100,7 +100,6 @@ class Field(object): ...@@ -100,7 +100,6 @@ class Field(object):
shape = tuple([shapeSymbol[i] for i in range(totalDimensions)]) shape = tuple([shapeSymbol[i] for i in range(totalDimensions)])
else: else:
shape = tuple([shapeSymbol[i] for i in range(spatialDimensions)] + list(indexShape)) shape = tuple([shapeSymbol[i] for i in range(spatialDimensions)] + list(indexShape))
assert len(shape) == totalDimensions
strides = tuple([strideSymbol[i] for i in range(totalDimensions)]) strides = tuple([strideSymbol[i] for i in range(totalDimensions)])
...@@ -252,6 +251,19 @@ class Field(object): ...@@ -252,6 +251,19 @@ class Field(object):
def neighbors(self, stencil): def neighbors(self, stencil):
return [self.__getitem__(s) for s in stencil] return [self.__getitem__(s) for s in stencil]
@property
def vecCenter(self):
indexShape = self.indexShape
if len(indexShape) == 0:
return self.center
elif len(indexShape) == 1:
return sp.Matrix([self(i) for i in range(indexShape[0])])
elif len(indexShape) == 2:
def cb(*args):
r = self.__call__(*args)
return r
return sp.Matrix(*indexShape, cb)
@property @property
def center(self): def center(self):
center = tuple([0] * self.spatialDimensions) center = tuple([0] * self.spatialDimensions)
...@@ -281,6 +293,7 @@ class Field(object): ...@@ -281,6 +293,7 @@ class Field(object):
otherTuple = (other.shape, other.strides, other.name, other.dtype, other.fieldType) otherTuple = (other.shape, other.strides, other.name, other.dtype, other.fieldType)
return selfTuple == otherTuple return selfTuple == otherTuple
PREFIX = "f" PREFIX = "f"
STRIDE_PREFIX = PREFIX + "stride_" STRIDE_PREFIX = PREFIX + "stride_"
SHAPE_PREFIX = PREFIX + "shape_" SHAPE_PREFIX = PREFIX + "shape_"
...@@ -296,7 +309,7 @@ class Field(object): ...@@ -296,7 +309,7 @@ class Field(object):
def __new_stage2__(self, field, offsets=(0, 0, 0), idx=None): def __new_stage2__(self, field, offsets=(0, 0, 0), idx=None):
fieldName = field.name fieldName = field.name
offsetsAndIndex = chain(offsets, idx) if idx is not None else offsets offsetsAndIndex = chain(offsets, idx) if idx is not None else offsets
constantOffsets = not any([isinstance(o, sp.Basic) for o in offsetsAndIndex]) constantOffsets = not any([isinstance(o, sp.Basic) and not o.is_Integer for o in offsetsAndIndex])
if not idx: if not idx:
idx = tuple([0] * field.indexDimensions) idx = tuple([0] * field.indexDimensions)
...@@ -310,6 +323,10 @@ class Field(object): ...@@ -310,6 +323,10 @@ class Field(object):
else: else:
idxStr = ",".join([str(e) for e in idx]) idxStr = ",".join([str(e) for e in idx])
superscript = idxStr superscript = idxStr
if field.hasFixedIndexShape and not isinstance(field.dtype, StructType):
for i, bound in zip(idx, field.indexShape):
if i >= bound:
raise ValueError("Field index out of bounds")
else: else:
offsetName = "%0.10X" % (abs(hash(tuple(offsetsAndIndex)))) offsetName = "%0.10X" % (abs(hash(tuple(offsetsAndIndex))))
superscript = None superscript = None
...@@ -340,7 +357,6 @@ class Field(object): ...@@ -340,7 +357,6 @@ class Field(object):
def __call__(self, *idx): def __call__(self, *idx):
if self._index != tuple([0]*self.field.indexDimensions): if self._index != tuple([0]*self.field.indexDimensions):
print(self._index, tuple([0]*self.field.indexDimensions))
raise ValueError("Indexing an already indexed Field.Access") raise ValueError("Indexing an already indexed Field.Access")
idx = tuple(idx) idx = tuple(idx)
...@@ -395,6 +411,11 @@ class Field(object): ...@@ -395,6 +411,11 @@ class Field(object):
def getNeighbor(self, *offsets): def getNeighbor(self, *offsets):
return Field.Access(self.field, offsets, self.index) return Field.Access(self.field, offsets, self.index)
def neighbor(self, coordId, offset):
offsetList = list(self.offsets)
offsetList[coordId] += offset
return Field.Access(self.field, tuple(offsetList), self.index)
def getShifted(self, *shift): def getShifted(self, *shift):
return Field.Access(self.field, tuple(a + b for a, b in zip(shift, self.offsets)), self.index) return Field.Access(self.field, tuple(a + b for a, b in zip(shift, self.offsets)), self.index)
...@@ -622,7 +643,8 @@ def directionStringToOffset(directionStr, dim=3): ...@@ -622,7 +643,8 @@ def directionStringToOffset(directionStr, dim=3):
if __name__ == '__main__': if __name__ == '__main__':
f = Field.createGeneric('f', spatialDimensions=2, indexDimensions=1) f = Field.createGeneric('f', spatialDimensions=2, indexShape=(2,4))
f(2, 0)
fa = f[0, 1](4) ** 2 fa = f[0, 1](4) ** 2
print(fa) print(fa)
print(sp.latex(fa)) print(sp.latex(fa))
\ No newline at end of file
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment