From 16d731b5d5d15e3e43641c3c40a6e8fd5ea63974 Mon Sep 17 00:00:00 2001 From: Michael Kuron <mkuron@icp.uni-stuttgart.de> Date: Thu, 23 Jan 2020 15:17:15 +0100 Subject: [PATCH] more test improvements --- pystencils/fd/derivative.py | 9 ++++++++- pystencils_tests/test_field.py | 2 +- pystencils_tests/test_interpolation.py | 3 ++- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/pystencils/fd/derivative.py b/pystencils/fd/derivative.py index 7acd24505..814501bf0 100644 --- a/pystencils/fd/derivative.py +++ b/pystencils/fd/derivative.py @@ -306,7 +306,8 @@ def expand_diff_full(expr, functions=None, constants=None): functions.difference_update(constants) def visit(e): - e = e.expand() + if not isinstance(e, sp.Tuple): + e = e.expand() if e.func == Diff: result = 0 @@ -331,6 +332,9 @@ def expand_diff_full(expr, functions=None, constants=None): return result elif isinstance(e, sp.Piecewise): return sp.Piecewise(*((expand_diff_full(a, functions, constants), b) for a, b in e.args)) + elif isinstance(expr, sp.Tuple): + new_args = [visit(arg) for arg in e.args] + return sp.Tuple(*new_args) else: new_args = [visit(arg) for arg in e.args] return e.func(*new_args) if new_args else e @@ -370,6 +374,9 @@ def expand_diff_linear(expr, functions=None, constants=None): return diff.split_linear(functions) elif isinstance(expr, sp.Piecewise): return sp.Piecewise(*((expand_diff_linear(a, functions, constants), b) for a, b in expr.args)) + elif isinstance(expr, sp.Tuple): + new_args = [expand_diff_linear(e, functions) for e in expr.args] + return sp.Tuple(*new_args) else: new_args = [expand_diff_linear(e, functions) for e in expr.args] result = sp.expand(expr.func(*new_args) if new_args else expr) diff --git a/pystencils_tests/test_field.py b/pystencils_tests/test_field.py index a2813e34f..253ca9f26 100644 --- a/pystencils_tests/test_field.py +++ b/pystencils_tests/test_field.py @@ -44,7 +44,7 @@ def test_error_handling(): Field.create_generic('f', spatial_dimensions=2, index_dimensions=1, dtype=struct_dtype) assert 'index dimension' in str(e.value) - arr = np.array([[1, 2.0, 3], [1, 2.0, 3]], dtype=struct_dtype) + arr = np.array([[[(1,)*3, (2,)*3, (3,)*3]]*2], dtype=struct_dtype) Field.create_from_numpy_array('f', arr, index_dimensions=0) with pytest.raises(ValueError) as e: Field.create_from_numpy_array('f', arr, index_dimensions=1) diff --git a/pystencils_tests/test_interpolation.py b/pystencils_tests/test_interpolation.py index d32e347f6..477765bb3 100644 --- a/pystencils_tests/test_interpolation.py +++ b/pystencils_tests/test_interpolation.py @@ -153,7 +153,8 @@ def test_rotate_interpolation_gpu(dtype, address_mode, use_textures): @pytest.mark.parametrize('dtype', [np.float64, np.float32, np.int32]) @pytest.mark.parametrize('use_textures', ('use_textures', False,)) def test_shift_interpolation_gpu(address_mode, dtype, use_textures): - if int(sympy.__version__.replace('.', '')) < 12 and address_mode in ['mirror', 'warp']: + sver = sympy.__version__.split(".") + if (int(sver[0]) == 1 and int(sver[1]) < 2) and address_mode in ['mirror', 'warp']: pytest.skip() pytest.importorskip('pycuda') -- GitLab