derivative.py 20.9 KB
Newer Older
Martin Bauer's avatar
Martin Bauer committed
1
2
from collections import defaultdict, namedtuple

Martin Bauer's avatar
Martin Bauer committed
3
import sympy as sp
Martin Bauer's avatar
Martin Bauer committed
4
5

from pystencils.field import Field
Martin Bauer's avatar
Martin Bauer committed
6
from pystencils.sympyextensions import normalize_product, prod
Martin Bauer's avatar
Martin Bauer committed
7
8


9
def _default_diff_sort_key(d):
10
    return str(d.superscript), str(d.target)
Martin Bauer's avatar
Martin Bauer committed
11
12
13


class Diff(sp.Expr):
14
15
16
    """Sympy Node representing a derivative.

    The difference to sympy's built in differential is:
Martin Bauer's avatar
Martin Bauer committed
17
18
        - shortened latex representation
        - all simplifications have to be done manually
19
        - optional marker displayed as superscript
Martin Bauer's avatar
Martin Bauer committed
20
21
22
    """
    is_number = False
    is_Rational = False
23
    _diff_wrt = True
Martin Bauer's avatar
Martin Bauer committed
24

Martin Bauer's avatar
Martin Bauer committed
25
    def __new__(cls, argument, target=-1, superscript=-1):
Martin Bauer's avatar
Martin Bauer committed
26
27
        if argument == 0:
            return sp.Rational(0, 1)
28
29
        if isinstance(argument, Field):
            argument = argument.center
Martin Bauer's avatar
Martin Bauer committed
30
        return sp.Expr.__new__(cls, argument.expand(), sp.sympify(target), sp.sympify(superscript))
Martin Bauer's avatar
Martin Bauer committed
31
32
33

    @property
    def is_commutative(self):
Martin Bauer's avatar
Martin Bauer committed
34
35
        any_non_commutative = any(not s.is_commutative for s in self.atoms(sp.Symbol))
        if any_non_commutative:
Martin Bauer's avatar
Martin Bauer committed
36
37
38
39
            return False
        else:
            return True

Martin Bauer's avatar
Martin Bauer committed
40
    def get_arg_recursive(self):
Martin Bauer's avatar
Martin Bauer committed
41
42
43
44
        """Returns the argument the derivative acts on, for nested derivatives the inner argument is returned"""
        if not isinstance(self.arg, Diff):
            return self.arg
        else:
Martin Bauer's avatar
Martin Bauer committed
45
            return self.arg.get_arg_recursive()
Martin Bauer's avatar
Martin Bauer committed
46

Martin Bauer's avatar
Martin Bauer committed
47
48
49
    def change_arg_recursive(self, new_arg):
        """Returns a Diff node with the given 'new_arg' instead of the current argument. For nested derivatives
        a new nested derivative is returned where the inner Diff has the 'new_arg'"""
Martin Bauer's avatar
Martin Bauer committed
50
        if not isinstance(self.arg, Diff):
Martin Bauer's avatar
Martin Bauer committed
51
            return Diff(new_arg, self.target, self.superscript)
Martin Bauer's avatar
Martin Bauer committed
52
        else:
Martin Bauer's avatar
Martin Bauer committed
53
            return Diff(self.arg.change_arg_recursive(new_arg), self.target, self.superscript)
Martin Bauer's avatar
Martin Bauer committed
54

Martin Bauer's avatar
Martin Bauer committed
55
    def split_linear(self, functions):
Martin Bauer's avatar
Martin Bauer committed
56
57
58
59
60
61
62
63
64
65
        """
        Applies linearity property of Diff: i.e.  'Diff(c*a+b)' is transformed to 'c * Diff(a) + Diff(b)'
        The parameter functions is a list of all symbols that are considered functions, not constants.
        For the example above: functions=[a, b]
        """
        constant, variable = 1, 1

        if self.arg.func != sp.Mul:
            constant, variable = 1, self.arg
        else:
Martin Bauer's avatar
Martin Bauer committed
66
            for factor in normalize_product(self.arg):
Martin Bauer's avatar
Martin Bauer committed
67
68
69
70
71
72
73
74
75
76
77
                if factor in functions or isinstance(factor, Diff):
                    variable *= factor
                else:
                    constant *= factor

        if isinstance(variable, sp.Symbol) and variable not in functions:
            return 0

        if isinstance(variable, int) or variable.is_number:
            return 0
        else:
78
            return constant * Diff(variable, target=self.target, superscript=self.superscript)
Martin Bauer's avatar
Martin Bauer committed
79
80
81
82
83
84
85

    @property
    def arg(self):
        """Expression the derivative acts on"""
        return self.args[0]

    @property
86
    def target(self):
Martin Bauer's avatar
Martin Bauer committed
87
88
89
90
        """Subscript, usually the variable the Diff is w.r.t. """
        return self.args[1]

    @property
91
    def superscript(self):
Martin Bauer's avatar
Martin Bauer committed
92
        """Superscript, for example used as the Chapman-Enskog order index"""
Martin Bauer's avatar
Martin Bauer committed
93
94
        return self.args[2]

Martin Bauer's avatar
Martin Bauer committed
95
    def _latex(self, printer, *_):
Martin Bauer's avatar
Martin Bauer committed
96
        result = r"{\partial"
97
98
99
100
        if self.superscript >= 0:
            result += "^{(%s)}" % (self.superscript,)
        if self.target != -1:
            result += "_{%s}" % (printer.doprint(self.target),)
Martin Bauer's avatar
Martin Bauer committed
101
102
103
104
105
106
107
108
109
110
111

        contents = printer.doprint(self.arg)
        if isinstance(self.arg, int) or isinstance(self.arg, sp.Symbol) or self.arg.is_number or self.arg.func == Diff:
            result += " " + contents
        else:
            result += " (" + contents + ") "

        result += "}"
        return result

    def __str__(self):
112
        return f"D({self.arg})"
Martin Bauer's avatar
Martin Bauer committed
113

114
    def interpolated_access(self, offset, **kwargs):
115
116
117
118
119
        """Represents an interpolated access on a spatially differentiated field

        Args:
            offset (Tuple[sympy.Expr]): Absolute position to determine the value of the spatial derivative
        """
120
        from pystencils.interpolation_astnodes import DiffInterpolatorAccess
121
122
        assert isinstance(self.arg.field, Field), "Must be field to enable interpolated accesses"
        return DiffInterpolatorAccess(self.arg.field.interpolated_access(offset, **kwargs).symbol, self.target, *offset)
123

Martin Bauer's avatar
Martin Bauer committed
124

125
class DiffOperator(sp.Expr):
126
127
128
129
    """Un-applied differential, i.e. differential operator

    Args:
        target: the differential is w.r.t to this variable.
130
131
                 This target is mainly for display purposes (its the subscript) and to distinguish DiffOperators
                 If the target is '-1' no subscript is displayed
132
133
134
        superscript: optional marker displayed as superscript
                     is not displayed if set to '-1'

135
136
137
138
139
140
141
    The DiffOperator behaves much like a variable with special name. Its main use is to be applied later, using the
    DiffOperator.apply(expr, arg) which transforms 'DiffOperator's to applied 'Diff's
    """
    is_commutative = True
    is_number = False
    is_Rational = False

Martin Bauer's avatar
Martin Bauer committed
142
143
    def __new__(cls, target=-1, superscript=-1):
        return sp.Expr.__new__(cls, sp.sympify(target), sp.sympify(superscript))
144
145
146
147
148
149
150
151
152

    @property
    def target(self):
        return self.args[0]

    @property
    def superscript(self):
        return self.args[1]

Martin Bauer's avatar
Martin Bauer committed
153
    def _latex(self, *_):
Martin Bauer's avatar
Martin Bauer committed
154
        result = r"{\partial"
155
156
157
158
159
160
161
162
        if self.superscript >= 0:
            result += "^{(%s)}" % (self.superscript,)
        if self.target != -1:
            result += "_{%s}" % (self.target,)
        result += "}"
        return result

    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
163
    def apply(expr, argument, apply_to_constants=True):
164
165
166
167
168
        """
        Returns a new expression where each 'DiffOperator' is replaced by a 'Diff' node.
        Multiplications of 'DiffOperator's are interpreted as nested application of differentiation:
        i.e. DiffOperator('x')*DiffOperator('x') is a second derivative replaced by Diff(Diff(arg, x), t)
        """
Martin Bauer's avatar
Martin Bauer committed
169

Martin Bauer's avatar
Martin Bauer committed
170
        def handle_mul(mul):
Martin Bauer's avatar
Martin Bauer committed
171
            args = normalize_product(mul)
172
173
            diffs = [a for a in args if isinstance(a, DiffOperator)]
            if len(diffs) == 0:
Martin Bauer's avatar
Martin Bauer committed
174
                return mul * argument if apply_to_constants else mul
175
            rest = [a for a in args if not isinstance(a, DiffOperator)]
176
            diffs.sort(key=_default_diff_sort_key)
177
178
179
180
181
182
183
            result = argument
            for d in reversed(diffs):
                result = Diff(result, target=d.target, superscript=d.superscript)
            return prod(rest) * result

        expr = expr.expand()
        if expr.func == sp.Mul or expr.func == sp.Pow:
Martin Bauer's avatar
Martin Bauer committed
184
            return handle_mul(expr)
185
        elif expr.func == sp.Add:
Martin Bauer's avatar
Martin Bauer committed
186
            return expr.func(*[handle_mul(a) for a in expr.args])
187
        else:
Martin Bauer's avatar
Martin Bauer committed
188
            return expr * argument if apply_to_constants else expr
189

Martin Bauer's avatar
Martin Bauer committed
190

Martin Bauer's avatar
Martin Bauer committed
191
192
# ----------------------------------------------------------------------------------------------------------------------

193

194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
def diff(expr, *args):
    """Shortcut function to create nested derivatives

    >>> f = sp.Symbol("f")
    >>> diff(f, 0, 0, 1) == Diff(Diff( Diff(f, 1), 0), 0)
    True
    """
    if len(args) == 0:
        return expr
    result = expr
    for index in reversed(args):
        result = Diff(result, index)
    return result


def diff_args(expr):
    """Extracts the indices and argument of possibly nested derivative - inverse of diff function

    >>> args = (sp.Symbol("x"), 0, 1, 2, 5, 1)
    >>> e = diff(*args)
    >>> assert diff_args(e) == args
    """
    if not isinstance(expr, Diff):
        return expr,
    else:
        inner_res = diff_args(expr.args[0])
        return (inner_res[0], expr.args[1], *inner_res[1:])


223
224
225
def diff_terms(expr):
    """Returns set of all derivatives in an expression.

226
    This function yields different results than 'expr.atoms(Diff)' when nested derivatives are in the expression,
Martin Bauer's avatar
Martin Bauer committed
227
    since this function only returns the outer derivatives
228
229
230

    Example:
        >>> x, y = sp.symbols("x, y")
231
232
233
        >>> diff_terms( diff(x, 0, 0) )
        {Diff(Diff(x, 0, -1), 0, -1)}
        >>> diff_terms( diff(x, 0, 0) + y )
234
        {Diff(Diff(x, 0, -1), 0, -1)}
Martin Bauer's avatar
Martin Bauer committed
235
236
237
238
239
240
241
242
243
    """
    result = set()

    def visit(e):
        if isinstance(e, Diff):
            result.add(e)
        else:
            for a in e.args:
                visit(a)
Martin Bauer's avatar
Martin Bauer committed
244

Martin Bauer's avatar
Martin Bauer committed
245
246
247
248
    visit(expr)
    return result


249
def collect_diffs(expr):
Martin Bauer's avatar
Martin Bauer committed
250
    """Rewrites expression into a sum of distinct derivatives with pre-factors"""
251
    return expr.collect(diff_terms(expr))
Martin Bauer's avatar
Martin Bauer committed
252
253


254
def zero_diffs(expr, label):
Martin Bauer's avatar
Martin Bauer committed
255
256
257
258
259
260
261
262
    """Replaces all differentials with the given target by 0

    Example:
        >>> x, y, f = sp.symbols("x y f")
        >>> expression = Diff(f, x) + Diff(f, y) + Diff(Diff(f, y), x) + 7
        >>> zero_diffs(expression, x)
        Diff(f, y, -1) + 7
    """
263
264
265
266
267
268
269
270
271
272
273
274
275

    def visit(e):
        if isinstance(e, Diff):
            if e.target == label:
                return 0
        new_args = [visit(arg) for arg in e.args]
        return e.func(*new_args) if new_args else e

    return visit(expr)


def evaluate_diffs(expr, var=None):
    """Replaces pystencils diff objects by sympy diff objects and evaluates them.
Martin Bauer's avatar
Martin Bauer committed
276

277
278
279
    Replaces Diff nodes by sp.diff , the free variable is either the target (if var=None) otherwise
    the specified var
    """
Martin Bauer's avatar
Martin Bauer committed
280
    if isinstance(expr, Diff):
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
        if var is None:
            var = expr.target
        return sp.diff(evaluate_diffs(expr.arg, var), var)
    else:
        new_args = [evaluate_diffs(arg, var) for arg in expr.args]
        return expr.func(*new_args) if new_args else expr


def normalize_diff_order(expression, functions=None, constants=None, sort_key=_default_diff_sort_key):
    """Assumes order of differentiation can be exchanged. Changes the order of nested Diffs to a standard order defined
    by the sorting key 'sort_key' such that the derivative terms can be further simplified """

    def visit(expr):
        if isinstance(expr, Diff):
            nodes = [expr]
            while isinstance(nodes[-1].arg, Diff):
                nodes.append(nodes[-1].arg)

            processed_arg = visit(nodes[-1].arg)
            nodes.sort(key=sort_key)

            result = processed_arg
            for d in reversed(nodes):
                result = Diff(result, target=d.target, superscript=d.superscript)
Martin Bauer's avatar
Martin Bauer committed
305
306
            return result
        else:
307
308
309
310
311
            new_args = [visit(e) for e in expr.args]
            return expr.func(*new_args) if new_args else expr

    expression = expand_diff_linear(expression.expand(), functions, constants).expand()
    return visit(expression)
Martin Bauer's avatar
Martin Bauer committed
312
313


314
def expand_diff_full(expr, functions=None, constants=None):
Martin Bauer's avatar
Martin Bauer committed
315
316
317
318
319
320
    if functions is None:
        functions = expr.atoms(sp.Symbol)
        if constants is not None:
            functions.difference_update(constants)

    def visit(e):
Michael Kuron's avatar
Michael Kuron committed
321
322
        if not isinstance(e, sp.Tuple):
            e = e.expand()
Martin Bauer's avatar
Martin Bauer committed
323
324
325

        if e.func == Diff:
            result = 0
Martin Bauer's avatar
Martin Bauer committed
326
327
328
            diff_args = {'target': e.target, 'superscript': e.superscript}
            diff_inner = e.args[0]
            diff_inner = visit(diff_inner)
Martin Bauer's avatar
Martin Bauer committed
329
330
            if diff_inner.func not in (sp.Add, sp.Mul):
                return e
Martin Bauer's avatar
Martin Bauer committed
331
332
333
            for term in diff_inner.args if diff_inner.func == sp.Add else [diff_inner]:
                independent_terms = 1
                dependent_terms = []
Martin Bauer's avatar
Martin Bauer committed
334
                for factor in normalize_product(term):
Martin Bauer's avatar
Martin Bauer committed
335
                    if factor in functions or isinstance(factor, Diff):
Martin Bauer's avatar
Martin Bauer committed
336
                        dependent_terms.append(factor)
Martin Bauer's avatar
Martin Bauer committed
337
                    else:
Martin Bauer's avatar
Martin Bauer committed
338
339
340
                        independent_terms *= factor
                for i in range(len(dependent_terms)):
                    dependent_term = dependent_terms[i]
Martin Bauer's avatar
Martin Bauer committed
341
                    other_dependent_terms = dependent_terms[:i] + dependent_terms[i + 1:]
Martin Bauer's avatar
Martin Bauer committed
342
343
                    processed_diff = normalize_diff_order(Diff(dependent_term, **diff_args))
                    result += independent_terms * prod(other_dependent_terms) * processed_diff
Martin Bauer's avatar
Martin Bauer committed
344
            return result
345
346
        elif isinstance(e, sp.Piecewise):
            return sp.Piecewise(*((expand_diff_full(a, functions, constants), b) for a, b in e.args))
Michael Kuron's avatar
Michael Kuron committed
347
348
349
        elif isinstance(expr, sp.Tuple):
            new_args = [visit(arg) for arg in e.args]
            return sp.Tuple(*new_args)
Martin Bauer's avatar
Martin Bauer committed
350
        else:
Martin Bauer's avatar
Martin Bauer committed
351
352
            new_args = [visit(arg) for arg in e.args]
            return e.func(*new_args) if new_args else e
Martin Bauer's avatar
Martin Bauer committed
353
354
355
356
357
358
359

    if isinstance(expr, sp.Matrix):
        return expr.applyfunc(visit)
    else:
        return visit(expr)


360
361
def expand_diff_linear(expr, functions=None, constants=None):
    """Expands all derivative nodes by applying Diff.split_linear
Martin Bauer's avatar
Martin Bauer committed
362

363
364
365
366
367
368
369
370
371
372
    Args:
        expr: expression containing derivatives
        functions: sequence of symbols that are considered functions and can not be pulled before the derivative.
                   if None, all symbols are viewed as functions
        constants: sequence of symbols which are considered constants and can be pulled before the derivative
    """
    if functions is None:
        functions = expr.atoms(sp.Symbol)
        if constants is not None:
            functions.difference_update(constants)
Martin Bauer's avatar
Martin Bauer committed
373

374
375
376
377
378
379
    if isinstance(expr, Diff):
        arg = expand_diff_linear(expr.arg, functions)
        if hasattr(arg, 'func') and arg.func == sp.Add:
            result = 0
            for a in arg.args:
                result += Diff(a, target=expr.target, superscript=expr.superscript).split_linear(functions)
Martin Bauer's avatar
Martin Bauer committed
380
381
            return result
        else:
382
383
384
385
386
            diff = Diff(arg, target=expr.target, superscript=expr.superscript)
            if diff == 0:
                return 0
            else:
                return diff.split_linear(functions)
387
388
    elif isinstance(expr, sp.Piecewise):
        return sp.Piecewise(*((expand_diff_linear(a, functions, constants), b) for a, b in expr.args))
Michael Kuron's avatar
Michael Kuron committed
389
390
391
    elif isinstance(expr, sp.Tuple):
        new_args = [expand_diff_linear(e, functions) for e in expr.args]
        return sp.Tuple(*new_args)
392
393
394
395
    else:
        new_args = [expand_diff_linear(e, functions) for e in expr.args]
        result = sp.expand(expr.func(*new_args) if new_args else expr)
        return result
Martin Bauer's avatar
Martin Bauer committed
396
397


398
def expand_diff_products(expr):
Martin Bauer's avatar
Martin Bauer committed
399
400
    """Fully expands all derivatives by applying product rule"""
    if isinstance(expr, Diff):
401
        arg = expand_diff_products(expr.args[0])
Martin Bauer's avatar
Martin Bauer committed
402
        if arg.func == sp.Add:
Martin Bauer's avatar
Martin Bauer committed
403
404
405
            new_args = [Diff(e, target=expr.target, superscript=expr.superscript)
                        for e in arg.args]
            return sp.Add(*new_args)
Martin Bauer's avatar
Martin Bauer committed
406
        if arg.func not in (sp.Mul, sp.Pow):
407
            return Diff(arg, target=expr.target, superscript=expr.superscript)
Martin Bauer's avatar
Martin Bauer committed
408
        else:
Martin Bauer's avatar
Martin Bauer committed
409
            prod_list = normalize_product(arg)
Martin Bauer's avatar
Martin Bauer committed
410
            result = 0
Martin Bauer's avatar
Martin Bauer committed
411
412
413
            for i in range(len(prod_list)):
                pre_factor = prod(prod_list[j] for j in range(len(prod_list)) if i != j)
                result += pre_factor * Diff(prod_list[i], target=expr.target, superscript=expr.superscript)
Martin Bauer's avatar
Martin Bauer committed
414
415
            return result
    else:
416
        new_args = [expand_diff_products(e) for e in expr.args]
Martin Bauer's avatar
Martin Bauer committed
417
        return expr.func(*new_args) if new_args else expr
Martin Bauer's avatar
Martin Bauer committed
418
419


420
def combine_diff_products(expr):
Martin Bauer's avatar
Martin Bauer committed
421
422
    """Inverse product rule"""

Martin Bauer's avatar
Martin Bauer committed
423
    def expr_to_diff_decomposition(expression):
Martin Bauer's avatar
Martin Bauer committed
424
        """Decomposes a sp.Add node containing CeDiffs into:
Martin Bauer's avatar
Martin Bauer committed
425
426
        diff_dict: maps (target, superscript) -> [ (pre_factor, argument), ... ]
        i.e.  a partial(b) ( a is pre-factor, b is argument)
Martin Bauer's avatar
Martin Bauer committed
427
428
            in case of partial(a) partial(b) two entries are created  (0.5 partial(a), b), (0.5 partial(b), a)
        """
429
        DiffInfo = namedtuple("DiffInfo", ["target", "superscript"])
Martin Bauer's avatar
Martin Bauer committed
430
431

        class DiffSplit:
Martin Bauer's avatar
Martin Bauer committed
432
433
            def __init__(self, fac, argument):
                self.pre_factor = fac
Martin Bauer's avatar
Martin Bauer committed
434
435
436
                self.argument = argument

            def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
437
                return str((self.pre_factor, self.argument))
Martin Bauer's avatar
Martin Bauer committed
438

Martin Bauer's avatar
Martin Bauer committed
439
440
        assert isinstance(expression, sp.Add)
        diff_dict = defaultdict(list)
Martin Bauer's avatar
Martin Bauer committed
441
        rest = 0
Martin Bauer's avatar
Martin Bauer committed
442
        for term in expression.args:
Martin Bauer's avatar
Martin Bauer committed
443
            if isinstance(term, Diff):
Martin Bauer's avatar
Martin Bauer committed
444
                diff_dict[DiffInfo(term.target, term.superscript)].append(DiffSplit(1, term.arg))
Martin Bauer's avatar
Martin Bauer committed
445
            else:
Martin Bauer's avatar
Martin Bauer committed
446
447
448
                mul_args = normalize_product(term)
                diffs = [d for d in mul_args if isinstance(d, Diff)]
                factor = prod(d for d in mul_args if not isinstance(d, Diff))
Martin Bauer's avatar
Martin Bauer committed
449
450
451
452
                if len(diffs) == 0:
                    rest += factor
                else:
                    for i, diff in enumerate(diffs):
Martin Bauer's avatar
Martin Bauer committed
453
454
455
                        all_but_current = [d for j, d in enumerate(diffs) if i != j]
                        pre_factor = factor * prod(all_but_current) * sp.Rational(1, len(diffs))
                        diff_dict[DiffInfo(diff.target, diff.superscript)].append(DiffSplit(pre_factor, diff.arg))
Martin Bauer's avatar
Martin Bauer committed
456

Martin Bauer's avatar
Martin Bauer committed
457
        return diff_dict, rest
Martin Bauer's avatar
Martin Bauer committed
458

Martin Bauer's avatar
Martin Bauer committed
459
460
461
462
463
    def match_diff_splits(own, other):
        own_fac = own.pre_factor / other.argument
        other_fac = other.pre_factor / own.argument
        count = sp.count_ops
        if count(own_fac) > count(own.pre_factor) or count(other_fac) > count(other.pre_factor):
Martin Bauer's avatar
Martin Bauer committed
464
465
            return None

Martin Bauer's avatar
Martin Bauer committed
466
467
        new_other_factor = own_fac - other_fac
        return new_other_factor
Martin Bauer's avatar
Martin Bauer committed
468

Martin Bauer's avatar
Martin Bauer committed
469
470
    def process_diff_list(diff_list, label, superscript):
        if len(diff_list) == 0:
Martin Bauer's avatar
Martin Bauer committed
471
            return 0
Martin Bauer's avatar
Martin Bauer committed
472
473
        elif len(diff_list) == 1:
            return diff_list[0].pre_factor * Diff(diff_list[0].argument, label, superscript)
Martin Bauer's avatar
Martin Bauer committed
474
475
476

        result = 0
        matches = []
Martin Bauer's avatar
Martin Bauer committed
477
478
479
480
        for i in range(1, len(diff_list)):
            match_result = match_diff_splits(diff_list[i], diff_list[0])
            if match_result is not None:
                matches.append((i, match_result))
Martin Bauer's avatar
Martin Bauer committed
481
482

        if len(matches) == 0:
Martin Bauer's avatar
Martin Bauer committed
483
            result += diff_list[0].pre_factor * Diff(diff_list[0].argument, label, superscript)
Martin Bauer's avatar
Martin Bauer committed
484
        else:
Martin Bauer's avatar
Martin Bauer committed
485
486
487
488
489
            other_idx, match_result = sorted(matches, key=lambda e: sp.count_ops(e[1]))[0]
            new_argument = diff_list[0].argument * diff_list[other_idx].argument
            result += (diff_list[0].pre_factor / diff_list[other_idx].argument) * Diff(new_argument, label, superscript)
            if match_result == 0:
                del diff_list[other_idx]
Martin Bauer's avatar
Martin Bauer committed
490
            else:
Martin Bauer's avatar
Martin Bauer committed
491
492
                diff_list[other_idx].pre_factor = match_result * diff_list[0].argument
        result += process_diff_list(diff_list[1:], label, superscript)
Martin Bauer's avatar
Martin Bauer committed
493
494
        return result

Martin Bauer's avatar
Martin Bauer committed
495
496
497
498
    def combine(expression):
        expression = expression.expand()
        if isinstance(expression, sp.Add):
            diff_dict, rest = expr_to_diff_decomposition(expression)
Martin Bauer's avatar
Martin Bauer committed
499
500
            for (label, superscript), diff_list in diff_dict.items():
                rest += process_diff_list(diff_list, label, superscript)
Martin Bauer's avatar
Martin Bauer committed
501
502
            return rest
        else:
503
            new_args = [combine_diff_products(e) for e in expression.args]
Martin Bauer's avatar
Martin Bauer committed
504
            return expression.func(*new_args) if new_args else expression
Martin Bauer's avatar
Martin Bauer committed
505

Martin Bauer's avatar
Martin Bauer committed
506
    return combine(expr)
Martin Bauer's avatar
Martin Bauer committed
507

Martin Bauer's avatar
Martin Bauer committed
508

509
510
511
512
513
514
def replace_generic_laplacian(expr, dim=None):
    """Laplacian can be written as Diff(Diff(term)) without explicitly giving the dimensions.

    This function replaces these constructs by diff(term, 0, 0) + diff(term, 1, 1) + ...
    For this to work, the arguments of the derivative have to be field or field accesses such that the spatial
    dimension can be determined.
Martin Bauer's avatar
Martin Bauer committed
515
516
517
518
519

    >>> l = Diff(Diff(sp.symbols('x')))
    >>> replace_generic_laplacian(l, 3)
    Diff(Diff(x, 0, -1), 0, -1) + Diff(Diff(x, 1, -1), 1, -1) + Diff(Diff(x, 2, -1), 2, -1)

520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
    """
    if isinstance(expr, Diff):
        arg, *indices = diff_args(expr)
        if isinstance(arg, Field.Access):
            dim = arg.field.spatial_dimensions
        assert dim is not None
        if len(indices) == 2 and all(i == -1 for i in indices):
            return sum(diff(arg, i, i) for i in range(dim))
        else:
            return expr
    else:
        new_args = [replace_generic_laplacian(a, dim) for a in expr.args]
        return expr.func(*new_args) if new_args else expr


Martin Bauer's avatar
Martin Bauer committed
535
def functional_derivative(functional, v):
536
    r"""Computes functional derivative of functional with respect to v using Euler-Lagrange equation
Martin Bauer's avatar
Martin Bauer committed
537
538
539
540
541
542

    .. math ::

        \frac{\delta F}{\delta v} =
                \frac{\partial F}{\partial v} - \nabla \cdot \frac{\partial F}{\partial \nabla v}

543
    - assumes that gradients are represented by Diff() node
Martin Bauer's avatar
Martin Bauer committed
544
545
546
547
548
    - Diff(Diff(r)) represents the divergence of r
    - the constants parameter is a list with symbols not affected by the derivative. This is used for simplification
      of the derivative terms.
    """
    diffs = functional.atoms(Diff)
Martin Bauer's avatar
Martin Bauer committed
549
550
551
552
    bulk_substitutions = {d: sp.Dummy() for d in diffs}
    bulk_substitutions_inverse = {v: k for k, v in bulk_substitutions.items()}
    non_diff_part = functional.subs(bulk_substitutions)
    partial_f_partial_v = sp.diff(non_diff_part, v).subs(bulk_substitutions_inverse)
Martin Bauer's avatar
Martin Bauer committed
553

Martin Bauer's avatar
Martin Bauer committed
554
    gradient_part = 0
Martin Bauer's avatar
Martin Bauer committed
555
556
    for diff_obj in diffs:
        if diff_obj.args[0] != v:
Martin Bauer's avatar
Martin Bauer committed
557
558
            continue
        dummy = sp.Dummy()
Martin Bauer's avatar
Martin Bauer committed
559
560
        partial_f_partial_grad_v = functional.subs(diff_obj, dummy).diff(dummy).subs(dummy, diff_obj)
        gradient_part += Diff(partial_f_partial_grad_v, target=diff_obj.target, superscript=diff_obj.superscript)
Martin Bauer's avatar
Martin Bauer committed
561

Martin Bauer's avatar
Martin Bauer committed
562
    result = partial_f_partial_v - gradient_part
563
    return result