Commit 61d1bae6 authored by Martin Bauer's avatar Martin Bauer
Browse files

Tests and documentation for derivative module

parent beec6d3e
......@@ -130,4 +130,4 @@ pages:
tags:
- docker
only:
- master@software/pystencils
- master@pycodegen/pystencils
This diff is collapsed.
......@@ -15,6 +15,6 @@ It is a good idea to download them and run them directly to be able to play arou
/notebooks/05_tutorial_phasefield_spinodal_decomposition.ipynb
/notebooks/06_tutorial_phasefield_dentritic_growth.ipynb
/notebooks/demo_assignment_collection.ipynb
/notebooks/demo_derivatives.ipynb
/notebooks/demo_benchmark.ipynb
/notebooks/demo_wave_equation.ipynb
import sympy as sp
from collections import namedtuple, defaultdict
from pystencils import Field
from pystencils.sympyextensions import normalize_product, prod
......@@ -214,6 +213,11 @@ def diff_terms(expr):
This function yields different results than 'expr.atoms(Diff)' when nested derivatives are in the expression,
since this function only returns the outer derivatives
Example:
>>> x, y = sp.symbols("x, y")
>>> diff_terms( diff(x, 0, 0) )
{Diff(Diff(x, 0, -1), 0, -1)}
"""
result = set()
......
......@@ -11,6 +11,7 @@ from .derivation import FiniteDifferenceStencilDerivation
def fd_stencils_standard(indices, dx, fa):
order = len(indices)
assert all(i >= 0 for i in indices), "Can only discretize objects with (integer) subscripts"
if order == 1:
idx = indices[0]
return (fa.neighbor(idx, 1) - fa.neighbor(idx, -1)) / (2 * dx)
......@@ -122,7 +123,6 @@ def discretize_spatial(expr, dx, stencil=fd_stencils_standard):
def discretize_spatial_staggered(expr, dx, stencil=fd_stencils_standard):
def staggered_visitor(e, coordinate, sign):
if isinstance(e, Diff):
arg, *indices = diff_args(e)
......
import sympy as sp
import pystencils as ps
from sympy.printing.latex import LatexPrinter
from pystencils.fd import *
from sympy.abc import a, b, t, x, y, z
def test_derivative_basic():
x, y, z, t = sp.symbols("x y z t")
d = diff
op1, op2, op3 = DiffOperator(), DiffOperator(target=x), DiffOperator(target=x, superscript=1)
......@@ -18,4 +19,31 @@ def test_derivative_basic():
assert diff_term == dx**2 + 2 * dx * dy + dy**2 + 1
assert DiffOperator.apply(diff_term, t) == d(t, x, x) + 2 * d(t, x, y) + d(t, y, y) + t
assert ps.fd.Diff(0) == 0
expr = ps.fd.diff(ps.fd.diff(x, 0, 0), 1, 1)
assert expr.get_arg_recursive() == x
assert expr.change_arg_recursive(y).get_arg_recursive() == y
def test_derivative_expand_collect():
original = Diff(x*y*z)
result = combine_diff_products(combine_diff_products(expand_diff_products(original))).expand()
assert original == result
original = -3 * y * z * Diff(x) + 2 * x * z * Diff(y)
result = expand_diff_products(combine_diff_products(original)).expand()
assert original == result
original = a + b * Diff(x ** 2 * y * z)
expanded = expand_diff_products(original)
collect_res = combine_diff_products(combine_diff_products(combine_diff_products(expanded)))
assert collect_res == original
def test_diff_expand_using_linearity():
eps = sp.symbols("epsilon")
funcs = [a, b]
test = Diff(eps * Diff(a+b))
result = expand_diff_linear(test, functions=funcs)
assert result == eps * Diff(Diff(a)) + eps * Diff(Diff(b))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment