Skip to content
Snippets Groups Projects
Commit 16d731b5 authored by Michael Kuron's avatar Michael Kuron :mortar_board:
Browse files

more test improvements

parent 54ff216b
No related merge requests found
......@@ -306,7 +306,8 @@ def expand_diff_full(expr, functions=None, constants=None):
functions.difference_update(constants)
def visit(e):
e = e.expand()
if not isinstance(e, sp.Tuple):
e = e.expand()
if e.func == Diff:
result = 0
......@@ -331,6 +332,9 @@ def expand_diff_full(expr, functions=None, constants=None):
return result
elif isinstance(e, sp.Piecewise):
return sp.Piecewise(*((expand_diff_full(a, functions, constants), b) for a, b in e.args))
elif isinstance(expr, sp.Tuple):
new_args = [visit(arg) for arg in e.args]
return sp.Tuple(*new_args)
else:
new_args = [visit(arg) for arg in e.args]
return e.func(*new_args) if new_args else e
......@@ -370,6 +374,9 @@ def expand_diff_linear(expr, functions=None, constants=None):
return diff.split_linear(functions)
elif isinstance(expr, sp.Piecewise):
return sp.Piecewise(*((expand_diff_linear(a, functions, constants), b) for a, b in expr.args))
elif isinstance(expr, sp.Tuple):
new_args = [expand_diff_linear(e, functions) for e in expr.args]
return sp.Tuple(*new_args)
else:
new_args = [expand_diff_linear(e, functions) for e in expr.args]
result = sp.expand(expr.func(*new_args) if new_args else expr)
......
......@@ -44,7 +44,7 @@ def test_error_handling():
Field.create_generic('f', spatial_dimensions=2, index_dimensions=1, dtype=struct_dtype)
assert 'index dimension' in str(e.value)
arr = np.array([[1, 2.0, 3], [1, 2.0, 3]], dtype=struct_dtype)
arr = np.array([[[(1,)*3, (2,)*3, (3,)*3]]*2], dtype=struct_dtype)
Field.create_from_numpy_array('f', arr, index_dimensions=0)
with pytest.raises(ValueError) as e:
Field.create_from_numpy_array('f', arr, index_dimensions=1)
......
......@@ -153,7 +153,8 @@ def test_rotate_interpolation_gpu(dtype, address_mode, use_textures):
@pytest.mark.parametrize('dtype', [np.float64, np.float32, np.int32])
@pytest.mark.parametrize('use_textures', ('use_textures', False,))
def test_shift_interpolation_gpu(address_mode, dtype, use_textures):
if int(sympy.__version__.replace('.', '')) < 12 and address_mode in ['mirror', 'warp']:
sver = sympy.__version__.split(".")
if (int(sver[0]) == 1 and int(sver[1]) < 2) and address_mode in ['mirror', 'warp']:
pytest.skip()
pytest.importorskip('pycuda')
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment