field.py 44.8 KB
Newer Older
1
import functools
Martin Bauer's avatar
Martin Bauer committed
2
import hashlib
3
import operator
Martin Bauer's avatar
Martin Bauer committed
4
5
import pickle
import re
6
from enum import Enum
7
from itertools import chain
Martin Bauer's avatar
Martin Bauer committed
8
9
from typing import List, Optional, Sequence, Set, Tuple

10
11
12
import numpy as np
import sympy as sp
from sympy.core.cache import cacheit
Martin Bauer's avatar
Martin Bauer committed
13

14
import pystencils
15
from pystencils.alignedarray import aligned_empty
16
from pystencils.data_types import StructType, TypedSymbol, create_type
17
from pystencils.kernelparameters import FieldShapeSymbol, FieldStrideSymbol
18
from pystencils.stencil import direction_string_to_offset, offset_to_direction_string, inverse_direction
Martin Bauer's avatar
Martin Bauer committed
19
from pystencils.sympyextensions import is_integer_sequence
20

21
__all__ = ['Field', 'fields', 'FieldType', 'AbstractField']
Martin Bauer's avatar
Martin Bauer committed
22

23

24
25
26
27
28
29
30
31
32
33
34
35
36
class FieldType(Enum):
    # generic fields
    GENERIC = 0
    # index fields are currently only used for boundary handling
    # the coordinates are not the loop counters in that case, but are read from this index field
    INDEXED = 1
    # communication buffer, used for (un)packing data in communication.
    BUFFER = 2
    # unsafe fields may be accessed in an absolute fashion - the index depends on the data
    # and thus may lead to out-of-bounds accesses
    CUSTOM = 3
    # staggered field
    STAGGERED = 4
37
38
    # staggered field that reverses sign when accessed via opposite direction
    STAGGERED_FLUX = 5
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62

    @staticmethod
    def is_generic(field):
        assert isinstance(field, Field)
        return field.field_type == FieldType.GENERIC

    @staticmethod
    def is_indexed(field):
        assert isinstance(field, Field)
        return field.field_type == FieldType.INDEXED

    @staticmethod
    def is_buffer(field):
        assert isinstance(field, Field)
        return field.field_type == FieldType.BUFFER

    @staticmethod
    def is_custom(field):
        assert isinstance(field, Field)
        return field.field_type == FieldType.CUSTOM

    @staticmethod
    def is_staggered(field):
        assert isinstance(field, Field)
63
64
65
66
67
68
        return field.field_type == FieldType.STAGGERED or field.field_type == FieldType.STAGGERED_FLUX

    @staticmethod
    def is_staggered_flux(field):
        assert isinstance(field, Field)
        return field.field_type == FieldType.STAGGERED_FLUX
69
70
71


def fields(description=None, index_dimensions=0, layout=None, field_type=FieldType.GENERIC, **kwargs):
72
73
74
    """Creates pystencils fields from a string description.

    Examples:
Martin Bauer's avatar
Martin Bauer committed
75
76
77
78
        Create a 2D scalar and vector field:
            >>> s, v = fields("s, v(2): double[2D]")
            >>> assert s.spatial_dimensions == 2 and s.index_dimensions == 0
            >>> assert (v.spatial_dimensions, v.index_dimensions, v.index_shape) == (2, 1, (2,))
79

Martin Bauer's avatar
Martin Bauer committed
80
81
82
83
        Create an integer field of shape (10, 20):
            >>> f = fields("f : int32[10, 20]")
            >>> f.has_fixed_shape, f.shape
            (True, (10, 20))
84

Martin Bauer's avatar
Martin Bauer committed
85
86
87
88
89
        Numpy arrays can be used as template for shape and data type of field:
            >>> arr_s, arr_v = np.zeros([20, 20]), np.zeros([20, 20, 2])
            >>> s, v = fields("s, v(2)", s=arr_s, v=arr_v)
            >>> assert s.index_dimensions == 0 and s.dtype.numpy_dtype == arr_s.dtype
            >>> assert v.index_shape == (2,)
90

Martin Bauer's avatar
Martin Bauer committed
91
92
        Format string can be left out, field names are taken from keyword arguments.
            >>> fields(f1=arr_s, f2=arr_s)
93
            [f1: double[20,20], f2: double[20,20]]
Martin Bauer's avatar
Martin Bauer committed
94
95
96
97
98
99
100

        The keyword names ``index_dimension`` and ``layout`` have special meaning, don't use them for field names
            >>> f = fields(f=arr_v, index_dimensions=1)
            >>> assert f.index_dimensions == 1
            >>> f = fields("pdfs(19) : float32[3D]", layout='fzyx')
            >>> f.layout
            (2, 1, 0)
101
102
103
104
105
106
107
108
109
110
    """
    result = []
    if description:
        field_descriptions, dtype, shape = _parse_description(description)
        layout = 'numpy' if layout is None else layout
        for field_name, idx_shape in field_descriptions:
            if field_name in kwargs:
                arr = kwargs[field_name]
                idx_shape_of_arr = () if not len(idx_shape) else arr.shape[-len(idx_shape):]
                assert idx_shape_of_arr == idx_shape
Michael Kuron's avatar
Michael Kuron committed
111
                f = Field.create_from_numpy_array(field_name, kwargs[field_name], index_dimensions=len(idx_shape),
112
                                                  field_type=field_type)
113
114
            elif isinstance(shape, tuple):
                f = Field.create_fixed_size(field_name, shape + idx_shape, dtype=dtype,
115
                                            index_dimensions=len(idx_shape), layout=layout, field_type=field_type)
116
117
            elif isinstance(shape, int):
                f = Field.create_generic(field_name, spatial_dimensions=shape, dtype=dtype,
118
                                         index_shape=idx_shape, layout=layout, field_type=field_type)
119
120
            elif shape is None:
                f = Field.create_generic(field_name, spatial_dimensions=2, dtype=dtype,
121
                                         index_shape=idx_shape, layout=layout, field_type=field_type)
122
123
124
125
126
127
            else:
                assert False
            result.append(f)
    else:
        assert layout is None, "Layout can not be specified when creating Field from numpy array"
        for field_name, arr in kwargs.items():
Michael Kuron's avatar
Michael Kuron committed
128
            result.append(Field.create_from_numpy_array(field_name, arr, index_dimensions=index_dimensions,
129
                                                        field_type=field_type))
130
131
132
133
134
135
136
137
138

    if len(result) == 0:
        return None
    elif len(result) == 1:
        return result[0]
    else:
        return result


139
140
141
142
143
144
145
class AbstractField:

    class AbstractAccess:
        pass


class Field(AbstractField):
146
147
148
149
    """
    With fields one can formulate stencil-like update rules on structured grids.
    This Field class knows about the dimension, memory layout (strides) and optionally about the size of an array.

Martin Bauer's avatar
Martin Bauer committed
150
    Creating Fields:
Martin Bauer's avatar
Martin Bauer committed
151
152
        The preferred method to create fields is the `fields` function.
        Alternatively one can use one of the static functions `Field.create_generic`, `Field.create_from_numpy_array`
153
        and `Field.create_fixed_size`. Don't instantiate the Field directly!
Martin Bauer's avatar
Martin Bauer committed
154
155
156
157
158
        Fields can be created with known or unknown shapes:

        1. If you want to create a kernel with fixed loop sizes i.e. the shape of the array is already known.
           This is usually the case if just-in-time compilation directly from Python is done.
           (see `Field.create_from_numpy_array`
159
        2. create a more general kernel that works for variable array sizes. This can be used to create kernels
Martin Bauer's avatar
Martin Bauer committed
160
           beforehand for a library. (see `Field.create_generic`)
161

Martin Bauer's avatar
Martin Bauer committed
162
    Dimensions and Indexing:
163
164
        A field has spatial and index dimensions, where the spatial dimensions come first.
        The interpretation is that the field has multiple cells in (usually) two or three dimensional space which are
Martin Bauer's avatar
Martin Bauer committed
165
        looped over. Additionally N values are stored per cell. In this case spatial_dimensions is two or three,
Martin Bauer's avatar
Martin Bauer committed
166
        and index_dimensions equals N. If you want to store a matrix on each point in a two dimensional grid, there
Martin Bauer's avatar
Martin Bauer committed
167
        are four dimensions, two spatial and two index dimensions: ``len(arr.shape) == spatial_dims + index_dims``
168

Martin Bauer's avatar
Martin Bauer committed
169
170
171
172
173
        The shape of the index dimension does not have to be specified. Just use the 'index_dimensions' parameter.
        However, it is good practice to define the shape, since out of bounds accesses can be directly detected in this
        case. The shape can be passed with the 'index_shape' parameter of the field creation functions.

        When accessing (indexing) a field the result is a `Field.Access` which is derived from sympy Symbol.
Martin Bauer's avatar
Martin Bauer committed
174
        First specify the spatial offsets in [], then in case index_dimension>0 the indices in ()
Martin Bauer's avatar
Martin Bauer committed
175
        e.g. ``f[-1,0,0](7)``
176

Michael Kuron's avatar
Michael Kuron committed
177
178
179
180
181
182
183
184
    Staggered Fields:
        Staggered fields are used to store a value on a second grid shifted by half a cell with respect to the usual
        grid.

        The first index dimension is used to specify the position on the staggered grid (e.g. 0 means half-way to the
        eastern neighbor, 1 is half-way to the northern neighbor, etc.), while additional indices can be used to store
        multiple values at each position.

Martin Bauer's avatar
Martin Bauer committed
185
    Example using no index dimensions:
186
        >>> a = np.zeros([10, 10])
Martin Bauer's avatar
Martin Bauer committed
187
        >>> f = Field.create_from_numpy_array("f", a, index_dimensions=0)
Martin Bauer's avatar
Martin Bauer committed
188
        >>> jacobi = (f[-1,0] + f[1,0] + f[0,-1] + f[0,1]) / 4
189

Martin Bauer's avatar
Martin Bauer committed
190
    Examples for index dimensions to create LB field and implement stream pull:
Martin Bauer's avatar
Martin Bauer committed
191
        >>> from pystencils import Assignment
192
        >>> stencil = np.array([[0,0], [0,1], [0,-1]])
Martin Bauer's avatar
Martin Bauer committed
193
        >>> src, dst = fields("src(3), dst(3) : double[2D]")
194
        >>> assignments = [Assignment(dst[0,0](i), src[-offset](i)) for i, offset in enumerate(stencil)];
195
    """
196
197

    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
198
    def create_generic(field_name, spatial_dimensions, dtype=np.float64, index_dimensions=0, layout='numpy',
199
                       index_shape=None, field_type=FieldType.GENERIC) -> 'Field':
200
201
202
        """
        Creates a generic field where the field size is not fixed i.e. can be called with arrays of different sizes

Martin Bauer's avatar
Martin Bauer committed
203
204
205
206
207
208
209
        Args:
            field_name: symbolic name for the field
            dtype: numpy data type of the array the kernel is called with later
            spatial_dimensions: see documentation of Field
            index_dimensions: see documentation of Field
            layout: tuple specifying the loop ordering of the spatial dimensions e.g. (2, 1, 0 ) means that
                    the outer loop loops over dimension 2, the second outer over dimension 1, and the inner loop
Martin Bauer's avatar
Martin Bauer committed
210
                    over dimension 0. Also allowed: the strings 'numpy' (0,1,..d) or 'reverse_numpy' (d, ..., 1, 0)
Martin Bauer's avatar
Martin Bauer committed
211
212
213
            index_shape: optional shape of the index dimensions i.e. maximum values allowed for each index dimension,
                        has to be a list or tuple
            field_type: besides the normal GENERIC fields, there are INDEXED fields that store indices of the domain
214
215
216
                        that should be iterated over, BUFFER fields that are used to generate communication
                        packing/unpacking kernels, and STAGGERED fields, which store values half-way to the next
                        cell
217
        """
218
219
220
        if index_shape is not None:
            assert index_dimensions == 0 or index_dimensions == len(index_shape)
            index_dimensions = len(index_shape)
221
        if isinstance(layout, str):
Martin Bauer's avatar
Martin Bauer committed
222
            layout = spatial_layout_string_to_tuple(layout, dim=spatial_dimensions)
223

Martin Bauer's avatar
Martin Bauer committed
224
225
        total_dimensions = spatial_dimensions + index_dimensions
        if index_shape is None or len(index_shape) == 0:
226
            shape = tuple([FieldShapeSymbol([field_name], i) for i in range(total_dimensions)])
227
        else:
228
            shape = tuple([FieldShapeSymbol([field_name], i) for i in range(spatial_dimensions)] + list(index_shape))
229

230
        strides = tuple([FieldStrideSymbol(field_name, i) for i in range(total_dimensions)])
231

Martin Bauer's avatar
Martin Bauer committed
232
233
234
        np_data_type = np.dtype(dtype)
        if np_data_type.fields is not None:
            if index_dimensions != 0:
235
236
237
                raise ValueError("Structured arrays/fields are not allowed to have an index dimension")
            shape += (1,)
            strides += (1,)
238
        if field_type == FieldType.STAGGERED and index_dimensions == 0:
Michael Kuron's avatar
Michael Kuron committed
239
            raise ValueError("A staggered field needs at least one index dimension")
240

241
        return Field(field_name, field_type, dtype, layout, shape, strides)
242

243
    @staticmethod
Michael Kuron's avatar
Michael Kuron committed
244
    def create_from_numpy_array(field_name: str, array: np.ndarray, index_dimensions: int = 0,
245
                                field_type=FieldType.GENERIC) -> 'Field':
Martin Bauer's avatar
Martin Bauer committed
246
247
        """Creates a field based on the layout, data type, and shape of a given numpy array.

248
        Kernels created for these kind of fields can only be called with arrays of the same layout, shape and type.
Martin Bauer's avatar
Martin Bauer committed
249
250
251
252
253

        Args:
            field_name: symbolic name for the field
            array: numpy array
            index_dimensions: see documentation of Field
254
            field_type: kind of field
255
        """
Martin Bauer's avatar
Martin Bauer committed
256
257
        spatial_dimensions = len(array.shape) - index_dimensions
        if spatial_dimensions < 1:
258
259
            raise ValueError("Too many index dimensions. At least one spatial dimension required")

Martin Bauer's avatar
Martin Bauer committed
260
261
262
        full_layout = get_layout_of_array(array)
        spatial_layout = tuple([i for i in full_layout if i < spatial_dimensions])
        assert len(spatial_layout) == spatial_dimensions
263

Martin Bauer's avatar
Martin Bauer committed
264
265
        strides = tuple([s // np.dtype(array.dtype).itemsize for s in array.strides])
        shape = tuple(int(s) for s in array.shape)
266

Martin Bauer's avatar
Martin Bauer committed
267
268
269
        numpy_dtype = np.dtype(array.dtype)
        if numpy_dtype.fields is not None:
            if index_dimensions != 0:
270
271
272
                raise ValueError("Structured arrays/fields are not allowed to have an index dimension")
            shape += (1,)
            strides += (1,)
273
        if field_type == FieldType.STAGGERED and index_dimensions == 0:
Michael Kuron's avatar
Michael Kuron committed
274
            raise ValueError("A staggered field needs at least one index dimension")
275

276
        return Field(field_name, field_type, array.dtype, spatial_layout, shape, strides)
277
278

    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
279
    def create_fixed_size(field_name: str, shape: Tuple[int, ...], index_dimensions: int = 0,
Michael Kuron's avatar
Michael Kuron committed
280
                          dtype=np.float64, layout: str = 'numpy', strides: Optional[Sequence[int]] = None,
281
                          field_type=FieldType.GENERIC) -> 'Field':
282
        """
283
        Creates a field with fixed sizes i.e. can be called only with arrays of the same size and layout
284

Martin Bauer's avatar
Martin Bauer committed
285
286
287
288
289
290
291
        Args:
            field_name: symbolic name for the field
            shape: overall shape of the array
            index_dimensions: how many of the trailing dimensions are interpreted as index (as opposed to spatial)
            dtype: numpy data type of the array the kernel is called with later
            layout: full layout of array, not only spatial dimensions
            strides: strides in bytes or None to automatically compute them from shape (assuming no padding)
292
            field_type: kind of field
293
        """
Martin Bauer's avatar
Martin Bauer committed
294
295
        spatial_dimensions = len(shape) - index_dimensions
        assert spatial_dimensions >= 1
296

297
        if isinstance(layout, str):
Martin Bauer's avatar
Martin Bauer committed
298
            layout = layout_string_to_tuple(layout, spatial_dimensions + index_dimensions)
299
300

        shape = tuple(int(s) for s in shape)
301
        if strides is None:
Martin Bauer's avatar
Martin Bauer committed
302
            strides = compute_strides(shape, layout)
303
304
305
        else:
            assert len(strides) == len(shape)
            strides = tuple([s // np.dtype(dtype).itemsize for s in strides])
306

Martin Bauer's avatar
Martin Bauer committed
307
308
309
        numpy_dtype = np.dtype(dtype)
        if numpy_dtype.fields is not None:
            if index_dimensions != 0:
310
311
312
                raise ValueError("Structured arrays/fields are not allowed to have an index dimension")
            shape += (1,)
            strides += (1,)
313
        if field_type == FieldType.STAGGERED and index_dimensions == 0:
Michael Kuron's avatar
Michael Kuron committed
314
            raise ValueError("A staggered field needs at least one index dimension")
315

Martin Bauer's avatar
Martin Bauer committed
316
317
318
        spatial_layout = list(layout)
        for i in range(spatial_dimensions, len(layout)):
            spatial_layout.remove(i)
319
        return Field(field_name, field_type, dtype, tuple(spatial_layout), shape, strides)
320

321
    def __init__(self, field_name, field_type, dtype, layout, shape, strides):
322
        """Do not use directly. Use static create* methods"""
323
        self._field_name = field_name
Martin Bauer's avatar
Martin Bauer committed
324
        assert isinstance(field_type, FieldType)
325
        assert len(shape) == len(strides)
Martin Bauer's avatar
Martin Bauer committed
326
        self.field_type = field_type
Martin Bauer's avatar
Martin Bauer committed
327
        self._dtype = create_type(dtype)
Martin Bauer's avatar
Martin Bauer committed
328
        self._layout = normalize_layout(layout)
329
330
        self.shape = shape
        self.strides = strides
331
        self.latex_name = None  # type: Optional[str]
332
333
334
335
        self.coordinate_origin = sp.Matrix(tuple(
            0 for _ in range(self.spatial_dimensions)
        ))  # type: tuple[float,sp.Symbol]
        self.coordinate_transform = sp.eye(self.spatial_dimensions)
336
337
        if field_type == FieldType.STAGGERED:
            assert self.staggered_stencil
338

Martin Bauer's avatar
Martin Bauer committed
339
    def new_field_with_different_name(self, new_name):
340
341
342
343
344
        if self.has_fixed_shape:
            return Field(new_name, self.field_type, self._dtype, self._layout, self.shape, self.strides)
        else:
            return Field.create_generic(new_name, self.spatial_dimensions, self.dtype.numpy_dtype,
                                        self.index_dimensions, self._layout, self.index_shape, self.field_type)
345

346
    @property
Martin Bauer's avatar
Martin Bauer committed
347
    def spatial_dimensions(self) -> int:
348
349
350
        return len(self._layout)

    @property
Martin Bauer's avatar
Martin Bauer committed
351
    def index_dimensions(self) -> int:
352
        return len(self.shape) - len(self._layout)
353

354
355
356
357
    @property
    def ndim(self) -> int:
        return len(self.shape)

358
359
360
    def values_per_cell(self) -> int:
        return functools.reduce(operator.mul, self.index_shape, 1)

361
362
363
364
365
    @property
    def layout(self):
        return self._layout

    @property
Martin Bauer's avatar
Martin Bauer committed
366
    def name(self) -> str:
367
        return self._field_name
368
369

    @property
Martin Bauer's avatar
Martin Bauer committed
370
371
    def spatial_shape(self) -> Tuple[int, ...]:
        return self.shape[:self.spatial_dimensions]
372

373
    @property
Martin Bauer's avatar
Martin Bauer committed
374
    def has_fixed_shape(self):
Martin Bauer's avatar
Martin Bauer committed
375
        return is_integer_sequence(self.shape)
376

377
    @property
Martin Bauer's avatar
Martin Bauer committed
378
379
    def index_shape(self):
        return self.shape[self.spatial_dimensions:]
380

381
    @property
Martin Bauer's avatar
Martin Bauer committed
382
383
    def has_fixed_index_shape(self):
        return is_integer_sequence(self.index_shape)
384

385
    @property
Martin Bauer's avatar
Martin Bauer committed
386
387
    def spatial_strides(self):
        return self.strides[:self.spatial_dimensions]
388
389

    @property
Martin Bauer's avatar
Martin Bauer committed
390
391
    def index_strides(self):
        return self.strides[self.spatial_dimensions:]
392
393
394
395
396

    @property
    def dtype(self):
        return self._dtype

397
398
399
400
    @property
    def itemsize(self):
        return self.dtype.numpy_dtype.itemsize

401
    def __repr__(self):
402
403
404
405
406
407
408
409
410
411
412
413
414
        if any(isinstance(s, sp.Symbol) for s in self.spatial_shape):
            spatial_shape_str = f'{self.spatial_dimensions}d'
        else:
            spatial_shape_str = ','.join(str(i) for i in self.spatial_shape)
        index_shape_str = ','.join(str(i) for i in self.index_shape)

        if self.index_shape:
            return f'{self._field_name}({index_shape_str}): {self.dtype}[{spatial_shape_str}]'
        else:
            return f'{self._field_name}: {self.dtype}[{spatial_shape_str}]'

    def __str__(self):
        return self.name
415

Martin Bauer's avatar
Martin Bauer committed
416
417
418
419
    def neighbor(self, coord_id, offset):
        offset_list = [0] * self.spatial_dimensions
        offset_list[coord_id] = offset
        return Field.Access(self, tuple(offset_list))
420

421
    def neighbors(self, stencil):
422
        return [self.__getitem__(s) for s in stencil]
423

424
    @property
Martin Bauer's avatar
Martin Bauer committed
425
426
427
    def center_vector(self):
        index_shape = self.index_shape
        if len(index_shape) == 0:
428
            return sp.Matrix([self.center])
429
        elif len(index_shape) == 1:
Martin Bauer's avatar
Martin Bauer committed
430
431
            return sp.Matrix([self(i) for i in range(index_shape[0])])
        elif len(index_shape) == 2:
432
433
434
435
436
437
            return sp.Matrix([[self(i, j) for j in range(index_shape[1])] for i in range(index_shape[0])])
        elif len(index_shape) == 3:
            return sp.Matrix([[[self(i, j, k) for k in range(index_shape[2])]
                             for j in range(index_shape[1])] for i in range(index_shape[0])])
        else:
            raise NotImplementedError("center_vector is not implemented for more than 3 index dimensions")
438

439
    @property
440
    def center(self):
Martin Bauer's avatar
Martin Bauer committed
441
        center = tuple([0] * self.spatial_dimensions)
442
443
        return Field.Access(self, center)

444
445
446
447
    def __getitem__(self, offset):
        if type(offset) is np.ndarray:
            offset = tuple(offset)
        if type(offset) is str:
Martin Bauer's avatar
Martin Bauer committed
448
            offset = tuple(direction_string_to_offset(offset, self.spatial_dimensions))
449
450
        if type(offset) is not tuple:
            offset = (offset,)
Martin Bauer's avatar
Martin Bauer committed
451
        if len(offset) != self.spatial_dimensions:
452
            raise ValueError("Wrong number of spatial indices: "
Martin Bauer's avatar
Martin Bauer committed
453
                             "Got %d, expected %d" % (len(offset), self.spatial_dimensions))
454
455
        return Field.Access(self, offset)

Martin Bauer's avatar
Martin Bauer committed
456
    def absolute_access(self, offset, index):
Martin Bauer's avatar
Martin Bauer committed
457
        assert FieldType.is_custom(self)
Martin Bauer's avatar
Martin Bauer committed
458
459
        return Field.Access(self, offset, index, is_absolute_access=True)

460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
    def interpolated_access(self,
                            offset: Tuple,
                            interpolation_mode='linear',
                            address_mode='BORDER',
                            allow_textures=True):
        """Provides access to field values at non-integer positions

        ``interpolated_access`` is similar to :func:`Field.absolute_access` except that
        it allows non-integer offsets and automatic handling of out-of-bound accesses.

        :param offset:              Tuple of spatial coordinates (can be floats)
        :param interpolation_mode:  One of :class:`pystencils.interpolation_astnodes.InterpolationMode`
        :param address_mode:        How boundaries are handled can be 'border', 'wrap', 'mirror', 'clamp'
        :param allow_textures:      Allow implementation by texture accesses on GPUs
        """
        from pystencils.interpolation_astnodes import Interpolator
        return Interpolator(self,
                            interpolation_mode,
                            address_mode,
                            allow_textures=allow_textures).at(offset)

Michael Kuron's avatar
Michael Kuron committed
481
482
483
484
485
486
487
    def staggered_access(self, offset, index=None):
        """If this field is a staggered field, it can be accessed using half-integer offsets.
        For example, an offset of ``(0, sp.Rational(1,2))`` or ``"E"`` corresponds to the staggered point to the east
        of the cell center, i.e. half-way to the eastern-next cell.
        If the field stores more than one value per staggered point (e.g. a vector or a tensor), the index (integer or
        tuple of integers) refers to which of these values to access.
        """
488
        assert FieldType.is_staggered(self)
Michael Kuron's avatar
Michael Kuron committed
489

490
        offset_orig = offset
Michael Kuron's avatar
Michael Kuron committed
491
492
493
494
495
496
497
498
499
500
501
        if type(offset) is np.ndarray:
            offset = tuple(offset)
        if type(offset) is str:
            offset = tuple(direction_string_to_offset(offset, self.spatial_dimensions))
            offset = tuple([o * sp.Rational(1, 2) for o in offset])
        if type(offset) is not tuple:
            offset = (offset,)
        if len(offset) != self.spatial_dimensions:
            raise ValueError("Wrong number of spatial indices: "
                             "Got %d, expected %d" % (len(offset), self.spatial_dimensions))

502
        prefactor = 1
503
504
505
506
507
508
509
510
        neighbor_vec = [0] * len(offset)
        for i in range(self.spatial_dimensions):
            if (offset[i] + sp.Rational(1, 2)).is_Integer:
                neighbor_vec[i] = sp.sign(offset[i])
        neighbor = offset_to_direction_string(neighbor_vec)
        if neighbor not in self.staggered_stencil:
            neighbor_vec = inverse_direction(neighbor_vec)
            neighbor = offset_to_direction_string(neighbor_vec)
511
512
            if FieldType.is_staggered_flux(self):
                prefactor = -1
513
        if neighbor not in self.staggered_stencil:
514
515
            raise ValueError("{} is not a valid neighbor for the {} stencil".format(offset_orig,
                             self.staggered_stencil_name))
516
517
518
519

        offset = tuple(sp.Matrix(offset) - sp.Rational(1, 2) * sp.Matrix(neighbor_vec))

        idx = self.staggered_stencil.index(neighbor)
Michael Kuron's avatar
Michael Kuron committed
520
521
522
523

        if self.index_dimensions == 1:  # this field stores a scalar value at each staggered position
            if index is not None:
                raise ValueError("Cannot specify an index for a scalar staggered field")
524
            return prefactor * Field.Access(self, offset, (idx,))
Michael Kuron's avatar
Michael Kuron committed
525
526
527
528
529
530
531
532
533
534
535
536
        else:  # this field stores a vector or tensor at each staggered position
            if index is None:
                raise ValueError("Wrong number of indices: "
                                 "Got %d, expected %d" % (0, self.index_dimensions - 1))
            if type(index) is np.ndarray:
                index = tuple(index)
            if type(index) is not tuple:
                index = (index,)
            if self.index_dimensions != len(index) + 1:
                raise ValueError("Wrong number of indices: "
                                 "Got %d, expected %d" % (len(index), self.index_dimensions - 1))

537
            return prefactor * Field.Access(self, offset, (idx, *index))
Michael Kuron's avatar
Michael Kuron committed
538

539
540
541
542
543
544
545
546
547
548
549
550
    def staggered_vector_access(self, offset):
        """Like staggered_access, but returns the entire vector/tensor stored at offset."""
        assert FieldType.is_staggered(self)

        if self.index_dimensions == 1:
            return sp.Matrix([self.staggered_access(offset)])
        elif self.index_dimensions == 2:
            return sp.Matrix([self.staggered_access(offset, i) for i in range(self.index_shape[1])])
        elif self.index_dimensions == 3:
            return sp.Matrix([[self.staggered_access(offset, (i, k)) for k in range(self.index_shape[2])]
                             for i in range(self.index_shape[1])])
        else:
Michael Kuron's avatar
Michael Kuron committed
551
            raise NotImplementedError("staggered_vector_access is not implemented for more than 3 index dimensions")
552

553
554
555
556
557
    @property
    def staggered_stencil(self):
        assert FieldType.is_staggered(self)
        stencils = {
            2: {
558
559
                2: ["W", "S"],  # D2Q5
                4: ["W", "S", "SW", "NW"]  # D2Q9
560
561
            },
            3: {
562
563
564
565
                3: ["W", "S", "B"],  # D3Q7
                7: ["W", "S", "B", "BSW", "TSW", "BNW", "TNW"],  # D3Q15
                9: ["W", "S", "B", "SW", "NW", "BW", "TW", "BS", "TS"],  # D3Q19
                13: ["W", "S", "B", "SW", "NW", "BW", "TW", "BS", "TS", "BSW", "TSW", "BNW", "TNW"]  # D3Q27
566
567
568
569
570
571
            }
        }
        if not self.index_shape[0] in stencils[self.spatial_dimensions]:
            raise ValueError("No known stencil has {} staggered points".format(self.index_shape[0]))
        return stencils[self.spatial_dimensions][self.index_shape[0]]

572
573
574
575
576
    @property
    def staggered_stencil_name(self):
        assert FieldType.is_staggered(self)
        return "D%dQ%d" % (self.spatial_dimensions, self.index_shape[0] * 2 + 1)

577
    def __call__(self, *args, **kwargs):
Martin Bauer's avatar
Martin Bauer committed
578
        center = tuple([0] * self.spatial_dimensions)
579
580
        return Field.Access(self, center)(*args, **kwargs)

581
    def hashable_contents(self):
582
583
584
585
586
587
588
        return (self._layout,
                self.shape,
                self.strides,
                self.field_type,
                self._field_name,
                self.latex_name,
                self._dtype)
589

590
    def __hash__(self):
591
        return hash(self.hashable_contents())
592
593

    def __eq__(self, other):
594
595
        if not isinstance(other, Field):
            return False
596
        return self.hashable_contents() == other.hashable_contents()
597

598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
    @property
    def physical_coordinates(self):
        return self.coordinate_transform @ (self.coordinate_origin + pystencils.x_vector(self.spatial_dimensions))

    @property
    def physical_coordinates_staggered(self):
        return self.coordinate_transform @ \
            (self.coordinate_origin + pystencils.x_staggered_vector(self.spatial_dimensions))

    def index_to_physical(self, index_coordinates, staggered=False):
        if staggered:
            index_coordinates = sp.Matrix([i + 0.5 for i in index_coordinates])
        return self.coordinate_transform @ (self.coordinate_origin + index_coordinates)

    def physical_to_index(self, physical_coordinates, staggered=False):
        rtn = self.coordinate_transform.inv() @ physical_coordinates - self.coordinate_origin
        if staggered:
            rtn = sp.Matrix([i - 0.5 for i in rtn])

        return rtn

    def index_to_staggered_physical_coordinates(self, symbol_vector):
        symbol_vector += sp.Matrix([0.5] * self.spatial_dimensions)
        return self.create_physical_coordinates(symbol_vector)

    def set_coordinate_origin_to_field_center(self):
        self.coordinate_origin = -sp.Matrix([i / 2 for i in self.spatial_shape])

Martin Bauer's avatar
Martin Bauer committed
626
    # noinspection PyAttributeOutsideInit,PyUnresolvedReferences
627
    class Access(TypedSymbol, AbstractField.AbstractAccess):
Martin Bauer's avatar
Martin Bauer committed
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
        """Class representing a relative access into a `Field`.

        This class behaves like a normal sympy Symbol, it is actually derived from it. One can built up
        sympy expressions using field accesses, solve for them, etc.

        Examples:
            >>> vector_field_2d = fields("v(2): double[2D]")  # create a 2D vector field
            >>> northern_neighbor_y_component = vector_field_2d[0, 1](1)
            >>> northern_neighbor_y_component
            v_N^1
            >>> central_y_component = vector_field_2d(1)
            >>> central_y_component
            v_C^1
            >>> central_y_component.get_shifted(1, 0)  # move the existing access
            v_E^1
            >>> central_y_component.at_index(0)  # change component
            v_C^0
        """
646

647
648
649
650
        def __new__(cls, name, *args, **kwargs):
            obj = Field.Access.__xnew_cached_(cls, name, *args, **kwargs)
            return obj

651
        def __new_stage2__(self, field, offsets=(0, 0, 0), idx=None, is_absolute_access=False, dtype=None):
Martin Bauer's avatar
Martin Bauer committed
652
            field_name = field.name
Martin Bauer's avatar
Martin Bauer committed
653
            offsets_and_index = (*offsets, *idx) if idx is not None else offsets
Martin Bauer's avatar
Martin Bauer committed
654
            constant_offsets = not any([isinstance(o, sp.Basic) and not o.is_Integer for o in offsets_and_index])
655
656

            if not idx:
Martin Bauer's avatar
Martin Bauer committed
657
                idx = tuple([0] * field.index_dimensions)
658

Martin Bauer's avatar
Martin Bauer committed
659
660
661
            if constant_offsets:
                offset_name = offset_to_direction_string(offsets)
                if field.index_dimensions == 0:
662
                    superscript = None
Martin Bauer's avatar
Martin Bauer committed
663
                elif field.index_dimensions == 1:
664
                    superscript = str(idx[0])
665
                else:
Martin Bauer's avatar
Martin Bauer committed
666
667
668
669
                    idx_str = ",".join([str(e) for e in idx])
                    superscript = idx_str
                if field.has_fixed_index_shape and not isinstance(field.dtype, StructType):
                    for i, bound in zip(idx, field.index_shape):
670
671
                        if i >= bound:
                            raise ValueError("Field index out of bounds")
672
            else:
673
                offset_name = hashlib.md5(pickle.dumps(offsets_and_index)).hexdigest()[:12]
674
                superscript = None
675

Martin Bauer's avatar
Martin Bauer committed
676
            symbol_name = "%s_%s" % (field_name, offset_name)
677
            if superscript is not None:
Martin Bauer's avatar
Martin Bauer committed
678
                symbol_name += "^" + superscript
679

680
            obj = super(Field.Access, self).__xnew__(self, symbol_name, field.dtype)
681
682
683
684
685
686
687
            obj._field = field
            obj._offsets = []
            for o in offsets:
                if isinstance(o, sp.Basic):
                    obj._offsets.append(o)
                else:
                    obj._offsets.append(int(o))
688
            obj._offsets = tuple(obj._offsets)
Martin Bauer's avatar
Martin Bauer committed
689
            obj._offsetName = offset_name
690
            obj._superscript = superscript
691
692
            obj._index = idx

Martin Bauer's avatar
Martin Bauer committed
693
694
695
696
697
698
            obj._indirect_addressing_fields = set()
            for e in chain(obj._offsets, obj._index):
                if isinstance(e, sp.Basic):
                    obj._indirect_addressing_fields.update(a.field for a in e.atoms(Field.Access))

            obj._is_absolute_access = is_absolute_access
699
700
            return obj

701
        def __getnewargs__(self):
702
            return self.field, self.offsets, self.index, self.is_absolute_access, self.dtype
703

Martin Bauer's avatar
Martin Bauer committed
704
        # noinspection SpellCheckingInspection
705
        __xnew__ = staticmethod(__new_stage2__)
Martin Bauer's avatar
Martin Bauer committed
706
        # noinspection SpellCheckingInspection
707
708
709
        __xnew_cached_ = staticmethod(cacheit(__new_stage2__))

        def __call__(self, *idx):
Martin Bauer's avatar
Martin Bauer committed
710
            if self._index != tuple([0] * self.field.index_dimensions):
711
712
713
                raise ValueError("Indexing an already indexed Field.Access")

            idx = tuple(idx)
714

Martin Bauer's avatar
Martin Bauer committed
715
            if self.field.index_dimensions == 0 and idx == (0,):
716
717
                idx = ()

Martin Bauer's avatar
Martin Bauer committed
718
            if len(idx) != self.field.index_dimensions:
719
                raise ValueError("Wrong number of indices: "
Martin Bauer's avatar
Martin Bauer committed
720
                                 "Got %d, expected %d" % (len(idx), self.field.index_dimensions))
721
            return Field.Access(self.field, self._offsets, idx, dtype=self.dtype)
722

Martin Bauer's avatar
Martin Bauer committed
723
724
725
        def __getitem__(self, *idx):
            return self.__call__(*idx)

Martin Bauer's avatar
Martin Bauer committed
726
727
728
729
730
        def __iter__(self):
            """This is necessary to work with parts of sympy that test if an object is iterable (e.g. simplify).
            The __getitem__ would make it iterable"""
            raise TypeError("Field access is not iterable")

731
        @property
Martin Bauer's avatar
Martin Bauer committed
732
733
        def field(self) -> 'Field':
            """Field that the Access points to"""
734
735
736
            return self._field

        @property
Martin Bauer's avatar
Martin Bauer committed
737
738
        def offsets(self) -> Tuple:
            """Spatial offset as tuple"""
739
            return self._offsets
740

741
        @property
Martin Bauer's avatar
Martin Bauer committed
742
743
        def required_ghost_layers(self) -> int:
            """Largest spatial distance that is accessed."""
744
745
746
            return int(np.max(np.abs(self._offsets)))

        @property
Martin Bauer's avatar
Martin Bauer committed
747
        def nr_of_coordinates(self):
748
749
750
            return len(self._offsets)

        @property
Martin Bauer's avatar
Martin Bauer committed
751
        def offset_name(self) -> str:
Martin Bauer's avatar
Martin Bauer committed
752
753
754
755
756
757
758
            """Spatial offset as string, East-West for x, North-South for y and Top-Bottom for z coordinate.

            Example:
                >>> f = fields("f: double[2D]")
                >>> f[1, 1].offset_name  # north-east
                'NE'
            """
759
760
761
762
            return self._offsetName

        @property
        def index(self):
Martin Bauer's avatar
Martin Bauer committed
763
            """Value of index coordinates as tuple."""
764
765
            return self._index

766
        def neighbor(self, coord_id: int, offset: int) -> 'Field.Access':
Martin Bauer's avatar
Martin Bauer committed
767
768
769
770
771
772
773
774
775
776
777
            """Returns a new Access with changed spatial coordinates.

            Args:
                coord_id: index of the coordinate to change (0 for x, 1 for y,...)
                offset: incremental change of this coordinate

            Example:
                >>> f = fields('f: [2D]')
                >>> f[0,0].neighbor(coord_id=1, offset=-1)
                f_S
            """
Martin Bauer's avatar
Martin Bauer committed
778
779
            offset_list = list(self.offsets)
            offset_list[coord_id] += offset
780
            return Field.Access(self.field, tuple(offset_list), self.index, dtype=self.dtype)
781

782
        def get_shifted(self, *shift) -> 'Field.Access':
Martin Bauer's avatar
Martin Bauer committed
783
784
785
786
787
788
789
            """Returns a new Access with changed spatial coordinates

            Example:
                >>> f = fields("f: [2D]")
                >>> f[0,0].get_shifted(1, 1)
                f_NE
            """
790
791
792
793
            return Field.Access(self.field,
                                tuple(a + b for a, b in zip(shift, self.offsets)),
                                self.index,
                                dtype=self.dtype)
794

Martin Bauer's avatar
Martin Bauer committed
795
796
797
798
799
800
801
802
        def at_index(self, *idx_tuple) -> 'Field.Access':
            """Returns new Access with changed index.

            Example:
                >>> f = fields("f(9): [2D]")
                >>> f(0).at_index(8)
                f_C^8
            """
803
            return Field.Access(self.field, self.offsets, idx_tuple, dtype=self.dtype)
804

805
806
807
808
809
810
        def _eval_subs(self, old, new):
            return Field.Access(self.field,
                                tuple(sp.sympify(a).subs(old, new) for a in self.offsets),
                                tuple(sp.sympify(a).subs(old, new) for a in self.index),
                                dtype=self.dtype)

Martin Bauer's avatar
Martin Bauer committed
811
812
813
814
815
816
817
818
819
820
821
822
823
        @property
        def is_absolute_access(self) -> bool:
            """Indicates if a field access is relative to the loop counters (this is the default) or absolute"""
            return self._is_absolute_access

        @property
        def indirect_addressing_fields(self) -> Set['Field']:
            """Returns a set of fields that the access depends on.

             e.g. f[index_field[1, 0]], the outer access to f depends on index_field
             """
            return self._indirect_addressing_fields

824
        def _hashable_content(self):
825
826
            super_class_contents = super(Field.Access, self)._hashable_content()
            return (super_class_contents, self._field.hashable_contents(), *self._index, *self._offsets)
Martin Bauer's avatar
Martin Bauer committed
827

828
829
830
831
        def _staggered_offset(self, offsets, index):
            assert FieldType.is_staggered(self._field)
            neighbor = self._field.staggered_stencil[index]
            neighbor = direction_string_to_offset(neighbor, self._field.spatial_dimensions)
832
            return [(o - sp.Rational(int(neighbor[i]), 2)) for i, o in enumerate(offsets)]
833

Martin Bauer's avatar
Martin Bauer committed
834
        def _latex(self, _):
835
            n = self._field.latex_name if self._field.latex_name else self._field.name
Martin Bauer's avatar
Martin Bauer committed
836
            offset_str = ",".join([sp.latex(o) for o in self.offsets])
837
            if FieldType.is_staggered(self._field):
838
839
                offset_str = ",".join([sp.latex(self._staggered_offset(self.offsets, self.index[0])[i])
                                       for i in range(len(self.offsets))])
Martin Bauer's avatar
Martin Bauer committed
840
841
842
843
844
            if self.is_absolute_access:
                offset_str = "\\mathbf{}".format(offset_str)
            elif self.field.spatial_dimensions > 1:
                offset_str = "({})".format(offset_str)

845
            if FieldType.is_staggered(self._field):
Michael Kuron's avatar
Michael Kuron committed
846
847
848
849
850
                if self.index and self.field.index_dimensions > 1:
                    return "{{%s}_{%s}^{%s}}" % (n, offset_str, self.index[1:]
                                                 if len(self.index) > 2 else self.index[1])
                else:
                    return "{{%s}_{%s}}" % (n, offset_str)
Martin Bauer's avatar
Martin Bauer committed
851
            else:
Michael Kuron's avatar
Michael Kuron committed
852
853
854
855
                if self.index and self.field.index_dimensions > 0:
                    return "{{%s}_{%s}^{%s}}" % (n, offset_str, self.index if len(self.index) > 1 else self.index[0])
                else:
                    return "{{%s}_{%s}}" % (n, offset_str)
Martin Bauer's avatar
Martin Bauer committed
856

857
858
859
        def __str__(self):
            n = self._field.latex_name if self._field.latex_name else self._field.name
            offset_str = ",".join([sp.latex(o) for o in self.offsets])
860
            if FieldType.is_staggered(self._field):
861
862
                offset_str = ",".join([sp.latex(self._staggered_offset(self.offsets, self.index[0])[i])
                                       for i in range(len(self.offsets))])
863
864
            if self.is_absolute_access:
                offset_str = "[abs]{}".format(offset_str)
Michael Kuron's avatar
Michael Kuron committed
865

866
            if FieldType.is_staggered(self._field):
Michael Kuron's avatar
Michael Kuron committed
867
868
869
870
                if self.index and self.field.index_dimensions > 1:
                    return "%s[%s](%s)" % (n, offset_str, self.index[1:] if len(self.index) > 2 else self.index[1])
                else:
                    return "%s[%s]" % (n, offset_str)
871
            else:
Michael Kuron's avatar
Michael Kuron committed
872
873
874
875
                if self.index and self.field.index_dimensions > 0:
                    return "%s[%s](%s)" % (n, offset_str, self.index if len(self.index) > 1 else self.index[0])
                else:
                    return "%s[%s]" % (n, offset_str)
876

Martin Bauer's avatar
Martin Bauer committed
877

Martin Bauer's avatar
Martin Bauer committed
878
879
def get_layout_from_strides(strides: Sequence[int], index_dimension_ids: Optional[List[int]] = None):
    index_dimension_ids = [] if index_dimension_ids is None else index_dimension_ids
880
    coordinates = list(range(len(strides)))
Martin Bauer's avatar
Martin Bauer committed
881
    relevant_strides = [stride for i, stride in enumerate(strides) if i not in index_dimension_ids]
882
    result = [x for (y, x) in sorted(zip(relevant_strides, coordinates), key=lambda pair: pair[0], reverse=True)]
Martin Bauer's avatar
Martin Bauer committed
883
    return normalize_layout(result)
884
885


Martin Bauer's avatar
Martin Bauer committed
886
887
888
889
890
891
def get_layout_of_array(arr: np.ndarray, index_dimension_ids: Optional[List[int]] = None):
    """ Returns a list indicating the memory layout (linearization order) of the numpy array.

    Examples:
        >>> get_layout_of_array(np.zeros([3,3,3]))
        (0, 1, 2)
Martin Bauer's avatar
Martin Bauer committed
892
893
894
895
896

    In this example the loop over the zeroth coordinate should be the outermost loop,
    followed by the first and second. Elements arr[x,y,0] and arr[x,y,1] are adjacent in memory.
    Normally constructed numpy arrays have this order, however by stride tricks or other frameworks, arrays
    with different memory layout can be created.
Martin Bauer's avatar
Martin Bauer committed
897

Martin Bauer's avatar
Martin Bauer committed
898
    The index_dimension_ids parameter leaves specifies which coordinates should not be
Martin Bauer's avatar
Martin Bauer committed
899
    """
Martin Bauer's avatar
Martin Bauer committed
900
901
    index_dimension_ids = [] if index_dimension_ids is None else index_dimension_ids
    return get_layout_from_strides(arr.strides, index_dimension_ids)
902
903


Martin Bauer's avatar
Martin Bauer committed
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
def create_numpy_array_with_layout(shape, layout, alignment=False, byte_offset=0, **kwargs):
    """Creates numpy array with given memory layout.

    Args:
        shape: shape of the resulting array
        layout: layout as tuple, where the coordinates are ordered from slow to fast
        alignment: number of bytes to align the beginning and the innermost coordinate to, or False for no alignment
        byte_offset: only used when alignment is specified, align not beginning but address at this offset
                     mostly used to align first inner cell, not ghost cells

    Example:
        >>> res = create_numpy_array_with_layout(shape=(2, 3, 4, 5), layout=(3, 2, 0, 1))
        >>> res.shape
        (2, 3, 4, 5)
        >>> get_layout_of_array(res)
        (3, 2, 0, 1)
920
921
    """
    assert set(layout) == set(range(len(shape))), "Wrong layout descriptor"
Martin Bauer's avatar
Martin Bauer committed
922
    cur_layout = list(range(len(shape)))
923
924
    swaps = []
    for i in range(len(layout)):
Martin Bauer's avatar
Martin Bauer committed
925
926
927
928
929
        if cur_layout[i] != layout[i]:
            index_to_swap_with = cur_layout.index(layout[i])
            swaps.append((i, index_to_swap_with))
            cur_layout[i], cur_layout[index_to_swap_with] = cur_layout[index_to_swap_with], cur_layout[i]
    assert tuple(cur_layout) == tuple(layout)
930
931
932
933
934

    shape = list(shape)
    for a, b in swaps:
        shape[a], shape[b] = shape[b], shape[a]

935
936
937
938
939
    if not alignment:
        res = np.empty(shape, order='c', **kwargs)
    else:
        if alignment is True:
            alignment = 8 * 4
Martin Bauer's avatar
Martin Bauer committed
940
        res = aligned_empty(shape, alignment, byte_offset=byte_offset, **kwargs)
941

942
943
944
945
946
    for a, b in reversed(swaps):
        res = res.swapaxes(a, b)
    return res


Martin Bauer's avatar
Martin Bauer committed
947
948
def spatial_layout_string_to_tuple(layout_str: str, dim: int) -> Tuple[int, ...]:
    if layout_str in ('fzyx', 'zyxf'):
949
950
        assert dim <= 3
        return tuple(reversed(range(dim)))
951

Martin Bauer's avatar
Martin Bauer committed
952
    if layout_str in ('fzyx', 'f', 'reverse_numpy', 'SoA'):
953
        return tuple(reversed(range(dim)))
Martin Bauer's avatar
Martin Bauer committed
954
    elif layout_str in ('c', 'numpy', 'AoS'):
955
        return tuple(range(dim))
Martin Bauer's avatar
Martin Bauer committed
956
    raise ValueError("Unknown layout descriptor " + layout_str)
957
958


Martin Bauer's avatar
Martin Bauer committed
959
960
961
def layout_string_to_tuple(layout_str, dim):
    layout_str = layout_str.lower()
    if layout_str == 'fzyx' or layout_str == 'soa':
962
963
        assert dim <= 4
        return tuple(reversed(range(dim)))
Martin Bauer's avatar
Martin Bauer committed
964
    elif layout_str == 'zyxf' or layout_str == 'aos':
965
        assert dim <= 4
Martin Bauer's avatar
Martin Bauer committed
966
        return tuple(reversed(range(dim - 1))) + (dim - 1,)
Martin Bauer's avatar
Martin Bauer committed
967
    elif layout_str == 'f' or layout_str == 'reverse_numpy':
968
        return tuple(reversed(range(dim)))
Martin Bauer's avatar
Martin Bauer committed
969
    elif layout_str == 'c' or layout_str == 'numpy':
970
        return tuple(range(dim))
Martin Bauer's avatar
Martin Bauer committed
971
    raise ValueError("Unknown layout descriptor " + layout_str)
972
973


Martin Bauer's avatar
Martin Bauer committed
974
def normalize_layout(layout):