diff --git a/field.py b/field.py index b64e567e2ae07f09faf1f574fceadabc5a66bcad..945fa7217553a15a4b05955e444b276e09975c1d 100644 --- a/field.py +++ b/field.py @@ -4,7 +4,7 @@ import numpy as np import sympy as sp from sympy.core.cache import cacheit from sympy.tensor import IndexedBase -from pystencils.data_types import TypedSymbol, createType, createCompositeTypeFromString +from pystencils.data_types import TypedSymbol, createType, createCompositeTypeFromString, StructType from pystencils.sympyextensions import isIntegerSequence @@ -100,7 +100,6 @@ class Field(object): shape = tuple([shapeSymbol[i] for i in range(totalDimensions)]) else: shape = tuple([shapeSymbol[i] for i in range(spatialDimensions)] + list(indexShape)) - assert len(shape) == totalDimensions strides = tuple([strideSymbol[i] for i in range(totalDimensions)]) @@ -252,6 +251,19 @@ class Field(object): def neighbors(self, stencil): return [self.__getitem__(s) for s in stencil] + @property + def vecCenter(self): + indexShape = self.indexShape + if len(indexShape) == 0: + return self.center + elif len(indexShape) == 1: + return sp.Matrix([self(i) for i in range(indexShape[0])]) + elif len(indexShape) == 2: + def cb(*args): + r = self.__call__(*args) + return r + return sp.Matrix(*indexShape, cb) + @property def center(self): center = tuple([0] * self.spatialDimensions) @@ -281,6 +293,7 @@ class Field(object): otherTuple = (other.shape, other.strides, other.name, other.dtype, other.fieldType) return selfTuple == otherTuple + PREFIX = "f" STRIDE_PREFIX = PREFIX + "stride_" SHAPE_PREFIX = PREFIX + "shape_" @@ -296,7 +309,7 @@ class Field(object): def __new_stage2__(self, field, offsets=(0, 0, 0), idx=None): fieldName = field.name offsetsAndIndex = chain(offsets, idx) if idx is not None else offsets - constantOffsets = not any([isinstance(o, sp.Basic) for o in offsetsAndIndex]) + constantOffsets = not any([isinstance(o, sp.Basic) and not o.is_Integer for o in offsetsAndIndex]) if not idx: idx = tuple([0] * field.indexDimensions) @@ -310,6 +323,10 @@ class Field(object): else: idxStr = ",".join([str(e) for e in idx]) superscript = idxStr + if field.hasFixedIndexShape and not isinstance(field.dtype, StructType): + for i, bound in zip(idx, field.indexShape): + if i >= bound: + raise ValueError("Field index out of bounds") else: offsetName = "%0.10X" % (abs(hash(tuple(offsetsAndIndex)))) superscript = None @@ -340,7 +357,6 @@ class Field(object): def __call__(self, *idx): if self._index != tuple([0]*self.field.indexDimensions): - print(self._index, tuple([0]*self.field.indexDimensions)) raise ValueError("Indexing an already indexed Field.Access") idx = tuple(idx) @@ -395,6 +411,11 @@ class Field(object): def getNeighbor(self, *offsets): return Field.Access(self.field, offsets, self.index) + def neighbor(self, coordId, offset): + offsetList = list(self.offsets) + offsetList[coordId] += offset + return Field.Access(self.field, tuple(offsetList), self.index) + def getShifted(self, *shift): return Field.Access(self.field, tuple(a + b for a, b in zip(shift, self.offsets)), self.index) @@ -622,7 +643,8 @@ def directionStringToOffset(directionStr, dim=3): if __name__ == '__main__': - f = Field.createGeneric('f', spatialDimensions=2, indexDimensions=1) + f = Field.createGeneric('f', spatialDimensions=2, indexShape=(2,4)) + f(2, 0) fa = f[0, 1](4) ** 2 print(fa) print(sp.latex(fa)) \ No newline at end of file