From 7c97bb74d0ca3e71a018a368356ecafdffffd799 Mon Sep 17 00:00:00 2001
From: markus holzer <markus.holzer@fau.de>
Date: Mon, 10 Aug 2020 13:32:43 +0200
Subject: [PATCH] Added test cases for fd derivation

---
 pystencils/fd/derivative.py            |  4 ++-
 pystencils_tests/test_fd_derivation.py | 34 --------------------------
 pystencils_tests/test_fd_derivative.py | 24 ++++++++++++++++++
 pystencils_tests/test_utils.py         | 33 +++++++++++++++++++++++--
 4 files changed, 58 insertions(+), 37 deletions(-)
 delete mode 100644 pystencils_tests/test_fd_derivation.py
 create mode 100644 pystencils_tests/test_fd_derivative.py

diff --git a/pystencils/fd/derivative.py b/pystencils/fd/derivative.py
index 0e2890ec5..c119d1e2e 100644
--- a/pystencils/fd/derivative.py
+++ b/pystencils/fd/derivative.py
@@ -228,7 +228,9 @@ def diff_terms(expr):
 
     Example:
         >>> x, y = sp.symbols("x, y")
-        >>> diff_terms( diff(x, 0, 0)  )
+        >>> diff_terms( diff(x, 0, 0) )
+        {Diff(Diff(x, 0, -1), 0, -1)}
+        >>> diff_terms( diff(x, 0, 0) + y )
         {Diff(Diff(x, 0, -1), 0, -1)}
     """
     result = set()
diff --git a/pystencils_tests/test_fd_derivation.py b/pystencils_tests/test_fd_derivation.py
deleted file mode 100644
index c2bb1aa08..000000000
--- a/pystencils_tests/test_fd_derivation.py
+++ /dev/null
@@ -1,34 +0,0 @@
-import pytest
-import sympy as sp
-
-from pystencils.utils import LinearEquationSystem
-
-
-def test_linear_equation_system():
-    unknowns = sp.symbols("x_:3")
-    x, y, z = unknowns
-    m = LinearEquationSystem(unknowns)
-    m.add_equation(x + y - 2)
-    m.add_equation(x - y - 1)
-    assert m.solution_structure() == 'multiple'
-    m.set_unknown_zero(2)
-    assert m.solution_structure() == 'single'
-    solution = m.solution()
-    assert solution[unknowns[2]] == 0
-    assert solution[unknowns[1]] == sp.Rational(1, 2)
-    assert solution[unknowns[0]] == sp.Rational(3, 2)
-
-    m.set_unknown_zero(0)
-    assert m.solution_structure() == 'none'
-
-    # special case where less rows than unknowns, but no solution
-    m = LinearEquationSystem(unknowns)
-    m.add_equation(x - 3)
-    m.add_equation(x - 4)
-    assert m.solution_structure() == 'none'
-    m.add_equation(y - 4)
-    assert m.solution_structure() == 'none'
-
-    with pytest.raises(ValueError) as e:
-        m.add_equation(x**2 - 1)
-    assert 'Not a linear equation' in str(e.value)
diff --git a/pystencils_tests/test_fd_derivative.py b/pystencils_tests/test_fd_derivative.py
new file mode 100644
index 000000000..fc2fca655
--- /dev/null
+++ b/pystencils_tests/test_fd_derivative.py
@@ -0,0 +1,24 @@
+import sympy as sp
+from pystencils import fields
+from pystencils.fd import Diff, diff, collect_diffs
+from pystencils.fd.derivative import replace_generic_laplacian
+
+
+def test_fs():
+    f = sp.Symbol("f", commutative=False)
+
+    a = Diff(Diff(Diff(f, 1), 0), 0)
+    assert a.is_commutative is False
+    print(str(a))
+
+    assert diff(f) == f
+
+    x, y = sp.symbols("x, y")
+    collected_terms = collect_diffs(diff(x, 0, 0))
+    assert collected_terms == Diff(Diff(x, 0, -1), 0, -1)
+
+    src = fields("src : double[2D]")
+    expr = sp.Add(Diff(Diff(src[0, 0])), 10)
+    expected = Diff(Diff(src[0, 0], 0, -1), 0, -1) + Diff(Diff(src[0, 0], 1, -1), 1, -1) + 10
+    result = replace_generic_laplacian(expr, 3)
+    assert result == expected
\ No newline at end of file
diff --git a/pystencils_tests/test_utils.py b/pystencils_tests/test_utils.py
index 231b165a9..3085ef61a 100644
--- a/pystencils_tests/test_utils.py
+++ b/pystencils_tests/test_utils.py
@@ -1,9 +1,38 @@
+import pytest
 import sympy as sp
 from pystencils.utils import LinearEquationSystem
 from pystencils.utils import DotDict
 
 
-def test_LinearEquationSystem():
+def test_linear_equation_system():
+    unknowns = sp.symbols("x_:3")
+    x, y, z = unknowns
+    m = LinearEquationSystem(unknowns)
+    m.add_equation(x + y - 2)
+    m.add_equation(x - y - 1)
+    assert m.solution_structure() == 'multiple'
+    m.set_unknown_zero(2)
+    assert m.solution_structure() == 'single'
+    solution = m.solution()
+    assert solution[unknowns[2]] == 0
+    assert solution[unknowns[1]] == sp.Rational(1, 2)
+    assert solution[unknowns[0]] == sp.Rational(3, 2)
+
+    m.set_unknown_zero(0)
+    assert m.solution_structure() == 'none'
+
+    # special case where less rows than unknowns, but no solution
+    m = LinearEquationSystem(unknowns)
+    m.add_equation(x - 3)
+    m.add_equation(x - 4)
+    assert m.solution_structure() == 'none'
+    m.add_equation(y - 4)
+    assert m.solution_structure() == 'none'
+
+    with pytest.raises(ValueError) as e:
+        m.add_equation(x**2 - 1)
+    assert 'Not a linear equation' in str(e.value)
+
     x, y, z = sp.symbols("x, y, z")
     les = LinearEquationSystem([x, y, z])
     les.add_equation(1 * x + 2 * y - 1 * z + 4)
@@ -37,7 +66,7 @@ def test_LinearEquationSystem():
     assert les.solution_structure() == 'none'
 
 
-def test_DotDict():
+def test_dot_dict():
     d = {'a': {'c': 7}, 'b': 6}
     t = DotDict(d)
     assert t.a.c == 7
-- 
GitLab