Skip to content
Snippets Groups Projects
Commit cd4c034b authored by Martin Bauer's avatar Martin Bauer
Browse files

Increased test coverage

parent 8c124ade
Branches
Tags
No related merge requests found
......@@ -198,11 +198,6 @@ def get_cache_config():
return _config['cache']
def hash_to_function_name(h):
res = "func_%s" % (h,)
return res.replace('-', 'm')
def add_or_change_compiler_flags(flags):
if not isinstance(flags, list) and not isinstance(flags, tuple):
flags = [flags]
......
from .derivative import Diff, DiffOperator, \
diff_terms, collect_diffs, replace_diff, zero_diffs, evaluate_diffs, normalize_diff_order, \
diff_terms, collect_diffs, zero_diffs, evaluate_diffs, normalize_diff_order, \
expand_diff_full, expand_diff_linear, expand_diff_products, combine_diff_products, \
functional_derivative, diff
from .finitedifferences import advection, diffusion, transient, Discretization2ndOrder
from .spatial import discretize_spatial
__all__ = ['Diff', 'diff', 'DiffOperator', 'diff_terms', 'collect_diffs', 'replace_diff',
__all__ = ['Diff', 'diff', 'DiffOperator', 'diff_terms', 'collect_diffs',
'zero_diffs', 'evaluate_diffs', 'normalize_diff_order', 'expand_diff_full', 'expand_diff_linear',
'expand_diff_products', 'combine_diff_products', 'functional_derivative',
'advection', 'diffusion', 'transient', 'Discretization2ndOrder', 'discretize_spatial']
......@@ -233,21 +233,15 @@ def collect_diffs(expr):
return expr.collect(diff_terms(expr))
def replace_diff(expr, replacement_dict):
"""replacement_dict: maps variable (target) to a new Differential operator"""
def visit(e):
if isinstance(e, Diff):
if e.target in replacement_dict:
return DiffOperator.apply(replacement_dict[e.target], visit(e.arg))
new_args = [visit(arg) for arg in e.args]
return e.func(*new_args) if new_args else e
return visit(expr)
def zero_diffs(expr, label):
"""Replaces all differentials with the given target by 0"""
"""Replaces all differentials with the given target by 0
Example:
>>> x, y, f = sp.symbols("x y f")
>>> expression = Diff(f, x) + Diff(f, y) + Diff(Diff(f, y), x) + 7
>>> zero_diffs(expression, x)
Diff(f, y, -1) + 7
"""
def visit(e):
if isinstance(e, Diff):
......@@ -493,6 +487,11 @@ def replace_generic_laplacian(expr, dim=None):
This function replaces these constructs by diff(term, 0, 0) + diff(term, 1, 1) + ...
For this to work, the arguments of the derivative have to be field or field accesses such that the spatial
dimension can be determined.
>>> l = Diff(Diff(sp.symbols('x')))
>>> replace_generic_laplacian(l, 3)
Diff(Diff(x, 0, -1), 0, -1) + Diff(Diff(x, 1, -1), 1, -1) + Diff(Diff(x, 2, -1), 2, -1)
"""
if isinstance(expr, Diff):
arg, *indices = diff_args(expr)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment