From 76fb5eab68f65693360c35824b4087d289250de7 Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Sat, 1 Jul 2017 18:52:16 +0200 Subject: [PATCH] Bugfix in string to layout conversion --- field.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/field.py b/field.py index c3a83e07e..5be82e170 100644 --- a/field.py +++ b/field.py @@ -62,7 +62,7 @@ class Field(object): over dimension 0. Also allowed: the strings 'numpy' (0,1,..d) or 'reverseNumpy' (d, ..., 1, 0) """ if isinstance(layout, str): - layout = layoutStringToTuple(layout, dim=spatialDimensions) + layout = spatialLayoutStringToTuple(layout, dim=spatialDimensions) shapeSymbol = IndexedBase(TypedSymbol(Field.SHAPE_PREFIX + fieldName, Field.SHAPE_DTYPE), shape=(1,)) strideSymbol = IndexedBase(TypedSymbol(Field.STRIDE_PREFIX + fieldName, Field.STRIDE_DTYPE), shape=(1,)) totalDimensions = spatialDimensions + indexDimensions @@ -409,19 +409,29 @@ def createNumpyArrayWithLayout(shape, layout): return res -def layoutStringToTuple(layoutStr, dim): - if layoutStr in ('fzyx', 'zyxf') and dim != 4: - if dim == 3: - return tuple(reversed(range(dim))) - else: - raise ValueError("Layout descriptor " + layoutStr + " only valid for dimension 4, not %d" % (dim,)) +def spatialLayoutStringToTuple(layoutStr, dim): + if layoutStr in ('fzyx', 'zyxf'): + assert dim <= 3 + return tuple(reversed(range(dim))) if layoutStr == "fzyx" or layoutStr == 'f' or layoutStr == 'reverseNumpy': return tuple(reversed(range(dim))) elif layoutStr == 'c' or layoutStr == 'numpy': return tuple(range(dim)) + raise ValueError("Unknown layout descriptor " + layoutStr) + + +def layoutStringToTuple(layoutStr, dim): + if layoutStr == 'fzyx': + assert dim <= 4 + return tuple(reversed(range(dim))) elif layoutStr == 'zyxf': - return tuple(reversed(range(dim-1))) + (dim,) + assert dim <= 4 + return tuple(reversed(range(dim - 1))) + (dim-1,) + elif layoutStr == 'f' or layoutStr == 'reverseNumpy': + return tuple(reversed(range(dim))) + elif layoutStr == 'c' or layoutStr == 'numpy': + return tuple(range(dim)) raise ValueError("Unknown layout descriptor " + layoutStr) -- GitLab