diff --git a/fd/derivative.py b/fd/derivative.py index 902bda803a392c0637400e09f828d927215b85e5..cd74b0fe5cb6e370e145160ddf7738b150085a79 100644 --- a/fd/derivative.py +++ b/fd/derivative.py @@ -304,6 +304,8 @@ def expand_diff_full(expr, functions=None, constants=None): processed_diff = normalize_diff_order(Diff(dependent_term, **diff_args)) result += independent_terms * prod(other_dependent_terms) * processed_diff return result + elif isinstance(e, sp.Piecewise): + return sp.Piecewise(*((expand_diff_full(a, functions, constants), b) for a, b in e.args)) else: new_args = [visit(arg) for arg in e.args] return e.func(*new_args) if new_args else e @@ -341,6 +343,8 @@ def expand_diff_linear(expr, functions=None, constants=None): return 0 else: return diff.split_linear(functions) + elif isinstance(expr, sp.Piecewise): + return sp.Piecewise(*((expand_diff_linear(a, functions, constants), b) for a, b in expr.args)) else: new_args = [expand_diff_linear(e, functions) for e in expr.args] result = sp.expand(expr.func(*new_args) if new_args else expr)