Newer
Older
import operator
from functools import reduce, partial
from collections import defaultdict, Counter
from sympy.functions import Abs
from typing import Optional, Union, List, TypeVar, Iterable, Sequence, Callable, Dict, Tuple
from pystencils.data_types import get_type_of_expression, get_base_type, cast_func
from pystencils.assignment import Assignment
"""Takes a sequence and returns the product of all elements"""
return reduce(operator.mul, seq, 1)
def remove_small_floats(expr, threshold):
"""Removes all sp.Float objects whose absolute value is smaller than threshold
>>> expr = sp.sympify("x + 1e-15 * y")
>>> remove_small_floats(expr, 1e-14)
x
"""
if isinstance(expr, sp.Float) and sp.Abs(expr) < threshold:
return 0
else:
new_args = [remove_small_floats(c, threshold) for c in expr.args]
return expr.func(*new_args) if new_args else expr
def is_integer_sequence(sequence: Iterable) -> bool:
"""Checks if all elements of the passed sequence can be cast to integers"""
return True
except TypeError:
return False
def scalar_product(a: Iterable[T], b: Iterable[T]) -> T:
"""Scalar product between two sequences."""
return sum(a_i * b_i for a_i, b_i in zip(a, b))
def kronecker_delta(*args):
"""Kronecker delta for variable number of arguments, 1 if all args are equal, otherwise 0"""
for a in args:
if a != args[0]:
return 0
return 1

Martin Bauer
committed
def tanh_step_function_approximation(x, step_location, kind='right', steepness=0.0001):
"""Approximation of step function by a tanh function
>>> tanh_step_function_approximation(1.2, step_location=1.0, kind='right')
1.00000000000000
>>> tanh_step_function_approximation(0.9, step_location=1.0, kind='right')
0
>>> tanh_step_function_approximation(1.1, step_location=1.0, kind='left')
0
>>> tanh_step_function_approximation(0.9, step_location=1.0, kind='left')
1.00000000000000
>>> tanh_step_function_approximation(0.5, step_location=(0, 1), kind='middle')
1
"""
if kind == 'left':
return (1 - sp.tanh((x - step_location) / steepness)) / 2
elif kind == 'right':
return (1 + sp.tanh((x - step_location) / steepness)) / 2
elif kind == 'middle':
x1, x2 = step_location
return 1 - (tanh_step_function_approximation(x, x1, 'left', steepness)
+ tanh_step_function_approximation(x, x2, 'right', steepness))

Martin Bauer
committed
def multidimensional_sum(i, dim):
"""Multidimensional summation
Example:
>>> list(multidimensional_sum(2, dim=3))
[(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2)]
prod_args = [range(dim)] * i
return itertools.product(*prod_args)
def normalize_product(product: sp.Expr) -> List[sp.Expr]:
"""Expects a sympy expression that can be interpreted as a product and returns a list of all factors.
Removes sp.Pow nodes that have integer exponent by representing them as single factors in list.
Returns:
* for a Mul node list of factors ('args')
* for a Pow node with positive integer exponent a list of factors
* for other node types [product] is returned
if power.exp.is_integer and power.exp.is_number and power.exp > 0:
return [power.base] * power.exp
else:
return [power]
if isinstance(product, sp.Pow):
return handle_pow(product)
elif isinstance(product, sp.Mul):
result = []
for a in product.args:
if a.func == sp.Pow:
else:
result.append(a)
return result
else:
return [product]
def symmetric_product(*args, with_diagonal: bool = True) -> Iterable:
"""Similar to itertools.product but yields only values where the index is ascending i.e. values below/up to diagonal
Examples:
>>> list(symmetric_product([1, 2, 3], ['a', 'b', 'c']))
[(1, 'a'), (1, 'b'), (1, 'c'), (2, 'b'), (2, 'c'), (3, 'c')]
>>> list(symmetric_product([1, 2, 3], ['a', 'b', 'c'], with_diagonal=False))
[(1, 'b'), (1, 'c'), (2, 'c')]
"""
ranges = [range(len(a)) for a in args]
for idx in itertools.product(*ranges):
for t in range(1, len(idx)):
if (with_diagonal and idx[t - 1] > idx[t]) or (not with_diagonal and idx[t - 1] >= idx[t]):
valid_index = False
yield tuple(a[i] for a, i in zip(args, idx))
Args:
expression: expression where parts should be substituted
substitutions: dict defining substitutions by mapping from old to new terms
skip: function that marks expressions to be skipped (if True is returned) - that means that in these skipped
expressions no substitutions are done
This version is much faster for big substitution dictionaries than sympy version
"""
if type(expression) is sp.Matrix:
return expression.copy().applyfunc(partial(fast_subs, substitutions=substitutions))
if skip and skip(expr):
return expr
if hasattr(expr, "fast_subs"):
return expr.fast_subs(substitutions)
if expr in substitutions:
return substitutions[expr]
if not hasattr(expr, 'args'):
return expr
param_list = [visit(a) for a in expr.args]
return expr if not param_list else expr.func(*param_list)
def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr,
required_match_replacement: Optional[Union[int, float]] = 0.5,
required_match_original: Optional[Union[int, float]] = None) -> sp.Expr:
"""Transformation for replacing a given subexpression inside a sum.
Examples:
The next example demonstrates the advantage of replace_additive compared to sympy.subs:
>>> x, y, z, k = sp.symbols("x y z k")
>>> subs_additive(3*x + 3*y, replacement=k, subexpression=x + y)
3*k
Terms that don't match completely can be substituted at the cost of additional terms.
This trade-off is managed using the required_match parameters.
>>> subs_additive(3*x + 3*y + z, replacement=k, subexpression=x+y+z, required_match_original=1.0)
3*x + 3*y + z
>>> subs_additive(3*x + 3*y + z, replacement=k, subexpression=x+y+z, required_match_original=0.5)
3*k - 2*z
>>> subs_additive(3*x + 3*y + z, replacement=k, subexpression=x+y+z, required_match_original=2)
3*k - 2*z
replacement: expression that is inserted for subexpression (if found)
subexpression: expression to replace
required_match_replacement:
* if float: the percentage of terms of the subexpression that has to be matched in order to replace
* if integer: the total number of terms that has to be matched in order to replace
* None: is equal to integer 1
* if both match parameters are given, both restrictions have to be fulfilled (i.e. logical AND)
required_match_original:
* if float: the percentage of terms of the original addition expression that has to be matched
* if integer: the total number of terms that has to be matched in order to replace
* None: is equal to integer 1
Returns:
new expression with replacement
def normalize_match_parameter(match_parameter, expression_length):
if match_parameter is None:
elif isinstance(match_parameter, float):
assert 0 <= match_parameter <= 1
res = int(match_parameter * expression_length)
elif isinstance(match_parameter, int):
assert match_parameter > 0
return match_parameter
raise ValueError("Invalid parameter")
normalized_replacement_match = normalize_match_parameter(required_match_replacement, len(subexpression.args))
def visit(current_expr):
if current_expr.is_Add:
expr_max_length = max(len(current_expr.args), len(subexpression.args))
normalized_current_expr_match = normalize_match_parameter(required_match_original, expr_max_length)
expr_coefficients = current_expr.as_coefficients_dict()
subexpression_coefficient_dict = subexpression.as_coefficients_dict()
intersection = set(subexpression_coefficient_dict.keys()).intersection(set(expr_coefficients))
if len(intersection) >= max(normalized_replacement_match, normalized_current_expr_match):
# find common factor
factors = defaultdict(lambda: 0)
skips = 0
for common_symbol in subexpression_coefficient_dict.keys():
if common_symbol not in expr_coefficients:
factor = expr_coefficients[common_symbol] / subexpression_coefficient_dict[common_symbol]
factors[sp.simplify(factor)] += 1
common_factor = max(factors.items(), key=operator.itemgetter(1))[0]
if factors[common_factor] >= max(normalized_current_expr_match, normalized_replacement_match):
return current_expr - common_factor * subexpression + common_factor * replacement
param_list = [visit(a) for a in current_expr.args]
if not param_list:
return current_expr
def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Symbol],
positive: Optional[bool] = None,
replace_mixed: Optional[List[Assignment]] = None) -> sp.Expr:
"""Replaces second order mixed terms like x*y by 2*( (x+y)**2 - x**2 - y**2 ).
This makes the term longer - simplify usually is undoing these - however this
transformation can be done to find more common sub-expressions
Args:
expr: input expression
search_symbols: symbols that are searched for
for example, given [x,y,z] terms like x*y, x*z, z*y are replaced
positive: there are two ways to do this substitution, either with term
(x+y)**2 or (x-y)**2 . if positive=True the first version is done,
if positive=False the second version is done, if positive=None the
sign is determined by the sign of the mixed term that is replaced
replace_mixed: if a list is passed here, the expr x+y or x-y is replaced by a special new symbol
and the replacement equation is added to the list
mixed_symbols_replaced = set([e.lhs for e in replace_mixed]) if replace_mixed is not None else set()
distinct_search_symbols = set()
nr_of_search_terms = 0
other_factors = 1
if t in search_symbols:
nr_of_search_terms += 1
distinct_search_symbols.add(t)
other_factors *= t
if len(distinct_search_symbols) == 2 and nr_of_search_terms == 2:
u, v = sorted(list(distinct_search_symbols), key=lambda symbol: symbol.name)
other_factors_without_symbols = other_factors
for s in other_factors.atoms(sp.Symbol):
other_factors_without_symbols = other_factors_without_symbols.subs(s, 1)
positive = other_factors_without_symbols.is_positive
assert positive is not None
sign = 1 if positive else -1
if replace_mixed is not None:
new_symbol_str = 'P' if positive else 'M'
mixed_symbol_name = u.name + new_symbol_str + v.name
mixed_symbol = sp.Symbol(mixed_symbol_name.replace("_", ""))
if mixed_symbol not in mixed_symbols_replaced:
mixed_symbols_replaced.add(mixed_symbol)
replace_mixed.append(Assignment(mixed_symbol, u + sign * v))
mixed_symbol = u + sign * v
return sp.Rational(1, 2) * sign * other_factors * (mixed_symbol ** 2 - u ** 2 - v ** 2)
param_list = [replace_second_order_products(a, search_symbols, positive, replace_mixed) for a in expr.args]
result = expr.func(*param_list, evaluate=False) if param_list else expr
def remove_higher_order_terms(expr: sp.Expr, symbols: Sequence[sp.Symbol], order: int = 3) -> sp.Expr:
"""Removes all terms that contain more than 'order' factors of given 'symbols'
Example:
>>> x, y = sp.symbols("x y")
>>> term = x**2 * y + y**2 * x + y**3 + x + y ** 2
>>> remove_higher_order_terms(term, order=2, symbols=[x, y])
"""
from sympy.core.power import Pow
from sympy.core.add import Add, Mul
result = 0
def velocity_factors_in_product(product):
factor_count = 0
if type(product) is Mul:
for factor in product.args:
if type(factor) == Pow:
if factor.args[0] in symbols:
elif type(product) is Pow:
if product.args[0] in symbols:
if type(expr) == Mul or type(expr) == Pow:
if velocity_factors_in_product(expr) <= order:
return expr
else:
return sp.Rational(0, 1)
for sum_term in expr.args:
if velocity_factors_in_product(sum_term) <= order:
result += sum_term
def complete_the_square(expr: sp.Expr, symbol_to_complete: sp.Symbol,
new_variable: sp.Symbol) -> Tuple[sp.Expr, Optional[Tuple[sp.Symbol, sp.Expr]]]:
"""Transforms second order polynomial into only squared part.
Examples:
>>> a, b, c, s, n = sp.symbols("a b c s n")
>>> expr = a * s**2 + b * s + c
>>> completed_expr, substitution = complete_the_square(expr, symbol_to_complete=s, new_variable=n)
>>> completed_expr
a*n**2 + c - b**2/(4*a)
>>> substitution
(n, s + b/(2*a))
(replaced_expr, tuple to pass to subs, such that old expr comes out again)
p = sp.Poly(expr, symbol_to_complete)
coefficients = p.all_coeffs()
if len(coefficients) != 3:
a, b, _ = coefficients
expr = expr.subs(symbol_to_complete, new_variable - b / (2 * a))
return sp.simplify(expr), (new_variable, symbol_to_complete + b / (2 * a))
def complete_the_squares_in_exp(expr: sp.Expr, symbols_to_complete: Sequence[sp.Symbol]):
"""Completes squares in arguments of exponential which makes them simpler to integrate.
Very useful for integrating Maxwell-Boltzmann equilibria and its moment generating function
"""
dummies = [sp.Dummy() for _ in symbols_to_complete]
def visit(term):
if term.func == sp.exp:
exp_arg = term.args[0]
for symbol_to_complete, dummy in zip(symbols_to_complete, dummies):
exp_arg, substitution = complete_the_square(exp_arg, symbol_to_complete, dummy)
return sp.exp(sp.expand(exp_arg))
param_list = [visit(a) for a in term.args]
if not param_list:
for s, d in zip(symbols_to_complete, dummies):
result = result.subs(d, s)
"""Processes a sum of fractions: determines the most common factor and splits term in common factor and rest"""
coefficient_dict = term.as_coefficients_dict()
counter = Counter([Abs(v) for v in coefficient_dict.values()])
common_factor, occurrences = max(counter.items(), key=operator.itemgetter(1))
common_factor = 1
return common_factor, term / common_factor
def count_operations(term: Union[sp.Expr, List[sp.Expr]],
only_type: Optional[str] = 'real') -> Dict[str, int]:
"""Counts the number of additions, multiplications and division.
Args:
term: a sympy expression (term, assignment) or sequence of sympy objects
only_type: 'real' or 'int' to count only operations on these types, or None for all
"""
result = {'adds': 0, 'muls': 0, 'divs': 0}
if isinstance(term, Sequence):
for element in term:
for operation_name in result.keys():
result[operation_name] += r[operation_name]
elif isinstance(term, Assignment):
if not hasattr(term, 'evalf'):
return result
except ValueError:
return False
if only_type == 'int' and (base_type.is_int() or base_type.is_uint()):
result['adds'] += len(t.args) - 1
result['muls'] += len(t.args) - 1
for a in t.args:
if a == 1 or a == -1:
result['muls'] -= 1
elif isinstance(t, sp.Float) or isinstance(t, sp.Rational):
pass
elif isinstance(t, sp.Symbol):
elif isinstance(t, sp.tensor.Indexed):
if t.exp.is_integer and t.exp.is_number:
if t.exp >= 0:
result['muls'] += int(t.exp) - 1
else:
result['muls'] -= 1
result['divs'] += 1
result['muls'] += (-int(t.exp)) - 1
else:
warnings.warn("Counting operations: only integer exponents are supported in Pow, "
"counting will be inaccurate")
else:
warnings.warn("Unknown sympy node of type " + str(t.func) + " counting will be inaccurate")
for a in t.args:
visit(a)
visit(term)
return result
def count_operations_in_ast(ast) -> Dict[str, int]:
"""Counts number of operations in an abstract syntax tree, see also :func:`count_operations`"""
from pystencils.astnodes import SympyAssignment
result = {'adds': 0, 'muls': 0, 'divs': 0}
def visit(node):
if isinstance(node, SympyAssignment):
result['adds'] += r['adds']
result['muls'] += r['muls']
result['divs'] += r['divs']
else:
for arg in node.args:
visit(arg)
visit(ast)
return result
def common_denominator(expr: sp.Expr) -> sp.Expr:
"""Finds least common multiple of all denominators occurring in an expression"""
denominators = [r.q for r in expr.atoms(sp.Rational)]
return sp.lcm(denominators)
def get_symmetric_part(expr: sp.Expr, symbols: Iterable[sp.Symbol]) -> sp.Expr:
"""
Returns the symmetric part of a sympy expressions.
Args:
expr: sympy expression, labeled here as :math:`f`
symbols: sequence of symbols which are considered as degrees of freedom, labeled here as :math:`x_0, x_1,...`
Returns:
:math:`\frac{1}{2} [ f(x_0, x_1, ..) + f(-x_0, -x_1) ]`
substitution_dict = {e: -e for e in symbols}
return sp.Rational(1, 2) * (expr + expr.subs(substitution_dict))
def sort_assignments_topologically(assignments: Sequence[Assignment]) -> List[Assignment]:
"""Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs"""
res = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in assignments])
return [Assignment(a, b) for a, b in res]
class SymbolCreator:
def __getattribute__(self, name):
return sp.Symbol(name)