diff --git a/field.py b/field.py index c3a83e07ea1879b0851040008a440e5224c902cc..5be82e17033fe3434c0a1d851482f82c5137a323 100644 --- a/field.py +++ b/field.py @@ -62,7 +62,7 @@ class Field(object): over dimension 0. Also allowed: the strings 'numpy' (0,1,..d) or 'reverseNumpy' (d, ..., 1, 0) """ if isinstance(layout, str): - layout = layoutStringToTuple(layout, dim=spatialDimensions) + layout = spatialLayoutStringToTuple(layout, dim=spatialDimensions) shapeSymbol = IndexedBase(TypedSymbol(Field.SHAPE_PREFIX + fieldName, Field.SHAPE_DTYPE), shape=(1,)) strideSymbol = IndexedBase(TypedSymbol(Field.STRIDE_PREFIX + fieldName, Field.STRIDE_DTYPE), shape=(1,)) totalDimensions = spatialDimensions + indexDimensions @@ -409,19 +409,29 @@ def createNumpyArrayWithLayout(shape, layout): return res -def layoutStringToTuple(layoutStr, dim): - if layoutStr in ('fzyx', 'zyxf') and dim != 4: - if dim == 3: - return tuple(reversed(range(dim))) - else: - raise ValueError("Layout descriptor " + layoutStr + " only valid for dimension 4, not %d" % (dim,)) +def spatialLayoutStringToTuple(layoutStr, dim): + if layoutStr in ('fzyx', 'zyxf'): + assert dim <= 3 + return tuple(reversed(range(dim))) if layoutStr == "fzyx" or layoutStr == 'f' or layoutStr == 'reverseNumpy': return tuple(reversed(range(dim))) elif layoutStr == 'c' or layoutStr == 'numpy': return tuple(range(dim)) + raise ValueError("Unknown layout descriptor " + layoutStr) + + +def layoutStringToTuple(layoutStr, dim): + if layoutStr == 'fzyx': + assert dim <= 4 + return tuple(reversed(range(dim))) elif layoutStr == 'zyxf': - return tuple(reversed(range(dim-1))) + (dim,) + assert dim <= 4 + return tuple(reversed(range(dim - 1))) + (dim-1,) + elif layoutStr == 'f' or layoutStr == 'reverseNumpy': + return tuple(reversed(range(dim))) + elif layoutStr == 'c' or layoutStr == 'numpy': + return tuple(range(dim)) raise ValueError("Unknown layout descriptor " + layoutStr)