diff --git a/pystencils/slicing.py b/pystencils/slicing.py index 2f44a00a91ba42d4fc6834a20fb21023080e4653..64b9d308f21dfbea63d4096323fa014b725dc572 100644 --- a/pystencils/slicing.py +++ b/pystencils/slicing.py @@ -89,9 +89,12 @@ def shift_slice(slices, offset): raise ValueError() if hasattr(offset, '__len__'): - return [shift_slice_component(k, off) for k, off in zip(slices, offset)] + return tuple(shift_slice_component(k, off) for k, off in zip(slices, offset)) else: - return [shift_slice_component(k, offset) for k in slices] + if isinstance(slices, slice) or isinstance(slices, int) or isinstance(slices, float): + return shift_slice_component(slices, offset) + else: + return tuple(shift_slice_component(k, offset) for k in slices) def slice_from_direction(direction_name, dim, normal_offset=0, tangential_offset=0): diff --git a/pystencils_tests/test_slicing.py b/pystencils_tests/test_slicing.py index 79e36576bfb622cae3f4f9ed865a8b2f8308430b..e2b9591a137d17ddf54215a80d666c9f5ecd619b 100644 --- a/pystencils_tests/test_slicing.py +++ b/pystencils_tests/test_slicing.py @@ -1,4 +1,5 @@ import numpy as np +from numpy.testing import assert_array_equal from pystencils import create_data_handling from pystencils.slicing import SlicedGetter, make_slice, SlicedGetterDataHandling, shift_slice, slice_intersection @@ -59,6 +60,28 @@ def test_shift_slice(): assert sh[1] == 1.5 +def test_shifted_array_access(): + arr = np.array(range(10)) + + sh = make_slice[2:5] + assert_array_equal(arr[sh], [2,3,4]) + + sh = shift_slice(sh, 3) + assert_array_equal(arr[sh], [5,6,7]) + + arr = np.array([ + [1, 2, 3], + [4, 5, 6], + [7, 8, 9] + ]) + + sh = make_slice[0:2, 0:2] + assert_array_equal(arr[sh], [[1, 2], [4, 5]]) + + sh = shift_slice(sh, (1,1)) + assert_array_equal(arr[sh], [[5, 6], [8, 9]]) + + def test_slice_intersection(): sl1 = make_slice[1:10, 1:10] sl2 = make_slice[5:15, 5:15]