Commit 76fb5eab authored by Martin Bauer's avatar Martin Bauer
Browse files

Bugfix in string to layout conversion

parent 3b4deebe
......@@ -62,7 +62,7 @@ class Field(object):
over dimension 0. Also allowed: the strings 'numpy' (0,1,..d) or 'reverseNumpy' (d, ..., 1, 0)
"""
if isinstance(layout, str):
layout = layoutStringToTuple(layout, dim=spatialDimensions)
layout = spatialLayoutStringToTuple(layout, dim=spatialDimensions)
shapeSymbol = IndexedBase(TypedSymbol(Field.SHAPE_PREFIX + fieldName, Field.SHAPE_DTYPE), shape=(1,))
strideSymbol = IndexedBase(TypedSymbol(Field.STRIDE_PREFIX + fieldName, Field.STRIDE_DTYPE), shape=(1,))
totalDimensions = spatialDimensions + indexDimensions
......@@ -409,19 +409,29 @@ def createNumpyArrayWithLayout(shape, layout):
return res
def layoutStringToTuple(layoutStr, dim):
if layoutStr in ('fzyx', 'zyxf') and dim != 4:
if dim == 3:
return tuple(reversed(range(dim)))
else:
raise ValueError("Layout descriptor " + layoutStr + " only valid for dimension 4, not %d" % (dim,))
def spatialLayoutStringToTuple(layoutStr, dim):
if layoutStr in ('fzyx', 'zyxf'):
assert dim <= 3
return tuple(reversed(range(dim)))
if layoutStr == "fzyx" or layoutStr == 'f' or layoutStr == 'reverseNumpy':
return tuple(reversed(range(dim)))
elif layoutStr == 'c' or layoutStr == 'numpy':
return tuple(range(dim))
raise ValueError("Unknown layout descriptor " + layoutStr)
def layoutStringToTuple(layoutStr, dim):
if layoutStr == 'fzyx':
assert dim <= 4
return tuple(reversed(range(dim)))
elif layoutStr == 'zyxf':
return tuple(reversed(range(dim-1))) + (dim,)
assert dim <= 4
return tuple(reversed(range(dim - 1))) + (dim-1,)
elif layoutStr == 'f' or layoutStr == 'reverseNumpy':
return tuple(reversed(range(dim)))
elif layoutStr == 'c' or layoutStr == 'numpy':
return tuple(range(dim))
raise ValueError("Unknown layout descriptor " + layoutStr)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment