From 1d4c3bf619226767f823e0e72db5bb5739e8fba3 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Sun, 28 Nov 2021 15:32:05 +0100
Subject: [PATCH] Fixed test for sliced iteration with buffer to use dynamic
 field sizes

---
 pystencils_tests/test_buffer.py | 18 +++++++++++-------
 1 file changed, 11 insertions(+), 7 deletions(-)

diff --git a/pystencils_tests/test_buffer.py b/pystencils_tests/test_buffer.py
index 28665a294..b8af6f53f 100644
--- a/pystencils_tests/test_buffer.py
+++ b/pystencils_tests/test_buffer.py
@@ -20,7 +20,7 @@ def _generate_fields(dt=np.uint64, num_directions=1, layout='numpy'):
     fields = []
     for size in field_sizes:
         field_layout = layout_string_to_tuple(layout, len(size))
-        src_arr = create_numpy_array_with_layout(size, field_layout)
+        src_arr = create_numpy_array_with_layout(size, field_layout, dtype=dt)
 
         array_data = np.reshape(np.arange(1, int(np.prod(size)+1)), size)
         # Use flat iterator to input data into the array
@@ -193,10 +193,14 @@ def test_field_layouts():
 
 def test_iteration_slices():
     num_cell_values = 19
-    fields = _generate_fields(num_directions=num_cell_values)
+    dt = np.uint64
+    fields = _generate_fields(dt=dt, num_directions=num_cell_values)
     for (src_arr, dst_arr, bufferArr) in fields:
-        src_field = Field.create_from_numpy_array("src_field", src_arr, index_dimensions=1)
-        dst_field = Field.create_from_numpy_array("dst_field", dst_arr, index_dimensions=1)
+        spatial_dimensions = len(src_arr.shape) - 1
+        # src_field = Field.create_from_numpy_array("src_field", src_arr, index_dimensions=1)
+        # dst_field = Field.create_from_numpy_array("dst_field", dst_arr, index_dimensions=1)
+        src_field = Field.create_generic("src_field", spatial_dimensions, index_shape=(num_cell_values,), dtype=dt)
+        dst_field = Field.create_generic("dst_field", spatial_dimensions, index_shape=(num_cell_values,), dtype=dt)
         buffer = Field.create_generic("buffer", spatial_dimensions=1, index_dimensions=1,
                                         field_type=FieldType.BUFFER, dtype=src_arr.dtype)
 
@@ -214,7 +218,7 @@ def test_iteration_slices():
 
         #   Fill the entire array with data
         src_arr[(slice(None, None, 1),) * dim] = np.arange(num_cell_values)
-        dst_arr.fill(0.0)
+        dst_arr.fill(0)
 
         pack_code = create_kernel(pack_eqs, iteration_slice=pack_slice, data_type={'src_field': src_arr.dtype, 'buffer': buffer.dtype})
         pack_kernel = pack_code.compile()
@@ -232,6 +236,6 @@ def test_iteration_slices():
 
         #   Check if only every second entry of the leftmost slice has been copied
         np.testing.assert_equal(dst_arr[pack_slice], src_arr[pack_slice])
-        np.testing.assert_equal(dst_arr[(slice(1, None, 2),)  * (dim-1) + (0,)], 0.0)
-        np.testing.assert_equal(dst_arr[(slice(None, None, 1),)  * (dim-1) + (slice(1,None),)], 0.0)
+        np.testing.assert_equal(dst_arr[(slice(1, None, 2),)  * (dim-1) + (0,)], 0)
+        np.testing.assert_equal(dst_arr[(slice(None, None, 1),)  * (dim-1) + (slice(1,None),)], 0)
 
-- 
GitLab