From a817ea2c57550da708e18717bb130430b478c585 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Wed, 7 Oct 2020 11:30:32 +0200
Subject: [PATCH] Fixes to shift_slice

- `shift_slice` did not work for single slices as it always tried to
iterate its argument
- `shift_slice` returned lists of slices, but for accessing numpy
arrays, tuples of slices are required
---
 pystencils/slicing.py            |  7 +++++--
 pystencils_tests/test_slicing.py | 23 +++++++++++++++++++++++
 2 files changed, 28 insertions(+), 2 deletions(-)

diff --git a/pystencils/slicing.py b/pystencils/slicing.py
index 2f44a00a9..64b9d308f 100644
--- a/pystencils/slicing.py
+++ b/pystencils/slicing.py
@@ -89,9 +89,12 @@ def shift_slice(slices, offset):
             raise ValueError()
 
     if hasattr(offset, '__len__'):
-        return [shift_slice_component(k, off) for k, off in zip(slices, offset)]
+        return tuple(shift_slice_component(k, off) for k, off in zip(slices, offset))
     else:
-        return [shift_slice_component(k, offset) for k in slices]
+        if isinstance(slices, slice) or isinstance(slices, int) or isinstance(slices, float):
+            return shift_slice_component(slices, offset)
+        else:
+            return tuple(shift_slice_component(k, offset) for k in slices)
 
 
 def slice_from_direction(direction_name, dim, normal_offset=0, tangential_offset=0):
diff --git a/pystencils_tests/test_slicing.py b/pystencils_tests/test_slicing.py
index 79e36576b..e2b9591a1 100644
--- a/pystencils_tests/test_slicing.py
+++ b/pystencils_tests/test_slicing.py
@@ -1,4 +1,5 @@
 import numpy as np
+from numpy.testing import assert_array_equal
 from pystencils import create_data_handling
 from pystencils.slicing import SlicedGetter, make_slice, SlicedGetterDataHandling, shift_slice, slice_intersection
 
@@ -59,6 +60,28 @@ def test_shift_slice():
     assert sh[1] == 1.5
 
 
+def test_shifted_array_access():
+    arr = np.array(range(10))
+    
+    sh = make_slice[2:5]
+    assert_array_equal(arr[sh], [2,3,4])
+
+    sh = shift_slice(sh, 3)
+    assert_array_equal(arr[sh], [5,6,7])
+
+    arr = np.array([
+        [1, 2, 3],
+        [4, 5, 6],
+        [7, 8, 9]
+    ])
+
+    sh = make_slice[0:2, 0:2]
+    assert_array_equal(arr[sh], [[1, 2], [4, 5]])
+
+    sh = shift_slice(sh, (1,1))
+    assert_array_equal(arr[sh], [[5, 6], [8, 9]])
+
+
 def test_slice_intersection():
     sl1 = make_slice[1:10, 1:10]
     sl2 = make_slice[5:15, 5:15]
-- 
GitLab