sympyextensions.py 22.2 KB
Newer Older
1
import itertools
Martin Bauer's avatar
Martin Bauer committed
2
import operator
Martin Bauer's avatar
Martin Bauer committed
3
4
5
6
7
import warnings
from collections import Counter, defaultdict
from functools import partial, reduce
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union

Martin Bauer's avatar
Martin Bauer committed
8
import sympy as sp
Martin Bauer's avatar
Martin Bauer committed
9
from sympy.functions import Abs
Martin Bauer's avatar
Martin Bauer committed
10

11
from pystencils.assignment import Assignment
Martin Bauer's avatar
Martin Bauer committed
12
from pystencils.data_types import cast_func, get_base_type, get_type_of_expression
13

Martin Bauer's avatar
Martin Bauer committed
14
15
T = TypeVar('T')

16

Martin Bauer's avatar
Martin Bauer committed
17
def prod(seq: Iterable[T]) -> T:
18
19
20
21
    """Takes a sequence and returns the product of all elements"""
    return reduce(operator.mul, seq, 1)


22
23
24
25
26
27
28
29
30
31
32
33
34
35
def remove_small_floats(expr, threshold):
    """Removes all sp.Float objects whose absolute value is smaller than threshold

    >>> expr = sp.sympify("x + 1e-15 * y")
    >>> remove_small_floats(expr, 1e-14)
    x
    """
    if isinstance(expr, sp.Float) and sp.Abs(expr) < threshold:
        return 0
    else:
        new_args = [remove_small_floats(c, threshold) for c in expr.args]
        return expr.func(*new_args) if new_args else expr


Martin Bauer's avatar
Martin Bauer committed
36
37
def is_integer_sequence(sequence: Iterable) -> bool:
    """Checks if all elements of the passed sequence can be cast to integers"""
38
    try:
Martin Bauer's avatar
Martin Bauer committed
39
40
        for i in sequence:
            int(i)
41
42
43
44
45
        return True
    except TypeError:
        return False


Martin Bauer's avatar
Martin Bauer committed
46
47
def scalar_product(a: Iterable[T], b: Iterable[T]) -> T:
    """Scalar product between two sequences."""
48
49
50
    return sum(a_i * b_i for a_i, b_i in zip(a, b))


Martin Bauer's avatar
Martin Bauer committed
51
52
def kronecker_delta(*args):
    """Kronecker delta for variable number of arguments, 1 if all args are equal, otherwise 0"""
Martin Bauer's avatar
Martin Bauer committed
53
54
55
56
57
58
    for a in args:
        if a != args[0]:
            return 0
    return 1


59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def tanh_step_function_approximation(x, step_location, kind='right', steepness=0.0001):
    """Approximation of step function by a tanh function

    >>> tanh_step_function_approximation(1.2, step_location=1.0, kind='right')
    1.00000000000000
    >>> tanh_step_function_approximation(0.9, step_location=1.0, kind='right')
    0
    >>> tanh_step_function_approximation(1.1, step_location=1.0, kind='left')
    0
    >>> tanh_step_function_approximation(0.9, step_location=1.0, kind='left')
    1.00000000000000
    >>> tanh_step_function_approximation(0.5, step_location=(0, 1), kind='middle')
    1
    """
    if kind == 'left':
        return (1 - sp.tanh((x - step_location) / steepness)) / 2
    elif kind == 'right':
        return (1 + sp.tanh((x - step_location) / steepness)) / 2
    elif kind == 'middle':
        x1, x2 = step_location
Martin Bauer's avatar
Martin Bauer committed
79
80
        return 1 - (tanh_step_function_approximation(x, x1, 'left', steepness)
                    + tanh_step_function_approximation(x, x2, 'right', steepness))
81
82


Martin Bauer's avatar
Martin Bauer committed
83
84
def multidimensional_sum(i, dim):
    """Multidimensional summation
Martin Bauer's avatar
Martin Bauer committed
85

Martin Bauer's avatar
Martin Bauer committed
86
87
88
    Example:
        >>> list(multidimensional_sum(2, dim=3))
        [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2)]
89
    """
Martin Bauer's avatar
Martin Bauer committed
90
91
92
93
94
95
96
97
98
99
100
101
102
    prod_args = [range(dim)] * i
    return itertools.product(*prod_args)


def normalize_product(product: sp.Expr) -> List[sp.Expr]:
    """Expects a sympy expression that can be interpreted as a product and returns a list of all factors.

    Removes sp.Pow nodes that have integer exponent by representing them as single factors in list.

    Returns:
        * for a Mul node list of factors ('args')
        * for a Pow node with positive integer exponent a list of factors
        * for other node types [product] is returned
103
    """
Martin Bauer's avatar
Martin Bauer committed
104
    def handle_pow(power):
105
106
107
108
109
        if power.exp.is_integer and power.exp.is_number and power.exp > 0:
            return [power.base] * power.exp
        else:
            return [power]

Martin Bauer's avatar
Martin Bauer committed
110
111
112
    if isinstance(product, sp.Pow):
        return handle_pow(product)
    elif isinstance(product, sp.Mul):
113
114
115
        result = []
        for a in product.args:
            if a.func == sp.Pow:
Martin Bauer's avatar
Martin Bauer committed
116
                result += handle_pow(a)
117
118
119
120
121
122
123
            else:
                result.append(a)
        return result
    else:
        return [product]


Martin Bauer's avatar
Martin Bauer committed
124
125
126
127
128
129
130
131
132
def symmetric_product(*args, with_diagonal: bool = True) -> Iterable:
    """Similar to itertools.product but yields only values where the index is ascending i.e. values below/up to diagonal

    Examples:
        >>> list(symmetric_product([1, 2, 3], ['a', 'b', 'c']))
        [(1, 'a'), (1, 'b'), (1, 'c'), (2, 'b'), (2, 'c'), (3, 'c')]
        >>> list(symmetric_product([1, 2, 3], ['a', 'b', 'c'], with_diagonal=False))
        [(1, 'b'), (1, 'c'), (2, 'c')]
    """
133
134
    ranges = [range(len(a)) for a in args]
    for idx in itertools.product(*ranges):
Martin Bauer's avatar
Martin Bauer committed
135
        valid_index = True
136
        for t in range(1, len(idx)):
Martin Bauer's avatar
Martin Bauer committed
137
138
            if (with_diagonal and idx[t - 1] > idx[t]) or (not with_diagonal and idx[t - 1] >= idx[t]):
                valid_index = False
139
                break
Martin Bauer's avatar
Martin Bauer committed
140
        if valid_index:
141
142
143
            yield tuple(a[i] for a, i in zip(args, idx))


Martin Bauer's avatar
Martin Bauer committed
144
def fast_subs(expression: T, substitutions: Dict,
Martin Bauer's avatar
Martin Bauer committed
145
              skip: Optional[Callable[[sp.Expr], bool]] = None) -> T:
146
    """Similar to sympy subs function.
Martin Bauer's avatar
Martin Bauer committed
147
148
149
150
151
152
153
154
155
156
157
158

    Args:
        expression: expression where parts should be substituted
        substitutions: dict defining substitutions by mapping from old to new terms
        skip: function that marks expressions to be skipped (if True is returned) - that means that in these skipped
              expressions no substitutions are done

    This version is much faster for big substitution dictionaries than sympy version
    """
    if type(expression) is sp.Matrix:
        return expression.copy().applyfunc(partial(fast_subs, substitutions=substitutions))

159
    def visit(expr):
160
161
        if skip and skip(expr):
            return expr
Martin Bauer's avatar
Martin Bauer committed
162
        if hasattr(expr, "fast_subs"):
163
            return expr.fast_subs(substitutions, skip)
Martin Bauer's avatar
Martin Bauer committed
164
165
        if expr in substitutions:
            return substitutions[expr]
166
167
        if not hasattr(expr, 'args'):
            return expr
Martin Bauer's avatar
Martin Bauer committed
168
169
        param_list = [visit(a) for a in expr.args]
        return expr if not param_list else expr.func(*param_list)
170

Martin Bauer's avatar
Martin Bauer committed
171
172
    if len(substitutions) == 0:
        return expression
173
    else:
Martin Bauer's avatar
Martin Bauer committed
174
175
        return visit(expression)

176

177
178
179
180
181
182
183
def is_constant(expr):
    """Simple version of checking if a sympy expression is constant.
    Works also for piecewise defined functions - sympy's is_constant() has a problem there, see:
    https://github.com/sympy/sympy/issues/16662
    """
    return len(expr.free_symbols) == 0

184

Martin Bauer's avatar
Martin Bauer committed
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr,
                  required_match_replacement: Optional[Union[int, float]] = 0.5,
                  required_match_original: Optional[Union[int, float]] = None) -> sp.Expr:
    """Transformation for replacing a given subexpression inside a sum.

    Examples:
        The next example demonstrates the advantage of replace_additive compared to sympy.subs:
        >>> x, y, z, k = sp.symbols("x y z k")
        >>> subs_additive(3*x + 3*y, replacement=k, subexpression=x + y)
        3*k

        Terms that don't match completely can be substituted at the cost of additional terms.
        This trade-off is managed using the required_match parameters.
        >>> subs_additive(3*x + 3*y + z, replacement=k, subexpression=x+y+z, required_match_original=1.0)
        3*x + 3*y + z
        >>> subs_additive(3*x + 3*y + z, replacement=k, subexpression=x+y+z, required_match_original=0.5)
        3*k - 2*z
202
203
        >>> subs_additive(3*x + 3*y + z, replacement=k, subexpression=x+y+z, required_match_original=2)
        3*k - 2*z
Martin Bauer's avatar
Martin Bauer committed
204
205
206

    Args:
        expr: input expression
Martin Bauer's avatar
Martin Bauer committed
207
        replacement: expression that is inserted for subexpression (if found)
Martin Bauer's avatar
Martin Bauer committed
208
209
        subexpression: expression to replace
        required_match_replacement:
Martin Bauer's avatar
Martin Bauer committed
210
             * if float: the percentage of terms of the subexpression that has to be matched in order to replace
Martin Bauer's avatar
Martin Bauer committed
211
212
213
214
215
216
217
218
219
220
             * if integer: the total number of terms that has to be matched in order to replace
             * None: is equal to integer 1
             * if both match parameters are given, both restrictions have to be fulfilled (i.e. logical AND)
        required_match_original:
             * if float: the percentage of terms of the original addition expression that has to be matched
             * if integer: the total number of terms that has to be matched in order to replace
             * None: is equal to integer 1

    Returns:
        new expression with replacement
221
    """
Martin Bauer's avatar
Martin Bauer committed
222
223
    def normalize_match_parameter(match_parameter, expression_length):
        if match_parameter is None:
224
            return 1
Martin Bauer's avatar
Martin Bauer committed
225
226
227
        elif isinstance(match_parameter, float):
            assert 0 <= match_parameter <= 1
            res = int(match_parameter * expression_length)
228
            return max(res, 1)
Martin Bauer's avatar
Martin Bauer committed
229
230
231
        elif isinstance(match_parameter, int):
            assert match_parameter > 0
            return match_parameter
232
233
        raise ValueError("Invalid parameter")

Martin Bauer's avatar
Martin Bauer committed
234
    normalized_replacement_match = normalize_match_parameter(required_match_replacement, len(subexpression.args))
235

Martin Bauer's avatar
Martin Bauer committed
236
237
238
239
240
241
242
243
    def visit(current_expr):
        if current_expr.is_Add:
            expr_max_length = max(len(current_expr.args), len(subexpression.args))
            normalized_current_expr_match = normalize_match_parameter(required_match_original, expr_max_length)
            expr_coefficients = current_expr.as_coefficients_dict()
            subexpression_coefficient_dict = subexpression.as_coefficients_dict()
            intersection = set(subexpression_coefficient_dict.keys()).intersection(set(expr_coefficients))
            if len(intersection) >= max(normalized_replacement_match, normalized_current_expr_match):
244
                # find common factor
245
                factors = defaultdict(int)
246
                skips = 0
Martin Bauer's avatar
Martin Bauer committed
247
248
                for common_symbol in subexpression_coefficient_dict.keys():
                    if common_symbol not in expr_coefficients:
249
250
                        skips += 1
                        continue
Martin Bauer's avatar
Martin Bauer committed
251
                    factor = expr_coefficients[common_symbol] / subexpression_coefficient_dict[common_symbol]
252
253
                    factors[sp.simplify(factor)] += 1

Martin Bauer's avatar
Martin Bauer committed
254
255
256
                common_factor = max(factors.items(), key=operator.itemgetter(1))[0]
                if factors[common_factor] >= max(normalized_current_expr_match, normalized_replacement_match):
                    return current_expr - common_factor * subexpression + common_factor * replacement
257
258

        # if no subexpression was found
Martin Bauer's avatar
Martin Bauer committed
259
260
261
        param_list = [visit(a) for a in current_expr.args]
        if not param_list:
            return current_expr
262
        else:
263
264
265
266
            if current_expr.func == sp.Mul and sp.numbers.Zero() in param_list:
                return sp.numbers.Zero()
            else:
                return current_expr.func(*param_list, evaluate=False)
267
268
269
270

    return visit(expr)


Martin Bauer's avatar
Martin Bauer committed
271
272
273
274
275
def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Symbol],
                                  positive: Optional[bool] = None,
                                  replace_mixed: Optional[List[Assignment]] = None) -> sp.Expr:
    """Replaces second order mixed terms like x*y by 2*( (x+y)**2 - x**2 - y**2 ).

276
277
    This makes the term longer - simplify usually is undoing these - however this
    transformation can be done to find more common sub-expressions
Martin Bauer's avatar
Martin Bauer committed
278
279
280
281
282
283
284
285
286
287
288

    Args:
        expr: input expression
        search_symbols: symbols that are searched for
                         for example, given [x,y,z] terms like x*y, x*z, z*y are replaced
        positive: there are two ways to do this substitution, either with term
                 (x+y)**2 or (x-y)**2 . if positive=True the first version is done,
                 if positive=False the second version is done, if positive=None the
                 sign is determined by the sign of the mixed term that is replaced
        replace_mixed: if a list is passed here, the expr x+y or x-y is replaced by a special new symbol
                       and the replacement equation is added to the list
289
    """
Martin Bauer's avatar
Martin Bauer committed
290
    mixed_symbols_replaced = set([e.lhs for e in replace_mixed]) if replace_mixed is not None else set()
291
292

    if expr.is_Mul:
Martin Bauer's avatar
Martin Bauer committed
293
294
295
        distinct_search_symbols = set()
        nr_of_search_terms = 0
        other_factors = 1
296
        for t in expr.args:
Martin Bauer's avatar
Martin Bauer committed
297
298
299
            if t in search_symbols:
                nr_of_search_terms += 1
                distinct_search_symbols.add(t)
300
            else:
Martin Bauer's avatar
Martin Bauer committed
301
302
303
                other_factors *= t
        if len(distinct_search_symbols) == 2 and nr_of_search_terms == 2:
            u, v = sorted(list(distinct_search_symbols), key=lambda symbol: symbol.name)
304
            if positive is None:
Martin Bauer's avatar
Martin Bauer committed
305
306
307
308
                other_factors_without_symbols = other_factors
                for s in other_factors.atoms(sp.Symbol):
                    other_factors_without_symbols = other_factors_without_symbols.subs(s, 1)
                positive = other_factors_without_symbols.is_positive
309
310
                assert positive is not None
            sign = 1 if positive else -1
Martin Bauer's avatar
Martin Bauer committed
311
312
313
314
315
316
317
            if replace_mixed is not None:
                new_symbol_str = 'P' if positive else 'M'
                mixed_symbol_name = u.name + new_symbol_str + v.name
                mixed_symbol = sp.Symbol(mixed_symbol_name.replace("_", ""))
                if mixed_symbol not in mixed_symbols_replaced:
                    mixed_symbols_replaced.add(mixed_symbol)
                    replace_mixed.append(Assignment(mixed_symbol, u + sign * v))
318
            else:
Martin Bauer's avatar
Martin Bauer committed
319
320
                mixed_symbol = u + sign * v
            return sp.Rational(1, 2) * sign * other_factors * (mixed_symbol ** 2 - u ** 2 - v ** 2)
321

Martin Bauer's avatar
Martin Bauer committed
322
323
    param_list = [replace_second_order_products(a, search_symbols, positive, replace_mixed) for a in expr.args]
    result = expr.func(*param_list, evaluate=False) if param_list else expr
324
325
326
    return result


Martin Bauer's avatar
Martin Bauer committed
327
328
def remove_higher_order_terms(expr: sp.Expr, symbols: Sequence[sp.Symbol], order: int = 3) -> sp.Expr:
    """Removes all terms that contain more than 'order' factors of given 'symbols'
Martin Bauer's avatar
Martin Bauer committed
329
330
331
332

    Example:
        >>> x, y = sp.symbols("x y")
        >>> term = x**2 * y + y**2 * x + y**3 + x + y ** 2
Martin Bauer's avatar
Martin Bauer committed
333
        >>> remove_higher_order_terms(term, order=2, symbols=[x, y])
Martin Bauer's avatar
Martin Bauer committed
334
        x + y**2
335
336
337
338
339
    """
    from sympy.core.power import Pow
    from sympy.core.add import Add, Mul

    result = 0
Martin Bauer's avatar
Martin Bauer committed
340
    expr = expr.expand()
341

Martin Bauer's avatar
Martin Bauer committed
342
343
    def velocity_factors_in_product(product):
        factor_count = 0
Martin Bauer's avatar
Martin Bauer committed
344
345
346
347
        if type(product) is Mul:
            for factor in product.args:
                if type(factor) == Pow:
                    if factor.args[0] in symbols:
Martin Bauer's avatar
Martin Bauer committed
348
                        factor_count += factor.args[1]
Martin Bauer's avatar
Martin Bauer committed
349
                if factor in symbols:
Martin Bauer's avatar
Martin Bauer committed
350
                    factor_count += 1
Martin Bauer's avatar
Martin Bauer committed
351
352
        elif type(product) is Pow:
            if product.args[0] in symbols:
Martin Bauer's avatar
Martin Bauer committed
353
354
                factor_count += product.args[1]
        return factor_count
355

Martin Bauer's avatar
Martin Bauer committed
356
357
358
    if type(expr) == Mul or type(expr) == Pow:
        if velocity_factors_in_product(expr) <= order:
            return expr
359
360
361
        else:
            return sp.Rational(0, 1)

Martin Bauer's avatar
Martin Bauer committed
362
363
    if type(expr) != Add:
        return expr
364

Martin Bauer's avatar
Martin Bauer committed
365
366
367
    for sum_term in expr.args:
        if velocity_factors_in_product(sum_term) <= order:
            result += sum_term
368
369
370
    return result


Martin Bauer's avatar
Martin Bauer committed
371
372
373
def complete_the_square(expr: sp.Expr, symbol_to_complete: sp.Symbol,
                        new_variable: sp.Symbol) -> Tuple[sp.Expr, Optional[Tuple[sp.Symbol, sp.Expr]]]:
    """Transforms second order polynomial into only squared part.
374

Martin Bauer's avatar
Martin Bauer committed
375
376
377
378
379
380
381
382
    Examples:
        >>> a, b, c, s, n = sp.symbols("a b c s n")
        >>> expr = a * s**2 + b * s + c
        >>> completed_expr, substitution = complete_the_square(expr, symbol_to_complete=s, new_variable=n)
        >>> completed_expr
        a*n**2 + c - b**2/(4*a)
        >>> substitution
        (n, s + b/(2*a))
383

Martin Bauer's avatar
Martin Bauer committed
384
    Returns:
Martin Bauer's avatar
Martin Bauer committed
385
        (replaced_expr, tuple to pass to subs, such that old expr comes out again)
386
    """
Martin Bauer's avatar
Martin Bauer committed
387
388
389
    p = sp.Poly(expr, symbol_to_complete)
    coefficients = p.all_coeffs()
    if len(coefficients) != 3:
390
        return expr, None
Martin Bauer's avatar
Martin Bauer committed
391
392
393
    a, b, _ = coefficients
    expr = expr.subs(symbol_to_complete, new_variable - b / (2 * a))
    return sp.simplify(expr), (new_variable, symbol_to_complete + b / (2 * a))
394
395


Martin Bauer's avatar
Martin Bauer committed
396
397
398
399
400
401
def complete_the_squares_in_exp(expr: sp.Expr, symbols_to_complete: Sequence[sp.Symbol]):
    """Completes squares in arguments of exponential which makes them simpler to integrate.

    Very useful for integrating Maxwell-Boltzmann equilibria and its moment generating function
    """
    dummies = [sp.Dummy() for _ in symbols_to_complete]
402
403
404

    def visit(term):
        if term.func == sp.exp:
Martin Bauer's avatar
Martin Bauer committed
405
406
407
408
            exp_arg = term.args[0]
            for symbol_to_complete, dummy in zip(symbols_to_complete, dummies):
                exp_arg, substitution = complete_the_square(exp_arg, symbol_to_complete, dummy)
            return sp.exp(sp.expand(exp_arg))
409
        else:
Martin Bauer's avatar
Martin Bauer committed
410
411
            param_list = [visit(a) for a in term.args]
            if not param_list:
412
413
                return term
            else:
Martin Bauer's avatar
Martin Bauer committed
414
                return term.func(*param_list)
415
416

    result = visit(expr)
Martin Bauer's avatar
Martin Bauer committed
417
418
    for s, d in zip(symbols_to_complete, dummies):
        result = result.subs(d, s)
419
420
421
    return result


Martin Bauer's avatar
Martin Bauer committed
422
def extract_most_common_factor(term):
423
    """Processes a sum of fractions: determines the most common factor and splits term in common factor and rest"""
Martin Bauer's avatar
Martin Bauer committed
424
425
426
    coefficient_dict = term.as_coefficients_dict()
    counter = Counter([Abs(v) for v in coefficient_dict.values()])
    common_factor, occurrences = max(counter.items(), key=operator.itemgetter(1))
Martin Bauer's avatar
Martin Bauer committed
427
    if occurrences == 1 and (1 in counter):
Martin Bauer's avatar
Martin Bauer committed
428
429
        common_factor = 1
    return common_factor, term / common_factor
430
431


Martin Bauer's avatar
Martin Bauer committed
432
433
434
def count_operations(term: Union[sp.Expr, List[sp.Expr]],
                     only_type: Optional[str] = 'real') -> Dict[str, int]:
    """Counts the number of additions, multiplications and division.
Martin Bauer's avatar
Martin Bauer committed
435

Martin Bauer's avatar
Martin Bauer committed
436
437
438
    Args:
        term: a sympy expression (term, assignment) or sequence of sympy objects
        only_type: 'real' or 'int' to count only operations on these types, or None for all
Martin Bauer's avatar
Martin Bauer committed
439

Martin Bauer's avatar
Martin Bauer committed
440
441
    Returns:
        dict with 'adds', 'muls' and 'divs' keys
442
    """
443
444
445
446
    from pystencils.fast_approximation import fast_sqrt, fast_inv_sqrt, fast_division

    result = {'adds': 0, 'muls': 0, 'divs': 0, 'sqrts': 0,
              'fast_sqrts': 0, 'fast_inv_sqrts': 0, 'fast_div': 0}
447
448
449

    if isinstance(term, Sequence):
        for element in term:
Martin Bauer's avatar
Martin Bauer committed
450
            r = count_operations(element, only_type)
Martin Bauer's avatar
Martin Bauer committed
451
452
            for operation_name in result.keys():
                result[operation_name] += r[operation_name]
453
        return result
454
    elif isinstance(term, Assignment):
455
456
        term = term.rhs

457
458
    if hasattr(term, 'evalf'):
        term = term.evalf()
459

Martin Bauer's avatar
Martin Bauer committed
460
461
    def check_type(e):
        if only_type is None:
462
463
            return True
        try:
Martin Bauer's avatar
Martin Bauer committed
464
            base_type = get_base_type(get_type_of_expression(e))
465
466
        except ValueError:
            return False
Martin Bauer's avatar
Martin Bauer committed
467
        if only_type == 'int' and (base_type.is_int() or base_type.is_uint()):
468
            return True
Martin Bauer's avatar
Martin Bauer committed
469
        if only_type == 'real' and (base_type.is_float()):
470
471
            return True
        else:
Martin Bauer's avatar
Martin Bauer committed
472
            return base_type == only_type
473

474
    def visit(t):
Martin Bauer's avatar
Martin Bauer committed
475
        visit_children = True
476
        if t.func is sp.Add:
Martin Bauer's avatar
Martin Bauer committed
477
            if check_type(t):
478
                result['adds'] += len(t.args) - 1
Julian Hammer's avatar
Julian Hammer committed
479
480
        elif t.func in [sp.Or, sp.And]:
            pass
481
        elif t.func is sp.Mul:
Martin Bauer's avatar
Martin Bauer committed
482
            if check_type(t):
483
484
485
486
                result['muls'] += len(t.args) - 1
                for a in t.args:
                    if a == 1 or a == -1:
                        result['muls'] -= 1
Martin Bauer's avatar
Martin Bauer committed
487
        elif isinstance(t, sp.Float) or isinstance(t, sp.Rational):
488
489
            pass
        elif isinstance(t, sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
490
            visit_children = False
491
        elif isinstance(t, sp.Indexed):
Martin Bauer's avatar
Martin Bauer committed
492
            visit_children = False
493
494
        elif t.is_integer:
            pass
495
        elif isinstance(t, cast_func):
Martin Bauer's avatar
Martin Bauer committed
496
497
            visit_children = False
            visit(t.args[0])
498
499
500
501
502
503
        elif t.func is fast_sqrt:
            result['fast_sqrts'] += 1
        elif t.func is fast_inv_sqrt:
            result['fast_inv_sqrts'] += 1
        elif t.func is fast_division:
            result['fast_div'] += 1
504
        elif t.func is sp.Pow:
Martin Bauer's avatar
Martin Bauer committed
505
506
            if check_type(t.args[0]):
                visit_children = False
507
508
509
510
511
512
513
                if t.exp.is_integer and t.exp.is_number:
                    if t.exp >= 0:
                        result['muls'] += int(t.exp) - 1
                    else:
                        result['muls'] -= 1
                        result['divs'] += 1
                        result['muls'] += (-int(t.exp)) - 1
514
515
516
517
                elif sp.nsimplify(t.exp) == sp.Rational(1, 2):
                    result['sqrts'] += 1
                else:
                    warnings.warn("Cannot handle exponent", t.exp, " of sp.Pow node")
518
519
520
            else:
                warnings.warn("Counting operations: only integer exponents are supported in Pow, "
                              "counting will be inaccurate")
521
522
523
524
        elif t.func is sp.Piecewise:
            for child_term, condition in t.args:
                visit(child_term)
            visit_children = False
525
526
        elif isinstance(t, sp.Rel):
            pass
527
528
529
        else:
            warnings.warn("Unknown sympy node of type " + str(t.func) + " counting will be inaccurate")

Martin Bauer's avatar
Martin Bauer committed
530
        if visit_children:
531
532
533
534
535
            for a in t.args:
                visit(a)

    visit(term)
    return result
536
537


Martin Bauer's avatar
Martin Bauer committed
538
539
def count_operations_in_ast(ast) -> Dict[str, int]:
    """Counts number of operations in an abstract syntax tree, see also :func:`count_operations`"""
540
    from pystencils.astnodes import SympyAssignment
541
    result = defaultdict(int)
542
543
544

    def visit(node):
        if isinstance(node, SympyAssignment):
Martin Bauer's avatar
Martin Bauer committed
545
            r = count_operations(node.rhs)
546
547
            for k, v in r.items():
                result[k] += v
548
549
550
551
552
553
554
        else:
            for arg in node.args:
                visit(arg)
    visit(ast)
    return result


Martin Bauer's avatar
Martin Bauer committed
555
556
def common_denominator(expr: sp.Expr) -> sp.Expr:
    """Finds least common multiple of all denominators occurring in an expression"""
557
558
    denominators = [r.q for r in expr.atoms(sp.Rational)]
    return sp.lcm(denominators)
559

Martin Bauer's avatar
Martin Bauer committed
560

Martin Bauer's avatar
Martin Bauer committed
561
def get_symmetric_part(expr: sp.Expr, symbols: Iterable[sp.Symbol]) -> sp.Expr:
Martin Bauer's avatar
Martin Bauer committed
562
563
564
    """
    Returns the symmetric part of a sympy expressions.

Martin Bauer's avatar
Martin Bauer committed
565
566
567
568
569
570
    Args:
        expr: sympy expression, labeled here as :math:`f`
        symbols: sequence of symbols which are considered as degrees of freedom, labeled here as :math:`x_0, x_1,...`

    Returns:
        :math:`\frac{1}{2} [ f(x_0, x_1, ..) + f(-x_0, -x_1) ]`
Martin Bauer's avatar
Martin Bauer committed
571
    """
Martin Bauer's avatar
Martin Bauer committed
572
573
    substitution_dict = {e: -e for e in symbols}
    return sp.Rational(1, 2) * (expr + expr.subs(substitution_dict))
574
575


Martin Bauer's avatar
Martin Bauer committed
576
577
578
class SymbolCreator:
    def __getattribute__(self, name):
        return sp.Symbol(name)