Commit 61d1bae6 authored by Martin Bauer's avatar Martin Bauer
Browse files

Tests and documentation for derivative module

parent beec6d3e
...@@ -130,4 +130,4 @@ pages: ...@@ -130,4 +130,4 @@ pages:
tags: tags:
- docker - docker
only: only:
- master@software/pystencils - master@pycodegen/pystencils
This diff is collapsed.
...@@ -15,6 +15,6 @@ It is a good idea to download them and run them directly to be able to play arou ...@@ -15,6 +15,6 @@ It is a good idea to download them and run them directly to be able to play arou
/notebooks/05_tutorial_phasefield_spinodal_decomposition.ipynb /notebooks/05_tutorial_phasefield_spinodal_decomposition.ipynb
/notebooks/06_tutorial_phasefield_dentritic_growth.ipynb /notebooks/06_tutorial_phasefield_dentritic_growth.ipynb
/notebooks/demo_assignment_collection.ipynb /notebooks/demo_assignment_collection.ipynb
/notebooks/demo_derivatives.ipynb
/notebooks/demo_benchmark.ipynb /notebooks/demo_benchmark.ipynb
/notebooks/demo_wave_equation.ipynb /notebooks/demo_wave_equation.ipynb
import sympy as sp import sympy as sp
from collections import namedtuple, defaultdict from collections import namedtuple, defaultdict
from pystencils import Field from pystencils import Field
from pystencils.sympyextensions import normalize_product, prod from pystencils.sympyextensions import normalize_product, prod
...@@ -214,6 +213,11 @@ def diff_terms(expr): ...@@ -214,6 +213,11 @@ def diff_terms(expr):
This function yields different results than 'expr.atoms(Diff)' when nested derivatives are in the expression, This function yields different results than 'expr.atoms(Diff)' when nested derivatives are in the expression,
since this function only returns the outer derivatives since this function only returns the outer derivatives
Example:
>>> x, y = sp.symbols("x, y")
>>> diff_terms( diff(x, 0, 0) )
{Diff(Diff(x, 0, -1), 0, -1)}
""" """
result = set() result = set()
......
...@@ -11,6 +11,7 @@ from .derivation import FiniteDifferenceStencilDerivation ...@@ -11,6 +11,7 @@ from .derivation import FiniteDifferenceStencilDerivation
def fd_stencils_standard(indices, dx, fa): def fd_stencils_standard(indices, dx, fa):
order = len(indices) order = len(indices)
assert all(i >= 0 for i in indices), "Can only discretize objects with (integer) subscripts"
if order == 1: if order == 1:
idx = indices[0] idx = indices[0]
return (fa.neighbor(idx, 1) - fa.neighbor(idx, -1)) / (2 * dx) return (fa.neighbor(idx, 1) - fa.neighbor(idx, -1)) / (2 * dx)
...@@ -122,7 +123,6 @@ def discretize_spatial(expr, dx, stencil=fd_stencils_standard): ...@@ -122,7 +123,6 @@ def discretize_spatial(expr, dx, stencil=fd_stencils_standard):
def discretize_spatial_staggered(expr, dx, stencil=fd_stencils_standard): def discretize_spatial_staggered(expr, dx, stencil=fd_stencils_standard):
def staggered_visitor(e, coordinate, sign): def staggered_visitor(e, coordinate, sign):
if isinstance(e, Diff): if isinstance(e, Diff):
arg, *indices = diff_args(e) arg, *indices = diff_args(e)
......
import sympy as sp import sympy as sp
import pystencils as ps
from sympy.printing.latex import LatexPrinter from sympy.printing.latex import LatexPrinter
from pystencils.fd import * from pystencils.fd import *
from sympy.abc import a, b, t, x, y, z
def test_derivative_basic(): def test_derivative_basic():
x, y, z, t = sp.symbols("x y z t")
d = diff d = diff
op1, op2, op3 = DiffOperator(), DiffOperator(target=x), DiffOperator(target=x, superscript=1) op1, op2, op3 = DiffOperator(), DiffOperator(target=x), DiffOperator(target=x, superscript=1)
...@@ -18,4 +19,31 @@ def test_derivative_basic(): ...@@ -18,4 +19,31 @@ def test_derivative_basic():
assert diff_term == dx**2 + 2 * dx * dy + dy**2 + 1 assert diff_term == dx**2 + 2 * dx * dy + dy**2 + 1
assert DiffOperator.apply(diff_term, t) == d(t, x, x) + 2 * d(t, x, y) + d(t, y, y) + t assert DiffOperator.apply(diff_term, t) == d(t, x, x) + 2 * d(t, x, y) + d(t, y, y) + t
assert ps.fd.Diff(0) == 0
expr = ps.fd.diff(ps.fd.diff(x, 0, 0), 1, 1)
assert expr.get_arg_recursive() == x
assert expr.change_arg_recursive(y).get_arg_recursive() == y
def test_derivative_expand_collect():
original = Diff(x*y*z)
result = combine_diff_products(combine_diff_products(expand_diff_products(original))).expand()
assert original == result
original = -3 * y * z * Diff(x) + 2 * x * z * Diff(y)
result = expand_diff_products(combine_diff_products(original)).expand()
assert original == result
original = a + b * Diff(x ** 2 * y * z)
expanded = expand_diff_products(original)
collect_res = combine_diff_products(combine_diff_products(combine_diff_products(expanded)))
assert collect_res == original
def test_diff_expand_using_linearity():
eps = sp.symbols("epsilon")
funcs = [a, b]
test = Diff(eps * Diff(a+b))
result = expand_diff_linear(test, functions=funcs)
assert result == eps * Diff(Diff(a)) + eps * Diff(Diff(b))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment