diff --git a/derivative.py b/derivative.py
index 2dbcae4519ad217cdac8ec5f275ba1e2603c92fb..3135761f79cd39b8e5c7e84b228ca15a7f351bb7 100644
--- a/derivative.py
+++ b/derivative.py
@@ -195,10 +195,10 @@ def collect_derivatives(expr):
     return expr.collect(derivative_terms(expr))
 
 
-def create_nested_diff(*args, arg=None):
+def create_nested_diff(arg, *args):
     """Shortcut to create nested derivatives"""
     assert arg is not None
-    args = sorted(args, reverse=True)
+    args = sorted(args, reverse=True, key=lambda e: e.name if isinstance(e, sp.Symbol) else e)
     res = arg
     for i in args:
         res = Diff(res, i)
@@ -455,7 +455,7 @@ def functional_derivative(functional, v):
         \frac{\delta F}{\delta v} =
                 \frac{\partial F}{\partial v} - \nabla \cdot \frac{\partial F}{\partial \nabla v}
 
-    - assumes that gradients are represented by Diff() node (from Chapman Enskog module)
+    - assumes that gradients are represented by Diff() node
     - Diff(Diff(r)) represents the divergence of r
     - the constants parameter is a list with symbols not affected by the derivative. This is used for simplification
       of the derivative terms.