diff --git a/pystencils/fd/derivative.py b/pystencils/fd/derivative.py index 7acd245059615ded27e2c1fe023e344b48c34a6e..814501bf0fd67f594055f5bf25fdf192fe6d9fec 100644 --- a/pystencils/fd/derivative.py +++ b/pystencils/fd/derivative.py @@ -306,7 +306,8 @@ def expand_diff_full(expr, functions=None, constants=None): functions.difference_update(constants) def visit(e): - e = e.expand() + if not isinstance(e, sp.Tuple): + e = e.expand() if e.func == Diff: result = 0 @@ -331,6 +332,9 @@ def expand_diff_full(expr, functions=None, constants=None): return result elif isinstance(e, sp.Piecewise): return sp.Piecewise(*((expand_diff_full(a, functions, constants), b) for a, b in e.args)) + elif isinstance(expr, sp.Tuple): + new_args = [visit(arg) for arg in e.args] + return sp.Tuple(*new_args) else: new_args = [visit(arg) for arg in e.args] return e.func(*new_args) if new_args else e @@ -370,6 +374,9 @@ def expand_diff_linear(expr, functions=None, constants=None): return diff.split_linear(functions) elif isinstance(expr, sp.Piecewise): return sp.Piecewise(*((expand_diff_linear(a, functions, constants), b) for a, b in expr.args)) + elif isinstance(expr, sp.Tuple): + new_args = [expand_diff_linear(e, functions) for e in expr.args] + return sp.Tuple(*new_args) else: new_args = [expand_diff_linear(e, functions) for e in expr.args] result = sp.expand(expr.func(*new_args) if new_args else expr) diff --git a/pystencils_tests/test_field.py b/pystencils_tests/test_field.py index a2813e34f9c72c8bbcd848d21993c181cf31fc43..253ca9f26547d78e6b23b2c6dbc582196b49960d 100644 --- a/pystencils_tests/test_field.py +++ b/pystencils_tests/test_field.py @@ -44,7 +44,7 @@ def test_error_handling(): Field.create_generic('f', spatial_dimensions=2, index_dimensions=1, dtype=struct_dtype) assert 'index dimension' in str(e.value) - arr = np.array([[1, 2.0, 3], [1, 2.0, 3]], dtype=struct_dtype) + arr = np.array([[[(1,)*3, (2,)*3, (3,)*3]]*2], dtype=struct_dtype) Field.create_from_numpy_array('f', arr, index_dimensions=0) with pytest.raises(ValueError) as e: Field.create_from_numpy_array('f', arr, index_dimensions=1) diff --git a/pystencils_tests/test_interpolation.py b/pystencils_tests/test_interpolation.py index d32e347f67cdd8f265d856c3eafa2ff6f3cc575d..477765bb31289dbfdc48927e3cc55f10d49f16a0 100644 --- a/pystencils_tests/test_interpolation.py +++ b/pystencils_tests/test_interpolation.py @@ -153,7 +153,8 @@ def test_rotate_interpolation_gpu(dtype, address_mode, use_textures): @pytest.mark.parametrize('dtype', [np.float64, np.float32, np.int32]) @pytest.mark.parametrize('use_textures', ('use_textures', False,)) def test_shift_interpolation_gpu(address_mode, dtype, use_textures): - if int(sympy.__version__.replace('.', '')) < 12 and address_mode in ['mirror', 'warp']: + sver = sympy.__version__.split(".") + if (int(sver[0]) == 1 and int(sver[1]) < 2) and address_mode in ['mirror', 'warp']: pytest.skip() pytest.importorskip('pycuda')