Skip to content
Snippets Groups Projects
Commit cd4c034b authored by Martin Bauer's avatar Martin Bauer
Browse files

Increased test coverage

parent 8c124ade
Branches
Tags
No related merge requests found
...@@ -198,11 +198,6 @@ def get_cache_config(): ...@@ -198,11 +198,6 @@ def get_cache_config():
return _config['cache'] return _config['cache']
def hash_to_function_name(h):
res = "func_%s" % (h,)
return res.replace('-', 'm')
def add_or_change_compiler_flags(flags): def add_or_change_compiler_flags(flags):
if not isinstance(flags, list) and not isinstance(flags, tuple): if not isinstance(flags, list) and not isinstance(flags, tuple):
flags = [flags] flags = [flags]
......
from .derivative import Diff, DiffOperator, \ from .derivative import Diff, DiffOperator, \
diff_terms, collect_diffs, replace_diff, zero_diffs, evaluate_diffs, normalize_diff_order, \ diff_terms, collect_diffs, zero_diffs, evaluate_diffs, normalize_diff_order, \
expand_diff_full, expand_diff_linear, expand_diff_products, combine_diff_products, \ expand_diff_full, expand_diff_linear, expand_diff_products, combine_diff_products, \
functional_derivative, diff functional_derivative, diff
from .finitedifferences import advection, diffusion, transient, Discretization2ndOrder from .finitedifferences import advection, diffusion, transient, Discretization2ndOrder
from .spatial import discretize_spatial from .spatial import discretize_spatial
__all__ = ['Diff', 'diff', 'DiffOperator', 'diff_terms', 'collect_diffs', 'replace_diff', __all__ = ['Diff', 'diff', 'DiffOperator', 'diff_terms', 'collect_diffs',
'zero_diffs', 'evaluate_diffs', 'normalize_diff_order', 'expand_diff_full', 'expand_diff_linear', 'zero_diffs', 'evaluate_diffs', 'normalize_diff_order', 'expand_diff_full', 'expand_diff_linear',
'expand_diff_products', 'combine_diff_products', 'functional_derivative', 'expand_diff_products', 'combine_diff_products', 'functional_derivative',
'advection', 'diffusion', 'transient', 'Discretization2ndOrder', 'discretize_spatial'] 'advection', 'diffusion', 'transient', 'Discretization2ndOrder', 'discretize_spatial']
...@@ -233,21 +233,15 @@ def collect_diffs(expr): ...@@ -233,21 +233,15 @@ def collect_diffs(expr):
return expr.collect(diff_terms(expr)) return expr.collect(diff_terms(expr))
def replace_diff(expr, replacement_dict):
"""replacement_dict: maps variable (target) to a new Differential operator"""
def visit(e):
if isinstance(e, Diff):
if e.target in replacement_dict:
return DiffOperator.apply(replacement_dict[e.target], visit(e.arg))
new_args = [visit(arg) for arg in e.args]
return e.func(*new_args) if new_args else e
return visit(expr)
def zero_diffs(expr, label): def zero_diffs(expr, label):
"""Replaces all differentials with the given target by 0""" """Replaces all differentials with the given target by 0
Example:
>>> x, y, f = sp.symbols("x y f")
>>> expression = Diff(f, x) + Diff(f, y) + Diff(Diff(f, y), x) + 7
>>> zero_diffs(expression, x)
Diff(f, y, -1) + 7
"""
def visit(e): def visit(e):
if isinstance(e, Diff): if isinstance(e, Diff):
...@@ -493,6 +487,11 @@ def replace_generic_laplacian(expr, dim=None): ...@@ -493,6 +487,11 @@ def replace_generic_laplacian(expr, dim=None):
This function replaces these constructs by diff(term, 0, 0) + diff(term, 1, 1) + ... This function replaces these constructs by diff(term, 0, 0) + diff(term, 1, 1) + ...
For this to work, the arguments of the derivative have to be field or field accesses such that the spatial For this to work, the arguments of the derivative have to be field or field accesses such that the spatial
dimension can be determined. dimension can be determined.
>>> l = Diff(Diff(sp.symbols('x')))
>>> replace_generic_laplacian(l, 3)
Diff(Diff(x, 0, -1), 0, -1) + Diff(Diff(x, 1, -1), 1, -1) + Diff(Diff(x, 2, -1), 2, -1)
""" """
if isinstance(expr, Diff): if isinstance(expr, Diff):
arg, *indices = diff_args(expr) arg, *indices = diff_args(expr)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment