From 16d731b5d5d15e3e43641c3c40a6e8fd5ea63974 Mon Sep 17 00:00:00 2001
From: Michael Kuron <mkuron@icp.uni-stuttgart.de>
Date: Thu, 23 Jan 2020 15:17:15 +0100
Subject: [PATCH] more test improvements

---
 pystencils/fd/derivative.py            | 9 ++++++++-
 pystencils_tests/test_field.py         | 2 +-
 pystencils_tests/test_interpolation.py | 3 ++-
 3 files changed, 11 insertions(+), 3 deletions(-)

diff --git a/pystencils/fd/derivative.py b/pystencils/fd/derivative.py
index 7acd24505..814501bf0 100644
--- a/pystencils/fd/derivative.py
+++ b/pystencils/fd/derivative.py
@@ -306,7 +306,8 @@ def expand_diff_full(expr, functions=None, constants=None):
             functions.difference_update(constants)
 
     def visit(e):
-        e = e.expand()
+        if not isinstance(e, sp.Tuple):
+            e = e.expand()
 
         if e.func == Diff:
             result = 0
@@ -331,6 +332,9 @@ def expand_diff_full(expr, functions=None, constants=None):
             return result
         elif isinstance(e, sp.Piecewise):
             return sp.Piecewise(*((expand_diff_full(a, functions, constants), b) for a, b in e.args))
+        elif isinstance(expr, sp.Tuple):
+            new_args = [visit(arg) for arg in e.args]
+            return sp.Tuple(*new_args)
         else:
             new_args = [visit(arg) for arg in e.args]
             return e.func(*new_args) if new_args else e
@@ -370,6 +374,9 @@ def expand_diff_linear(expr, functions=None, constants=None):
                 return diff.split_linear(functions)
     elif isinstance(expr, sp.Piecewise):
         return sp.Piecewise(*((expand_diff_linear(a, functions, constants), b) for a, b in expr.args))
+    elif isinstance(expr, sp.Tuple):
+        new_args = [expand_diff_linear(e, functions) for e in expr.args]
+        return sp.Tuple(*new_args)
     else:
         new_args = [expand_diff_linear(e, functions) for e in expr.args]
         result = sp.expand(expr.func(*new_args) if new_args else expr)
diff --git a/pystencils_tests/test_field.py b/pystencils_tests/test_field.py
index a2813e34f..253ca9f26 100644
--- a/pystencils_tests/test_field.py
+++ b/pystencils_tests/test_field.py
@@ -44,7 +44,7 @@ def test_error_handling():
         Field.create_generic('f', spatial_dimensions=2, index_dimensions=1, dtype=struct_dtype)
     assert 'index dimension' in str(e.value)
 
-    arr = np.array([[1, 2.0, 3], [1, 2.0, 3]], dtype=struct_dtype)
+    arr = np.array([[[(1,)*3, (2,)*3, (3,)*3]]*2], dtype=struct_dtype)
     Field.create_from_numpy_array('f', arr, index_dimensions=0)
     with pytest.raises(ValueError) as e:
         Field.create_from_numpy_array('f', arr, index_dimensions=1)
diff --git a/pystencils_tests/test_interpolation.py b/pystencils_tests/test_interpolation.py
index d32e347f6..477765bb3 100644
--- a/pystencils_tests/test_interpolation.py
+++ b/pystencils_tests/test_interpolation.py
@@ -153,7 +153,8 @@ def test_rotate_interpolation_gpu(dtype, address_mode, use_textures):
 @pytest.mark.parametrize('dtype', [np.float64, np.float32, np.int32])
 @pytest.mark.parametrize('use_textures', ('use_textures', False,))
 def test_shift_interpolation_gpu(address_mode, dtype, use_textures):
-    if int(sympy.__version__.replace('.', '')) < 12 and address_mode in ['mirror', 'warp']:
+    sver = sympy.__version__.split(".")
+    if (int(sver[0]) == 1 and int(sver[1]) < 2) and address_mode in ['mirror', 'warp']:
         pytest.skip()
     pytest.importorskip('pycuda')
 
-- 
GitLab