diff --git a/pystencils/fd/derivative.py b/pystencils/fd/derivative.py index 0e2890ec558db92a364c6912b19b26cb676f7217..c119d1e2ec34c32c67f18f7837c43dee05cfc65b 100644 --- a/pystencils/fd/derivative.py +++ b/pystencils/fd/derivative.py @@ -228,7 +228,9 @@ def diff_terms(expr): Example: >>> x, y = sp.symbols("x, y") - >>> diff_terms( diff(x, 0, 0) ) + >>> diff_terms( diff(x, 0, 0) ) + {Diff(Diff(x, 0, -1), 0, -1)} + >>> diff_terms( diff(x, 0, 0) + y ) {Diff(Diff(x, 0, -1), 0, -1)} """ result = set() diff --git a/pystencils_tests/test_fd_derivation.py b/pystencils_tests/test_fd_derivation.py deleted file mode 100644 index c2bb1aa08e7007e024aaf0867c15764c5e0880ce..0000000000000000000000000000000000000000 --- a/pystencils_tests/test_fd_derivation.py +++ /dev/null @@ -1,34 +0,0 @@ -import pytest -import sympy as sp - -from pystencils.utils import LinearEquationSystem - - -def test_linear_equation_system(): - unknowns = sp.symbols("x_:3") - x, y, z = unknowns - m = LinearEquationSystem(unknowns) - m.add_equation(x + y - 2) - m.add_equation(x - y - 1) - assert m.solution_structure() == 'multiple' - m.set_unknown_zero(2) - assert m.solution_structure() == 'single' - solution = m.solution() - assert solution[unknowns[2]] == 0 - assert solution[unknowns[1]] == sp.Rational(1, 2) - assert solution[unknowns[0]] == sp.Rational(3, 2) - - m.set_unknown_zero(0) - assert m.solution_structure() == 'none' - - # special case where less rows than unknowns, but no solution - m = LinearEquationSystem(unknowns) - m.add_equation(x - 3) - m.add_equation(x - 4) - assert m.solution_structure() == 'none' - m.add_equation(y - 4) - assert m.solution_structure() == 'none' - - with pytest.raises(ValueError) as e: - m.add_equation(x**2 - 1) - assert 'Not a linear equation' in str(e.value) diff --git a/pystencils_tests/test_fd_derivative.py b/pystencils_tests/test_fd_derivative.py new file mode 100644 index 0000000000000000000000000000000000000000..fc2fca655440574659f001048f340a1c008028f5 --- /dev/null +++ b/pystencils_tests/test_fd_derivative.py @@ -0,0 +1,24 @@ +import sympy as sp +from pystencils import fields +from pystencils.fd import Diff, diff, collect_diffs +from pystencils.fd.derivative import replace_generic_laplacian + + +def test_fs(): + f = sp.Symbol("f", commutative=False) + + a = Diff(Diff(Diff(f, 1), 0), 0) + assert a.is_commutative is False + print(str(a)) + + assert diff(f) == f + + x, y = sp.symbols("x, y") + collected_terms = collect_diffs(diff(x, 0, 0)) + assert collected_terms == Diff(Diff(x, 0, -1), 0, -1) + + src = fields("src : double[2D]") + expr = sp.Add(Diff(Diff(src[0, 0])), 10) + expected = Diff(Diff(src[0, 0], 0, -1), 0, -1) + Diff(Diff(src[0, 0], 1, -1), 1, -1) + 10 + result = replace_generic_laplacian(expr, 3) + assert result == expected \ No newline at end of file diff --git a/pystencils_tests/test_utils.py b/pystencils_tests/test_utils.py index 231b165a92134c1ab4d1c2904bb71f066ced196f..3085ef61a3a33f599a53d5c7233d892a59367518 100644 --- a/pystencils_tests/test_utils.py +++ b/pystencils_tests/test_utils.py @@ -1,9 +1,38 @@ +import pytest import sympy as sp from pystencils.utils import LinearEquationSystem from pystencils.utils import DotDict -def test_LinearEquationSystem(): +def test_linear_equation_system(): + unknowns = sp.symbols("x_:3") + x, y, z = unknowns + m = LinearEquationSystem(unknowns) + m.add_equation(x + y - 2) + m.add_equation(x - y - 1) + assert m.solution_structure() == 'multiple' + m.set_unknown_zero(2) + assert m.solution_structure() == 'single' + solution = m.solution() + assert solution[unknowns[2]] == 0 + assert solution[unknowns[1]] == sp.Rational(1, 2) + assert solution[unknowns[0]] == sp.Rational(3, 2) + + m.set_unknown_zero(0) + assert m.solution_structure() == 'none' + + # special case where less rows than unknowns, but no solution + m = LinearEquationSystem(unknowns) + m.add_equation(x - 3) + m.add_equation(x - 4) + assert m.solution_structure() == 'none' + m.add_equation(y - 4) + assert m.solution_structure() == 'none' + + with pytest.raises(ValueError) as e: + m.add_equation(x**2 - 1) + assert 'Not a linear equation' in str(e.value) + x, y, z = sp.symbols("x, y, z") les = LinearEquationSystem([x, y, z]) les.add_equation(1 * x + 2 * y - 1 * z + 4) @@ -37,7 +66,7 @@ def test_LinearEquationSystem(): assert les.solution_structure() == 'none' -def test_DotDict(): +def test_dot_dict(): d = {'a': {'c': 7}, 'b': 6} t = DotDict(d) assert t.a.c == 7