From e5f4c74d3e7ae66bcb914ba1523bd0c43e530446 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Mon, 13 Jan 2020 11:50:55 +0100
Subject: [PATCH] Make add_arrays return fields to arrays

---
 pystencils/datahandling/datahandling_interface.py | 15 +++++++++++----
 pystencils_tests/test_datahandling.py             |  4 +++-
 2 files changed, 14 insertions(+), 5 deletions(-)

diff --git a/pystencils/datahandling/datahandling_interface.py b/pystencils/datahandling/datahandling_interface.py
index 2cc71b2b1..e055e9ea8 100644
--- a/pystencils/datahandling/datahandling_interface.py
+++ b/pystencils/datahandling/datahandling_interface.py
@@ -68,20 +68,27 @@ class DataHandling(ABC):
 
         >>> from pystencils.datahandling import create_data_handling
         >>> dh = create_data_handling((20, 30))
-        >>> dh.add_arrays('x, y(9)')
+        >>> x, y =dh.add_arrays('x, y(9)')
         >>> print(dh.fields)
-        {'x': x: double[20,30], 'y': y(9): double[20,30]}
-        >>> assert dh.fields['x'].shape = (20, 30)
-        >>> assert dh.fields['y'].index_shape = (9,)
+        {'x': x: double[22,32], 'y': y(9): double[22,32]}
+        >>> assert x == dh.fields['x']
+        >>> assert dh.fields['x'].shape == (22, 32)
+        >>> assert dh.fields['y'].index_shape == (9,)
 
         Args:
             description (str): String description of the fields to add
+        Returns:
+            Fields representing the just created arrays
         """
         from pystencils.field import _parse_part1
 
+        names = []
         for name, indices in _parse_part1(description):
+            names.append(name)
             self.add_array(name, values_per_cell=indices)
 
+        return (self.fields[n] for n in names)
+
     @abstractmethod
     def has_data(self, name):
         """Returns true if a field or custom data element with this name was added."""
diff --git a/pystencils_tests/test_datahandling.py b/pystencils_tests/test_datahandling.py
index 2f4d87a40..1b79fdbec 100644
--- a/pystencils_tests/test_datahandling.py
+++ b/pystencils_tests/test_datahandling.py
@@ -233,9 +233,11 @@ def test_add_arrays():
     field_description = 'x, y(9)'
 
     dh = create_data_handling(domain_size=domain_shape, default_ghost_layers=0, default_layout='numpy')
-    dh.add_arrays(field_description)
+    x_, y_ = dh.add_arrays(field_description)
 
     x, y = ps.fields(field_description + ': [3,4,5]')
 
+    assert x_ == x
+    assert y_ == y
     assert x == dh.fields['x']
     assert y == dh.fields['y']
-- 
GitLab