From a817ea2c57550da708e18717bb130430b478c585 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Wed, 7 Oct 2020 11:30:32 +0200 Subject: [PATCH] Fixes to shift_slice - `shift_slice` did not work for single slices as it always tried to iterate its argument - `shift_slice` returned lists of slices, but for accessing numpy arrays, tuples of slices are required --- pystencils/slicing.py | 7 +++++-- pystencils_tests/test_slicing.py | 23 +++++++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/pystencils/slicing.py b/pystencils/slicing.py index 2f44a00a9..64b9d308f 100644 --- a/pystencils/slicing.py +++ b/pystencils/slicing.py @@ -89,9 +89,12 @@ def shift_slice(slices, offset): raise ValueError() if hasattr(offset, '__len__'): - return [shift_slice_component(k, off) for k, off in zip(slices, offset)] + return tuple(shift_slice_component(k, off) for k, off in zip(slices, offset)) else: - return [shift_slice_component(k, offset) for k in slices] + if isinstance(slices, slice) or isinstance(slices, int) or isinstance(slices, float): + return shift_slice_component(slices, offset) + else: + return tuple(shift_slice_component(k, offset) for k in slices) def slice_from_direction(direction_name, dim, normal_offset=0, tangential_offset=0): diff --git a/pystencils_tests/test_slicing.py b/pystencils_tests/test_slicing.py index 79e36576b..e2b9591a1 100644 --- a/pystencils_tests/test_slicing.py +++ b/pystencils_tests/test_slicing.py @@ -1,4 +1,5 @@ import numpy as np +from numpy.testing import assert_array_equal from pystencils import create_data_handling from pystencils.slicing import SlicedGetter, make_slice, SlicedGetterDataHandling, shift_slice, slice_intersection @@ -59,6 +60,28 @@ def test_shift_slice(): assert sh[1] == 1.5 +def test_shifted_array_access(): + arr = np.array(range(10)) + + sh = make_slice[2:5] + assert_array_equal(arr[sh], [2,3,4]) + + sh = shift_slice(sh, 3) + assert_array_equal(arr[sh], [5,6,7]) + + arr = np.array([ + [1, 2, 3], + [4, 5, 6], + [7, 8, 9] + ]) + + sh = make_slice[0:2, 0:2] + assert_array_equal(arr[sh], [[1, 2], [4, 5]]) + + sh = shift_slice(sh, (1,1)) + assert_array_equal(arr[sh], [[5, 6], [8, 9]]) + + def test_slice_intersection(): sl1 = make_slice[1:10, 1:10] sl2 = make_slice[5:15, 5:15] -- GitLab