derivative.py 20.8 KB
Newer Older
Martin Bauer's avatar
Martin Bauer committed
1
2
from collections import defaultdict, namedtuple

Martin Bauer's avatar
Martin Bauer committed
3
import sympy as sp
Martin Bauer's avatar
Martin Bauer committed
4
5

from pystencils.field import Field
Martin Bauer's avatar
Martin Bauer committed
6
from pystencils.sympyextensions import normalize_product, prod
Martin Bauer's avatar
Martin Bauer committed
7
8


9
def _default_diff_sort_key(d):
10
    return str(d.superscript), str(d.target)
Martin Bauer's avatar
Martin Bauer committed
11
12
13


class Diff(sp.Expr):
14
15
16
    """Sympy Node representing a derivative.

    The difference to sympy's built in differential is:
Martin Bauer's avatar
Martin Bauer committed
17
18
        - shortened latex representation
        - all simplifications have to be done manually
19
        - optional marker displayed as superscript
Martin Bauer's avatar
Martin Bauer committed
20
21
22
    """
    is_number = False
    is_Rational = False
23
    _diff_wrt = True
Martin Bauer's avatar
Martin Bauer committed
24

Martin Bauer's avatar
Martin Bauer committed
25
    def __new__(cls, argument, target=-1, superscript=-1):
Martin Bauer's avatar
Martin Bauer committed
26
27
        if argument == 0:
            return sp.Rational(0, 1)
28
29
        if isinstance(argument, Field):
            argument = argument.center
Martin Bauer's avatar
Martin Bauer committed
30
        return sp.Expr.__new__(cls, argument.expand(), sp.sympify(target), sp.sympify(superscript))
Martin Bauer's avatar
Martin Bauer committed
31
32
33

    @property
    def is_commutative(self):
Martin Bauer's avatar
Martin Bauer committed
34
35
        any_non_commutative = any(not s.is_commutative for s in self.atoms(sp.Symbol))
        if any_non_commutative:
Martin Bauer's avatar
Martin Bauer committed
36
37
38
39
            return False
        else:
            return True

Martin Bauer's avatar
Martin Bauer committed
40
    def get_arg_recursive(self):
Martin Bauer's avatar
Martin Bauer committed
41
42
43
44
        """Returns the argument the derivative acts on, for nested derivatives the inner argument is returned"""
        if not isinstance(self.arg, Diff):
            return self.arg
        else:
Martin Bauer's avatar
Martin Bauer committed
45
            return self.arg.get_arg_recursive()
Martin Bauer's avatar
Martin Bauer committed
46

Martin Bauer's avatar
Martin Bauer committed
47
48
49
    def change_arg_recursive(self, new_arg):
        """Returns a Diff node with the given 'new_arg' instead of the current argument. For nested derivatives
        a new nested derivative is returned where the inner Diff has the 'new_arg'"""
Martin Bauer's avatar
Martin Bauer committed
50
        if not isinstance(self.arg, Diff):
Martin Bauer's avatar
Martin Bauer committed
51
            return Diff(new_arg, self.target, self.superscript)
Martin Bauer's avatar
Martin Bauer committed
52
        else:
Martin Bauer's avatar
Martin Bauer committed
53
            return Diff(self.arg.change_arg_recursive(new_arg), self.target, self.superscript)
Martin Bauer's avatar
Martin Bauer committed
54

Martin Bauer's avatar
Martin Bauer committed
55
    def split_linear(self, functions):
Martin Bauer's avatar
Martin Bauer committed
56
57
58
59
60
61
62
63
64
65
        """
        Applies linearity property of Diff: i.e.  'Diff(c*a+b)' is transformed to 'c * Diff(a) + Diff(b)'
        The parameter functions is a list of all symbols that are considered functions, not constants.
        For the example above: functions=[a, b]
        """
        constant, variable = 1, 1

        if self.arg.func != sp.Mul:
            constant, variable = 1, self.arg
        else:
Martin Bauer's avatar
Martin Bauer committed
66
            for factor in normalize_product(self.arg):
Martin Bauer's avatar
Martin Bauer committed
67
68
69
70
71
72
73
74
75
76
77
                if factor in functions or isinstance(factor, Diff):
                    variable *= factor
                else:
                    constant *= factor

        if isinstance(variable, sp.Symbol) and variable not in functions:
            return 0

        if isinstance(variable, int) or variable.is_number:
            return 0
        else:
78
            return constant * Diff(variable, target=self.target, superscript=self.superscript)
Martin Bauer's avatar
Martin Bauer committed
79
80
81
82
83
84
85

    @property
    def arg(self):
        """Expression the derivative acts on"""
        return self.args[0]

    @property
86
    def target(self):
Martin Bauer's avatar
Martin Bauer committed
87
88
89
90
        """Subscript, usually the variable the Diff is w.r.t. """
        return self.args[1]

    @property
91
    def superscript(self):
Martin Bauer's avatar
Martin Bauer committed
92
        """Superscript, for example used as the Chapman-Enskog order index"""
Martin Bauer's avatar
Martin Bauer committed
93
94
        return self.args[2]

Martin Bauer's avatar
Martin Bauer committed
95
    def _latex(self, printer, *_):
Martin Bauer's avatar
Martin Bauer committed
96
        result = r"{\partial"
97
98
99
100
        if self.superscript >= 0:
            result += "^{(%s)}" % (self.superscript,)
        if self.target != -1:
            result += "_{%s}" % (printer.doprint(self.target),)
Martin Bauer's avatar
Martin Bauer committed
101
102
103
104
105
106
107
108
109
110
111
112
113

        contents = printer.doprint(self.arg)
        if isinstance(self.arg, int) or isinstance(self.arg, sp.Symbol) or self.arg.is_number or self.arg.func == Diff:
            result += " " + contents
        else:
            result += " (" + contents + ") "

        result += "}"
        return result

    def __str__(self):
        return "D(%s)" % self.arg

114
    def interpolated_access(self, offset, **kwargs):
115
116
117
118
119
        """Represents an interpolated access on a spatially differentiated field

        Args:
            offset (Tuple[sympy.Expr]): Absolute position to determine the value of the spatial derivative
        """
120
        from pystencils.interpolation_astnodes import DiffInterpolatorAccess
121
122
        assert isinstance(self.arg.field, Field), "Must be field to enable interpolated accesses"
        return DiffInterpolatorAccess(self.arg.field.interpolated_access(offset, **kwargs).symbol, self.target, *offset)
123

Martin Bauer's avatar
Martin Bauer committed
124

125
class DiffOperator(sp.Expr):
126
127
128
129
    """Un-applied differential, i.e. differential operator

    Args:
        target: the differential is w.r.t to this variable.
130
131
                 This target is mainly for display purposes (its the subscript) and to distinguish DiffOperators
                 If the target is '-1' no subscript is displayed
132
133
134
        superscript: optional marker displayed as superscript
                     is not displayed if set to '-1'

135
136
137
138
139
140
141
    The DiffOperator behaves much like a variable with special name. Its main use is to be applied later, using the
    DiffOperator.apply(expr, arg) which transforms 'DiffOperator's to applied 'Diff's
    """
    is_commutative = True
    is_number = False
    is_Rational = False

Martin Bauer's avatar
Martin Bauer committed
142
143
    def __new__(cls, target=-1, superscript=-1):
        return sp.Expr.__new__(cls, sp.sympify(target), sp.sympify(superscript))
144
145
146
147
148
149
150
151
152

    @property
    def target(self):
        return self.args[0]

    @property
    def superscript(self):
        return self.args[1]

Martin Bauer's avatar
Martin Bauer committed
153
    def _latex(self, *_):
Martin Bauer's avatar
Martin Bauer committed
154
        result = r"{\partial"
155
156
157
158
159
160
161
162
        if self.superscript >= 0:
            result += "^{(%s)}" % (self.superscript,)
        if self.target != -1:
            result += "_{%s}" % (self.target,)
        result += "}"
        return result

    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
163
    def apply(expr, argument, apply_to_constants=True):
164
165
166
167
168
        """
        Returns a new expression where each 'DiffOperator' is replaced by a 'Diff' node.
        Multiplications of 'DiffOperator's are interpreted as nested application of differentiation:
        i.e. DiffOperator('x')*DiffOperator('x') is a second derivative replaced by Diff(Diff(arg, x), t)
        """
Martin Bauer's avatar
Martin Bauer committed
169

Martin Bauer's avatar
Martin Bauer committed
170
        def handle_mul(mul):
Martin Bauer's avatar
Martin Bauer committed
171
            args = normalize_product(mul)
172
173
            diffs = [a for a in args if isinstance(a, DiffOperator)]
            if len(diffs) == 0:
Martin Bauer's avatar
Martin Bauer committed
174
                return mul * argument if apply_to_constants else mul
175
            rest = [a for a in args if not isinstance(a, DiffOperator)]
176
            diffs.sort(key=_default_diff_sort_key)
177
178
179
180
181
182
183
            result = argument
            for d in reversed(diffs):
                result = Diff(result, target=d.target, superscript=d.superscript)
            return prod(rest) * result

        expr = expr.expand()
        if expr.func == sp.Mul or expr.func == sp.Pow:
Martin Bauer's avatar
Martin Bauer committed
184
            return handle_mul(expr)
185
        elif expr.func == sp.Add:
Martin Bauer's avatar
Martin Bauer committed
186
            return expr.func(*[handle_mul(a) for a in expr.args])
187
        else:
Martin Bauer's avatar
Martin Bauer committed
188
            return expr * argument if apply_to_constants else expr
189

Martin Bauer's avatar
Martin Bauer committed
190

Martin Bauer's avatar
Martin Bauer committed
191
192
# ----------------------------------------------------------------------------------------------------------------------

193

194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
def diff(expr, *args):
    """Shortcut function to create nested derivatives

    >>> f = sp.Symbol("f")
    >>> diff(f, 0, 0, 1) == Diff(Diff( Diff(f, 1), 0), 0)
    True
    """
    if len(args) == 0:
        return expr
    result = expr
    for index in reversed(args):
        result = Diff(result, index)
    return result


def diff_args(expr):
    """Extracts the indices and argument of possibly nested derivative - inverse of diff function

    >>> args = (sp.Symbol("x"), 0, 1, 2, 5, 1)
    >>> e = diff(*args)
    >>> assert diff_args(e) == args
    """
    if not isinstance(expr, Diff):
        return expr,
    else:
        inner_res = diff_args(expr.args[0])
        return (inner_res[0], expr.args[1], *inner_res[1:])


223
224
225
def diff_terms(expr):
    """Returns set of all derivatives in an expression.

226
    This function yields different results than 'expr.atoms(Diff)' when nested derivatives are in the expression,
Martin Bauer's avatar
Martin Bauer committed
227
    since this function only returns the outer derivatives
228
229
230
231
232

    Example:
        >>> x, y = sp.symbols("x, y")
        >>> diff_terms( diff(x, 0, 0)  )
        {Diff(Diff(x, 0, -1), 0, -1)}
Martin Bauer's avatar
Martin Bauer committed
233
234
235
236
237
238
239
240
241
    """
    result = set()

    def visit(e):
        if isinstance(e, Diff):
            result.add(e)
        else:
            for a in e.args:
                visit(a)
Martin Bauer's avatar
Martin Bauer committed
242

Martin Bauer's avatar
Martin Bauer committed
243
244
245
246
    visit(expr)
    return result


247
def collect_diffs(expr):
Martin Bauer's avatar
Martin Bauer committed
248
    """Rewrites expression into a sum of distinct derivatives with pre-factors"""
249
    return expr.collect(diff_terms(expr))
Martin Bauer's avatar
Martin Bauer committed
250
251


252
def zero_diffs(expr, label):
Martin Bauer's avatar
Martin Bauer committed
253
254
255
256
257
258
259
260
    """Replaces all differentials with the given target by 0

    Example:
        >>> x, y, f = sp.symbols("x y f")
        >>> expression = Diff(f, x) + Diff(f, y) + Diff(Diff(f, y), x) + 7
        >>> zero_diffs(expression, x)
        Diff(f, y, -1) + 7
    """
261
262
263
264
265
266
267
268
269
270
271
272
273

    def visit(e):
        if isinstance(e, Diff):
            if e.target == label:
                return 0
        new_args = [visit(arg) for arg in e.args]
        return e.func(*new_args) if new_args else e

    return visit(expr)


def evaluate_diffs(expr, var=None):
    """Replaces pystencils diff objects by sympy diff objects and evaluates them.
Martin Bauer's avatar
Martin Bauer committed
274

275
276
277
    Replaces Diff nodes by sp.diff , the free variable is either the target (if var=None) otherwise
    the specified var
    """
Martin Bauer's avatar
Martin Bauer committed
278
    if isinstance(expr, Diff):
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
        if var is None:
            var = expr.target
        return sp.diff(evaluate_diffs(expr.arg, var), var)
    else:
        new_args = [evaluate_diffs(arg, var) for arg in expr.args]
        return expr.func(*new_args) if new_args else expr


def normalize_diff_order(expression, functions=None, constants=None, sort_key=_default_diff_sort_key):
    """Assumes order of differentiation can be exchanged. Changes the order of nested Diffs to a standard order defined
    by the sorting key 'sort_key' such that the derivative terms can be further simplified """

    def visit(expr):
        if isinstance(expr, Diff):
            nodes = [expr]
            while isinstance(nodes[-1].arg, Diff):
                nodes.append(nodes[-1].arg)

            processed_arg = visit(nodes[-1].arg)
            nodes.sort(key=sort_key)

            result = processed_arg
            for d in reversed(nodes):
                result = Diff(result, target=d.target, superscript=d.superscript)
Martin Bauer's avatar
Martin Bauer committed
303
304
            return result
        else:
305
306
307
308
309
            new_args = [visit(e) for e in expr.args]
            return expr.func(*new_args) if new_args else expr

    expression = expand_diff_linear(expression.expand(), functions, constants).expand()
    return visit(expression)
Martin Bauer's avatar
Martin Bauer committed
310
311


312
def expand_diff_full(expr, functions=None, constants=None):
Martin Bauer's avatar
Martin Bauer committed
313
314
315
316
317
318
    if functions is None:
        functions = expr.atoms(sp.Symbol)
        if constants is not None:
            functions.difference_update(constants)

    def visit(e):
Michael Kuron's avatar
Michael Kuron committed
319
320
        if not isinstance(e, sp.Tuple):
            e = e.expand()
Martin Bauer's avatar
Martin Bauer committed
321
322
323

        if e.func == Diff:
            result = 0
Martin Bauer's avatar
Martin Bauer committed
324
325
326
            diff_args = {'target': e.target, 'superscript': e.superscript}
            diff_inner = e.args[0]
            diff_inner = visit(diff_inner)
Martin Bauer's avatar
Martin Bauer committed
327
328
            if diff_inner.func not in (sp.Add, sp.Mul):
                return e
Martin Bauer's avatar
Martin Bauer committed
329
330
331
            for term in diff_inner.args if diff_inner.func == sp.Add else [diff_inner]:
                independent_terms = 1
                dependent_terms = []
Martin Bauer's avatar
Martin Bauer committed
332
                for factor in normalize_product(term):
Martin Bauer's avatar
Martin Bauer committed
333
                    if factor in functions or isinstance(factor, Diff):
Martin Bauer's avatar
Martin Bauer committed
334
                        dependent_terms.append(factor)
Martin Bauer's avatar
Martin Bauer committed
335
                    else:
Martin Bauer's avatar
Martin Bauer committed
336
337
338
                        independent_terms *= factor
                for i in range(len(dependent_terms)):
                    dependent_term = dependent_terms[i]
Martin Bauer's avatar
Martin Bauer committed
339
                    other_dependent_terms = dependent_terms[:i] + dependent_terms[i + 1:]
Martin Bauer's avatar
Martin Bauer committed
340
341
                    processed_diff = normalize_diff_order(Diff(dependent_term, **diff_args))
                    result += independent_terms * prod(other_dependent_terms) * processed_diff
Martin Bauer's avatar
Martin Bauer committed
342
            return result
343
344
        elif isinstance(e, sp.Piecewise):
            return sp.Piecewise(*((expand_diff_full(a, functions, constants), b) for a, b in e.args))
Michael Kuron's avatar
Michael Kuron committed
345
346
347
        elif isinstance(expr, sp.Tuple):
            new_args = [visit(arg) for arg in e.args]
            return sp.Tuple(*new_args)
Martin Bauer's avatar
Martin Bauer committed
348
        else:
Martin Bauer's avatar
Martin Bauer committed
349
350
            new_args = [visit(arg) for arg in e.args]
            return e.func(*new_args) if new_args else e
Martin Bauer's avatar
Martin Bauer committed
351
352
353
354
355
356
357

    if isinstance(expr, sp.Matrix):
        return expr.applyfunc(visit)
    else:
        return visit(expr)


358
359
def expand_diff_linear(expr, functions=None, constants=None):
    """Expands all derivative nodes by applying Diff.split_linear
Martin Bauer's avatar
Martin Bauer committed
360

361
362
363
364
365
366
367
368
369
370
    Args:
        expr: expression containing derivatives
        functions: sequence of symbols that are considered functions and can not be pulled before the derivative.
                   if None, all symbols are viewed as functions
        constants: sequence of symbols which are considered constants and can be pulled before the derivative
    """
    if functions is None:
        functions = expr.atoms(sp.Symbol)
        if constants is not None:
            functions.difference_update(constants)
Martin Bauer's avatar
Martin Bauer committed
371

372
373
374
375
376
377
    if isinstance(expr, Diff):
        arg = expand_diff_linear(expr.arg, functions)
        if hasattr(arg, 'func') and arg.func == sp.Add:
            result = 0
            for a in arg.args:
                result += Diff(a, target=expr.target, superscript=expr.superscript).split_linear(functions)
Martin Bauer's avatar
Martin Bauer committed
378
379
            return result
        else:
380
381
382
383
384
            diff = Diff(arg, target=expr.target, superscript=expr.superscript)
            if diff == 0:
                return 0
            else:
                return diff.split_linear(functions)
385
386
    elif isinstance(expr, sp.Piecewise):
        return sp.Piecewise(*((expand_diff_linear(a, functions, constants), b) for a, b in expr.args))
Michael Kuron's avatar
Michael Kuron committed
387
388
389
    elif isinstance(expr, sp.Tuple):
        new_args = [expand_diff_linear(e, functions) for e in expr.args]
        return sp.Tuple(*new_args)
390
391
392
393
    else:
        new_args = [expand_diff_linear(e, functions) for e in expr.args]
        result = sp.expand(expr.func(*new_args) if new_args else expr)
        return result
Martin Bauer's avatar
Martin Bauer committed
394
395


396
def expand_diff_products(expr):
Martin Bauer's avatar
Martin Bauer committed
397
398
    """Fully expands all derivatives by applying product rule"""
    if isinstance(expr, Diff):
399
        arg = expand_diff_products(expr.args[0])
Martin Bauer's avatar
Martin Bauer committed
400
        if arg.func == sp.Add:
Martin Bauer's avatar
Martin Bauer committed
401
402
403
            new_args = [Diff(e, target=expr.target, superscript=expr.superscript)
                        for e in arg.args]
            return sp.Add(*new_args)
Martin Bauer's avatar
Martin Bauer committed
404
        if arg.func not in (sp.Mul, sp.Pow):
405
            return Diff(arg, target=expr.target, superscript=expr.superscript)
Martin Bauer's avatar
Martin Bauer committed
406
        else:
Martin Bauer's avatar
Martin Bauer committed
407
            prod_list = normalize_product(arg)
Martin Bauer's avatar
Martin Bauer committed
408
            result = 0
Martin Bauer's avatar
Martin Bauer committed
409
410
411
            for i in range(len(prod_list)):
                pre_factor = prod(prod_list[j] for j in range(len(prod_list)) if i != j)
                result += pre_factor * Diff(prod_list[i], target=expr.target, superscript=expr.superscript)
Martin Bauer's avatar
Martin Bauer committed
412
413
            return result
    else:
414
        new_args = [expand_diff_products(e) for e in expr.args]
Martin Bauer's avatar
Martin Bauer committed
415
        return expr.func(*new_args) if new_args else expr
Martin Bauer's avatar
Martin Bauer committed
416
417


418
def combine_diff_products(expr):
Martin Bauer's avatar
Martin Bauer committed
419
420
    """Inverse product rule"""

Martin Bauer's avatar
Martin Bauer committed
421
    def expr_to_diff_decomposition(expression):
Martin Bauer's avatar
Martin Bauer committed
422
        """Decomposes a sp.Add node containing CeDiffs into:
Martin Bauer's avatar
Martin Bauer committed
423
424
        diff_dict: maps (target, superscript) -> [ (pre_factor, argument), ... ]
        i.e.  a partial(b) ( a is pre-factor, b is argument)
Martin Bauer's avatar
Martin Bauer committed
425
426
            in case of partial(a) partial(b) two entries are created  (0.5 partial(a), b), (0.5 partial(b), a)
        """
427
        DiffInfo = namedtuple("DiffInfo", ["target", "superscript"])
Martin Bauer's avatar
Martin Bauer committed
428
429

        class DiffSplit:
Martin Bauer's avatar
Martin Bauer committed
430
431
            def __init__(self, fac, argument):
                self.pre_factor = fac
Martin Bauer's avatar
Martin Bauer committed
432
433
434
                self.argument = argument

            def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
435
                return str((self.pre_factor, self.argument))
Martin Bauer's avatar
Martin Bauer committed
436

Martin Bauer's avatar
Martin Bauer committed
437
438
        assert isinstance(expression, sp.Add)
        diff_dict = defaultdict(list)
Martin Bauer's avatar
Martin Bauer committed
439
        rest = 0
Martin Bauer's avatar
Martin Bauer committed
440
        for term in expression.args:
Martin Bauer's avatar
Martin Bauer committed
441
            if isinstance(term, Diff):
Martin Bauer's avatar
Martin Bauer committed
442
                diff_dict[DiffInfo(term.target, term.superscript)].append(DiffSplit(1, term.arg))
Martin Bauer's avatar
Martin Bauer committed
443
            else:
Martin Bauer's avatar
Martin Bauer committed
444
445
446
                mul_args = normalize_product(term)
                diffs = [d for d in mul_args if isinstance(d, Diff)]
                factor = prod(d for d in mul_args if not isinstance(d, Diff))
Martin Bauer's avatar
Martin Bauer committed
447
448
449
450
                if len(diffs) == 0:
                    rest += factor
                else:
                    for i, diff in enumerate(diffs):
Martin Bauer's avatar
Martin Bauer committed
451
452
453
                        all_but_current = [d for j, d in enumerate(diffs) if i != j]
                        pre_factor = factor * prod(all_but_current) * sp.Rational(1, len(diffs))
                        diff_dict[DiffInfo(diff.target, diff.superscript)].append(DiffSplit(pre_factor, diff.arg))
Martin Bauer's avatar
Martin Bauer committed
454

Martin Bauer's avatar
Martin Bauer committed
455
        return diff_dict, rest
Martin Bauer's avatar
Martin Bauer committed
456

Martin Bauer's avatar
Martin Bauer committed
457
458
459
460
461
    def match_diff_splits(own, other):
        own_fac = own.pre_factor / other.argument
        other_fac = other.pre_factor / own.argument
        count = sp.count_ops
        if count(own_fac) > count(own.pre_factor) or count(other_fac) > count(other.pre_factor):
Martin Bauer's avatar
Martin Bauer committed
462
463
            return None

Martin Bauer's avatar
Martin Bauer committed
464
465
        new_other_factor = own_fac - other_fac
        return new_other_factor
Martin Bauer's avatar
Martin Bauer committed
466

Martin Bauer's avatar
Martin Bauer committed
467
468
    def process_diff_list(diff_list, label, superscript):
        if len(diff_list) == 0:
Martin Bauer's avatar
Martin Bauer committed
469
            return 0
Martin Bauer's avatar
Martin Bauer committed
470
471
        elif len(diff_list) == 1:
            return diff_list[0].pre_factor * Diff(diff_list[0].argument, label, superscript)
Martin Bauer's avatar
Martin Bauer committed
472
473
474

        result = 0
        matches = []
Martin Bauer's avatar
Martin Bauer committed
475
476
477
478
        for i in range(1, len(diff_list)):
            match_result = match_diff_splits(diff_list[i], diff_list[0])
            if match_result is not None:
                matches.append((i, match_result))
Martin Bauer's avatar
Martin Bauer committed
479
480

        if len(matches) == 0:
Martin Bauer's avatar
Martin Bauer committed
481
            result += diff_list[0].pre_factor * Diff(diff_list[0].argument, label, superscript)
Martin Bauer's avatar
Martin Bauer committed
482
        else:
Martin Bauer's avatar
Martin Bauer committed
483
484
485
486
487
            other_idx, match_result = sorted(matches, key=lambda e: sp.count_ops(e[1]))[0]
            new_argument = diff_list[0].argument * diff_list[other_idx].argument
            result += (diff_list[0].pre_factor / diff_list[other_idx].argument) * Diff(new_argument, label, superscript)
            if match_result == 0:
                del diff_list[other_idx]
Martin Bauer's avatar
Martin Bauer committed
488
            else:
Martin Bauer's avatar
Martin Bauer committed
489
490
                diff_list[other_idx].pre_factor = match_result * diff_list[0].argument
        result += process_diff_list(diff_list[1:], label, superscript)
Martin Bauer's avatar
Martin Bauer committed
491
492
        return result

Martin Bauer's avatar
Martin Bauer committed
493
494
495
496
    def combine(expression):
        expression = expression.expand()
        if isinstance(expression, sp.Add):
            diff_dict, rest = expr_to_diff_decomposition(expression)
Martin Bauer's avatar
Martin Bauer committed
497
498
            for (label, superscript), diff_list in diff_dict.items():
                rest += process_diff_list(diff_list, label, superscript)
Martin Bauer's avatar
Martin Bauer committed
499
500
            return rest
        else:
501
            new_args = [combine_diff_products(e) for e in expression.args]
Martin Bauer's avatar
Martin Bauer committed
502
            return expression.func(*new_args) if new_args else expression
Martin Bauer's avatar
Martin Bauer committed
503

Martin Bauer's avatar
Martin Bauer committed
504
    return combine(expr)
Martin Bauer's avatar
Martin Bauer committed
505

Martin Bauer's avatar
Martin Bauer committed
506

507
508
509
510
511
512
def replace_generic_laplacian(expr, dim=None):
    """Laplacian can be written as Diff(Diff(term)) without explicitly giving the dimensions.

    This function replaces these constructs by diff(term, 0, 0) + diff(term, 1, 1) + ...
    For this to work, the arguments of the derivative have to be field or field accesses such that the spatial
    dimension can be determined.
Martin Bauer's avatar
Martin Bauer committed
513
514
515
516
517

    >>> l = Diff(Diff(sp.symbols('x')))
    >>> replace_generic_laplacian(l, 3)
    Diff(Diff(x, 0, -1), 0, -1) + Diff(Diff(x, 1, -1), 1, -1) + Diff(Diff(x, 2, -1), 2, -1)

518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
    """
    if isinstance(expr, Diff):
        arg, *indices = diff_args(expr)
        if isinstance(arg, Field.Access):
            dim = arg.field.spatial_dimensions
        assert dim is not None
        if len(indices) == 2 and all(i == -1 for i in indices):
            return sum(diff(arg, i, i) for i in range(dim))
        else:
            return expr
    else:
        new_args = [replace_generic_laplacian(a, dim) for a in expr.args]
        return expr.func(*new_args) if new_args else expr


Martin Bauer's avatar
Martin Bauer committed
533
def functional_derivative(functional, v):
534
    r"""Computes functional derivative of functional with respect to v using Euler-Lagrange equation
Martin Bauer's avatar
Martin Bauer committed
535
536
537
538
539
540

    .. math ::

        \frac{\delta F}{\delta v} =
                \frac{\partial F}{\partial v} - \nabla \cdot \frac{\partial F}{\partial \nabla v}

541
    - assumes that gradients are represented by Diff() node
Martin Bauer's avatar
Martin Bauer committed
542
543
544
545
546
    - Diff(Diff(r)) represents the divergence of r
    - the constants parameter is a list with symbols not affected by the derivative. This is used for simplification
      of the derivative terms.
    """
    diffs = functional.atoms(Diff)
Martin Bauer's avatar
Martin Bauer committed
547
548
549
550
    bulk_substitutions = {d: sp.Dummy() for d in diffs}
    bulk_substitutions_inverse = {v: k for k, v in bulk_substitutions.items()}
    non_diff_part = functional.subs(bulk_substitutions)
    partial_f_partial_v = sp.diff(non_diff_part, v).subs(bulk_substitutions_inverse)
Martin Bauer's avatar
Martin Bauer committed
551

Martin Bauer's avatar
Martin Bauer committed
552
    gradient_part = 0
Martin Bauer's avatar
Martin Bauer committed
553
554
    for diff_obj in diffs:
        if diff_obj.args[0] != v:
Martin Bauer's avatar
Martin Bauer committed
555
556
            continue
        dummy = sp.Dummy()
Martin Bauer's avatar
Martin Bauer committed
557
558
        partial_f_partial_grad_v = functional.subs(diff_obj, dummy).diff(dummy).subs(dummy, diff_obj)
        gradient_part += Diff(partial_f_partial_grad_v, target=diff_obj.target, superscript=diff_obj.superscript)
Martin Bauer's avatar
Martin Bauer committed
559

Martin Bauer's avatar
Martin Bauer committed
560
    result = partial_f_partial_v - gradient_part
561
    return result