Skip to content
Snippets Groups Projects
sympyextensions.py 21.4 KiB
Newer Older
import warnings
Martin Bauer's avatar
Martin Bauer committed
import operator
from functools import reduce, partial
from collections import defaultdict, Counter
Martin Bauer's avatar
Martin Bauer committed
import sympy as sp
Martin Bauer's avatar
Martin Bauer committed
from sympy.functions import Abs
from typing import Optional, Union, List, TypeVar, Iterable, Sequence, Callable, Dict, Tuple
Martin Bauer's avatar
Martin Bauer committed
from pystencils.data_types import get_type_of_expression, get_base_type, cast_func
from pystencils.assignment import Assignment
Martin Bauer's avatar
Martin Bauer committed
T = TypeVar('T')

Martin Bauer's avatar
Martin Bauer committed
def prod(seq: Iterable[T]) -> T:
    """Takes a sequence and returns the product of all elements"""
    return reduce(operator.mul, seq, 1)


def remove_small_floats(expr, threshold):
    """Removes all sp.Float objects whose absolute value is smaller than threshold

    >>> expr = sp.sympify("x + 1e-15 * y")
    >>> remove_small_floats(expr, 1e-14)
    x
    """
    if isinstance(expr, sp.Float) and sp.Abs(expr) < threshold:
        return 0
    else:
        new_args = [remove_small_floats(c, threshold) for c in expr.args]
        return expr.func(*new_args) if new_args else expr


Martin Bauer's avatar
Martin Bauer committed
def is_integer_sequence(sequence: Iterable) -> bool:
    """Checks if all elements of the passed sequence can be cast to integers"""
Martin Bauer's avatar
Martin Bauer committed
        for i in sequence:
            int(i)
        return True
    except TypeError:
        return False


Martin Bauer's avatar
Martin Bauer committed
def scalar_product(a: Iterable[T], b: Iterable[T]) -> T:
    """Scalar product between two sequences."""
    return sum(a_i * b_i for a_i, b_i in zip(a, b))


Martin Bauer's avatar
Martin Bauer committed
def kronecker_delta(*args):
    """Kronecker delta for variable number of arguments, 1 if all args are equal, otherwise 0"""
    for a in args:
        if a != args[0]:
            return 0
    return 1


def tanh_step_function_approximation(x, step_location, kind='right', steepness=0.0001):
    """Approximation of step function by a tanh function

    >>> tanh_step_function_approximation(1.2, step_location=1.0, kind='right')
    1.00000000000000
    >>> tanh_step_function_approximation(0.9, step_location=1.0, kind='right')
    0
    >>> tanh_step_function_approximation(1.1, step_location=1.0, kind='left')
    0
    >>> tanh_step_function_approximation(0.9, step_location=1.0, kind='left')
    1.00000000000000
    >>> tanh_step_function_approximation(0.5, step_location=(0, 1), kind='middle')
    1
    """
    if kind == 'left':
        return (1 - sp.tanh((x - step_location) / steepness)) / 2
    elif kind == 'right':
        return (1 + sp.tanh((x - step_location) / steepness)) / 2
    elif kind == 'middle':
        x1, x2 = step_location
Martin Bauer's avatar
Martin Bauer committed
        return 1 - (tanh_step_function_approximation(x, x1, 'left', steepness)
                    + tanh_step_function_approximation(x, x2, 'right', steepness))
Martin Bauer's avatar
Martin Bauer committed
def multidimensional_sum(i, dim):
    """Multidimensional summation
Martin Bauer's avatar
Martin Bauer committed
    Example:
        >>> list(multidimensional_sum(2, dim=3))
        [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2)]
Martin Bauer's avatar
Martin Bauer committed
    prod_args = [range(dim)] * i
    return itertools.product(*prod_args)


def normalize_product(product: sp.Expr) -> List[sp.Expr]:
    """Expects a sympy expression that can be interpreted as a product and returns a list of all factors.

    Removes sp.Pow nodes that have integer exponent by representing them as single factors in list.

    Returns:
        * for a Mul node list of factors ('args')
        * for a Pow node with positive integer exponent a list of factors
        * for other node types [product] is returned
Martin Bauer's avatar
Martin Bauer committed
    def handle_pow(power):
        if power.exp.is_integer and power.exp.is_number and power.exp > 0:
            return [power.base] * power.exp
        else:
            return [power]

Martin Bauer's avatar
Martin Bauer committed
    if isinstance(product, sp.Pow):
        return handle_pow(product)
    elif isinstance(product, sp.Mul):
        result = []
        for a in product.args:
            if a.func == sp.Pow:
Martin Bauer's avatar
Martin Bauer committed
                result += handle_pow(a)
            else:
                result.append(a)
        return result
    else:
        return [product]


Martin Bauer's avatar
Martin Bauer committed
def symmetric_product(*args, with_diagonal: bool = True) -> Iterable:
    """Similar to itertools.product but yields only values where the index is ascending i.e. values below/up to diagonal

    Examples:
        >>> list(symmetric_product([1, 2, 3], ['a', 'b', 'c']))
        [(1, 'a'), (1, 'b'), (1, 'c'), (2, 'b'), (2, 'c'), (3, 'c')]
        >>> list(symmetric_product([1, 2, 3], ['a', 'b', 'c'], with_diagonal=False))
        [(1, 'b'), (1, 'c'), (2, 'c')]
    """
    ranges = [range(len(a)) for a in args]
    for idx in itertools.product(*ranges):
Martin Bauer's avatar
Martin Bauer committed
        valid_index = True
        for t in range(1, len(idx)):
Martin Bauer's avatar
Martin Bauer committed
            if (with_diagonal and idx[t - 1] > idx[t]) or (not with_diagonal and idx[t - 1] >= idx[t]):
                valid_index = False
Martin Bauer's avatar
Martin Bauer committed
        if valid_index:
            yield tuple(a[i] for a, i in zip(args, idx))


Martin Bauer's avatar
Martin Bauer committed
def fast_subs(expression: T, substitutions: Dict,
Martin Bauer's avatar
Martin Bauer committed
              skip: Optional[Callable[[sp.Expr], bool]] = None) -> T:
    """Similar to sympy subs function.
Martin Bauer's avatar
Martin Bauer committed

    Args:
        expression: expression where parts should be substituted
        substitutions: dict defining substitutions by mapping from old to new terms
        skip: function that marks expressions to be skipped (if True is returned) - that means that in these skipped
              expressions no substitutions are done

    This version is much faster for big substitution dictionaries than sympy version
    """
    if type(expression) is sp.Matrix:
        return expression.copy().applyfunc(partial(fast_subs, substitutions=substitutions))

    def visit(expr):
        if skip and skip(expr):
            return expr
Martin Bauer's avatar
Martin Bauer committed
        if hasattr(expr, "fast_subs"):
            return expr.fast_subs(substitutions)
        if expr in substitutions:
            return substitutions[expr]
        if not hasattr(expr, 'args'):
            return expr
Martin Bauer's avatar
Martin Bauer committed
        param_list = [visit(a) for a in expr.args]
        return expr if not param_list else expr.func(*param_list)
Martin Bauer's avatar
Martin Bauer committed
    if len(substitutions) == 0:
        return expression
Martin Bauer's avatar
Martin Bauer committed
        return visit(expression)

Martin Bauer's avatar
Martin Bauer committed
def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr,
                  required_match_replacement: Optional[Union[int, float]] = 0.5,
                  required_match_original: Optional[Union[int, float]] = None) -> sp.Expr:
    """Transformation for replacing a given subexpression inside a sum.

    Examples:
        The next example demonstrates the advantage of replace_additive compared to sympy.subs:
        >>> x, y, z, k = sp.symbols("x y z k")
        >>> subs_additive(3*x + 3*y, replacement=k, subexpression=x + y)
        3*k

        Terms that don't match completely can be substituted at the cost of additional terms.
        This trade-off is managed using the required_match parameters.
        >>> subs_additive(3*x + 3*y + z, replacement=k, subexpression=x+y+z, required_match_original=1.0)
        3*x + 3*y + z
        >>> subs_additive(3*x + 3*y + z, replacement=k, subexpression=x+y+z, required_match_original=0.5)
        3*k - 2*z
        >>> subs_additive(3*x + 3*y + z, replacement=k, subexpression=x+y+z, required_match_original=2)
        3*k - 2*z
Martin Bauer's avatar
Martin Bauer committed

    Args:
        expr: input expression
Martin Bauer's avatar
Martin Bauer committed
        replacement: expression that is inserted for subexpression (if found)
Martin Bauer's avatar
Martin Bauer committed
        subexpression: expression to replace
        required_match_replacement:
Martin Bauer's avatar
Martin Bauer committed
             * if float: the percentage of terms of the subexpression that has to be matched in order to replace
Martin Bauer's avatar
Martin Bauer committed
             * if integer: the total number of terms that has to be matched in order to replace
             * None: is equal to integer 1
             * if both match parameters are given, both restrictions have to be fulfilled (i.e. logical AND)
        required_match_original:
             * if float: the percentage of terms of the original addition expression that has to be matched
             * if integer: the total number of terms that has to be matched in order to replace
             * None: is equal to integer 1

    Returns:
        new expression with replacement
Martin Bauer's avatar
Martin Bauer committed
    def normalize_match_parameter(match_parameter, expression_length):
        if match_parameter is None:
Martin Bauer's avatar
Martin Bauer committed
        elif isinstance(match_parameter, float):
            assert 0 <= match_parameter <= 1
            res = int(match_parameter * expression_length)
            return max(res, 1)
Martin Bauer's avatar
Martin Bauer committed
        elif isinstance(match_parameter, int):
            assert match_parameter > 0
            return match_parameter
        raise ValueError("Invalid parameter")

Martin Bauer's avatar
Martin Bauer committed
    normalized_replacement_match = normalize_match_parameter(required_match_replacement, len(subexpression.args))
Martin Bauer's avatar
Martin Bauer committed
    def visit(current_expr):
        if current_expr.is_Add:
            expr_max_length = max(len(current_expr.args), len(subexpression.args))
            normalized_current_expr_match = normalize_match_parameter(required_match_original, expr_max_length)
            expr_coefficients = current_expr.as_coefficients_dict()
            subexpression_coefficient_dict = subexpression.as_coefficients_dict()
            intersection = set(subexpression_coefficient_dict.keys()).intersection(set(expr_coefficients))
            if len(intersection) >= max(normalized_replacement_match, normalized_current_expr_match):
                # find common factor
                factors = defaultdict(lambda: 0)
                skips = 0
Martin Bauer's avatar
Martin Bauer committed
                for common_symbol in subexpression_coefficient_dict.keys():
                    if common_symbol not in expr_coefficients:
                        skips += 1
                        continue
Martin Bauer's avatar
Martin Bauer committed
                    factor = expr_coefficients[common_symbol] / subexpression_coefficient_dict[common_symbol]
                    factors[sp.simplify(factor)] += 1

Martin Bauer's avatar
Martin Bauer committed
                common_factor = max(factors.items(), key=operator.itemgetter(1))[0]
                if factors[common_factor] >= max(normalized_current_expr_match, normalized_replacement_match):
                    return current_expr - common_factor * subexpression + common_factor * replacement

        # if no subexpression was found
Martin Bauer's avatar
Martin Bauer committed
        param_list = [visit(a) for a in current_expr.args]
        if not param_list:
            return current_expr
Martin Bauer's avatar
Martin Bauer committed
            return current_expr.func(*param_list, evaluate=False)
Martin Bauer's avatar
Martin Bauer committed
def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Symbol],
                                  positive: Optional[bool] = None,
                                  replace_mixed: Optional[List[Assignment]] = None) -> sp.Expr:
    """Replaces second order mixed terms like x*y by 2*( (x+y)**2 - x**2 - y**2 ).

    This makes the term longer - simplify usually is undoing these - however this
    transformation can be done to find more common sub-expressions
Martin Bauer's avatar
Martin Bauer committed

    Args:
        expr: input expression
        search_symbols: symbols that are searched for
                         for example, given [x,y,z] terms like x*y, x*z, z*y are replaced
        positive: there are two ways to do this substitution, either with term
                 (x+y)**2 or (x-y)**2 . if positive=True the first version is done,
                 if positive=False the second version is done, if positive=None the
                 sign is determined by the sign of the mixed term that is replaced
        replace_mixed: if a list is passed here, the expr x+y or x-y is replaced by a special new symbol
                       and the replacement equation is added to the list
Martin Bauer's avatar
Martin Bauer committed
    mixed_symbols_replaced = set([e.lhs for e in replace_mixed]) if replace_mixed is not None else set()

    if expr.is_Mul:
Martin Bauer's avatar
Martin Bauer committed
        distinct_search_symbols = set()
        nr_of_search_terms = 0
        other_factors = 1
        for t in expr.args:
Martin Bauer's avatar
Martin Bauer committed
            if t in search_symbols:
                nr_of_search_terms += 1
                distinct_search_symbols.add(t)
Martin Bauer's avatar
Martin Bauer committed
                other_factors *= t
        if len(distinct_search_symbols) == 2 and nr_of_search_terms == 2:
            u, v = sorted(list(distinct_search_symbols), key=lambda symbol: symbol.name)
            if positive is None:
Martin Bauer's avatar
Martin Bauer committed
                other_factors_without_symbols = other_factors
                for s in other_factors.atoms(sp.Symbol):
                    other_factors_without_symbols = other_factors_without_symbols.subs(s, 1)
                positive = other_factors_without_symbols.is_positive
                assert positive is not None
            sign = 1 if positive else -1
Martin Bauer's avatar
Martin Bauer committed
            if replace_mixed is not None:
                new_symbol_str = 'P' if positive else 'M'
                mixed_symbol_name = u.name + new_symbol_str + v.name
                mixed_symbol = sp.Symbol(mixed_symbol_name.replace("_", ""))
                if mixed_symbol not in mixed_symbols_replaced:
                    mixed_symbols_replaced.add(mixed_symbol)
                    replace_mixed.append(Assignment(mixed_symbol, u + sign * v))
Martin Bauer's avatar
Martin Bauer committed
                mixed_symbol = u + sign * v
            return sp.Rational(1, 2) * sign * other_factors * (mixed_symbol ** 2 - u ** 2 - v ** 2)
Martin Bauer's avatar
Martin Bauer committed
    param_list = [replace_second_order_products(a, search_symbols, positive, replace_mixed) for a in expr.args]
    result = expr.func(*param_list, evaluate=False) if param_list else expr
Martin Bauer's avatar
Martin Bauer committed
def remove_higher_order_terms(expr: sp.Expr, symbols: Sequence[sp.Symbol], order: int = 3) -> sp.Expr:
    """Removes all terms that contain more than 'order' factors of given 'symbols'
Martin Bauer's avatar
Martin Bauer committed

    Example:
        >>> x, y = sp.symbols("x y")
        >>> term = x**2 * y + y**2 * x + y**3 + x + y ** 2
Martin Bauer's avatar
Martin Bauer committed
        >>> remove_higher_order_terms(term, order=2, symbols=[x, y])
Martin Bauer's avatar
Martin Bauer committed
        x + y**2
    """
    from sympy.core.power import Pow
    from sympy.core.add import Add, Mul

    result = 0
Martin Bauer's avatar
Martin Bauer committed
    expr = expr.expand()
Martin Bauer's avatar
Martin Bauer committed
    def velocity_factors_in_product(product):
        factor_count = 0
Martin Bauer's avatar
Martin Bauer committed
        if type(product) is Mul:
            for factor in product.args:
                if type(factor) == Pow:
                    if factor.args[0] in symbols:
Martin Bauer's avatar
Martin Bauer committed
                        factor_count += factor.args[1]
Martin Bauer's avatar
Martin Bauer committed
                if factor in symbols:
Martin Bauer's avatar
Martin Bauer committed
                    factor_count += 1
Martin Bauer's avatar
Martin Bauer committed
        elif type(product) is Pow:
            if product.args[0] in symbols:
Martin Bauer's avatar
Martin Bauer committed
                factor_count += product.args[1]
        return factor_count
Martin Bauer's avatar
Martin Bauer committed
    if type(expr) == Mul or type(expr) == Pow:
        if velocity_factors_in_product(expr) <= order:
            return expr
        else:
            return sp.Rational(0, 1)

Martin Bauer's avatar
Martin Bauer committed
    if type(expr) != Add:
        return expr
Martin Bauer's avatar
Martin Bauer committed
    for sum_term in expr.args:
        if velocity_factors_in_product(sum_term) <= order:
            result += sum_term
Martin Bauer's avatar
Martin Bauer committed
def complete_the_square(expr: sp.Expr, symbol_to_complete: sp.Symbol,
                        new_variable: sp.Symbol) -> Tuple[sp.Expr, Optional[Tuple[sp.Symbol, sp.Expr]]]:
    """Transforms second order polynomial into only squared part.
Martin Bauer's avatar
Martin Bauer committed
    Examples:
        >>> a, b, c, s, n = sp.symbols("a b c s n")
        >>> expr = a * s**2 + b * s + c
        >>> completed_expr, substitution = complete_the_square(expr, symbol_to_complete=s, new_variable=n)
        >>> completed_expr
        a*n**2 + c - b**2/(4*a)
        >>> substitution
        (n, s + b/(2*a))
Martin Bauer's avatar
Martin Bauer committed
    Returns:
Martin Bauer's avatar
Martin Bauer committed
        (replaced_expr, tuple to pass to subs, such that old expr comes out again)
Martin Bauer's avatar
Martin Bauer committed
    p = sp.Poly(expr, symbol_to_complete)
    coefficients = p.all_coeffs()
    if len(coefficients) != 3:
        return expr, None
Martin Bauer's avatar
Martin Bauer committed
    a, b, _ = coefficients
    expr = expr.subs(symbol_to_complete, new_variable - b / (2 * a))
    return sp.simplify(expr), (new_variable, symbol_to_complete + b / (2 * a))
Martin Bauer's avatar
Martin Bauer committed
def complete_the_squares_in_exp(expr: sp.Expr, symbols_to_complete: Sequence[sp.Symbol]):
    """Completes squares in arguments of exponential which makes them simpler to integrate.

    Very useful for integrating Maxwell-Boltzmann equilibria and its moment generating function
    """
    dummies = [sp.Dummy() for _ in symbols_to_complete]

    def visit(term):
        if term.func == sp.exp:
Martin Bauer's avatar
Martin Bauer committed
            exp_arg = term.args[0]
            for symbol_to_complete, dummy in zip(symbols_to_complete, dummies):
                exp_arg, substitution = complete_the_square(exp_arg, symbol_to_complete, dummy)
            return sp.exp(sp.expand(exp_arg))
Martin Bauer's avatar
Martin Bauer committed
            param_list = [visit(a) for a in term.args]
            if not param_list:
                return term
            else:
Martin Bauer's avatar
Martin Bauer committed
                return term.func(*param_list)

    result = visit(expr)
Martin Bauer's avatar
Martin Bauer committed
    for s, d in zip(symbols_to_complete, dummies):
        result = result.subs(d, s)
Martin Bauer's avatar
Martin Bauer committed
def extract_most_common_factor(term):
    """Processes a sum of fractions: determines the most common factor and splits term in common factor and rest"""
Martin Bauer's avatar
Martin Bauer committed
    coefficient_dict = term.as_coefficients_dict()
    counter = Counter([Abs(v) for v in coefficient_dict.values()])
    common_factor, occurrences = max(counter.items(), key=operator.itemgetter(1))
Martin Bauer's avatar
Martin Bauer committed
    if occurrences == 1 and (1 in counter):
Martin Bauer's avatar
Martin Bauer committed
        common_factor = 1
    return common_factor, term / common_factor
Martin Bauer's avatar
Martin Bauer committed
def count_operations(term: Union[sp.Expr, List[sp.Expr]],
                     only_type: Optional[str] = 'real') -> Dict[str, int]:
    """Counts the number of additions, multiplications and division.
Martin Bauer's avatar
Martin Bauer committed
    Args:
        term: a sympy expression (term, assignment) or sequence of sympy objects
        only_type: 'real' or 'int' to count only operations on these types, or None for all
Martin Bauer's avatar
Martin Bauer committed
    Returns:
        dict with 'adds', 'muls' and 'divs' keys
    """
    result = {'adds': 0, 'muls': 0, 'divs': 0}

    if isinstance(term, Sequence):
        for element in term:
Martin Bauer's avatar
Martin Bauer committed
            r = count_operations(element, only_type)
Martin Bauer's avatar
Martin Bauer committed
            for operation_name in result.keys():
                result[operation_name] += r[operation_name]
        return result
    elif isinstance(term, Assignment):
        term = term.rhs

    if not hasattr(term, 'evalf'):
        return result

    term = term.evalf()

Martin Bauer's avatar
Martin Bauer committed
    def check_type(e):
        if only_type is None:
Martin Bauer's avatar
Martin Bauer committed
            base_type = get_base_type(get_type_of_expression(e))
        except ValueError:
            return False
Martin Bauer's avatar
Martin Bauer committed
        if only_type == 'int' and (base_type.is_int() or base_type.is_uint()):
Martin Bauer's avatar
Martin Bauer committed
        if only_type == 'real' and (base_type.is_float()):
Martin Bauer's avatar
Martin Bauer committed
            return base_type == only_type
    def visit(t):
Martin Bauer's avatar
Martin Bauer committed
        visit_children = True
        if t.func is sp.Add:
Martin Bauer's avatar
Martin Bauer committed
            if check_type(t):
                result['adds'] += len(t.args) - 1
        elif t.func in [sp.Or, sp.And]:
            pass
        elif t.func is sp.Mul:
Martin Bauer's avatar
Martin Bauer committed
            if check_type(t):
                result['muls'] += len(t.args) - 1
                for a in t.args:
                    if a == 1 or a == -1:
                        result['muls'] -= 1
Martin Bauer's avatar
Martin Bauer committed
        elif isinstance(t, sp.Float) or isinstance(t, sp.Rational):
            pass
        elif isinstance(t, sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
            visit_children = False
        elif isinstance(t, sp.tensor.Indexed):
Martin Bauer's avatar
Martin Bauer committed
            visit_children = False
        elif t.is_integer:
            pass
        elif isinstance(t, cast_func):
Martin Bauer's avatar
Martin Bauer committed
            visit_children = False
            visit(t.args[0])
        elif t.func is sp.Pow:
Martin Bauer's avatar
Martin Bauer committed
            if check_type(t.args[0]):
                visit_children = False
                if t.exp.is_integer and t.exp.is_number:
                    if t.exp >= 0:
                        result['muls'] += int(t.exp) - 1
                    else:
                        result['muls'] -= 1
                        result['divs'] += 1
                        result['muls'] += (-int(t.exp)) - 1
            else:
                warnings.warn("Counting operations: only integer exponents are supported in Pow, "
                              "counting will be inaccurate")
        else:
            warnings.warn("Unknown sympy node of type " + str(t.func) + " counting will be inaccurate")

Martin Bauer's avatar
Martin Bauer committed
        if visit_children:
            for a in t.args:
                visit(a)

    visit(term)
    return result
Martin Bauer's avatar
Martin Bauer committed
def count_operations_in_ast(ast) -> Dict[str, int]:
    """Counts number of operations in an abstract syntax tree, see also :func:`count_operations`"""
    from pystencils.astnodes import SympyAssignment
    result = {'adds': 0, 'muls': 0, 'divs': 0}

    def visit(node):
        if isinstance(node, SympyAssignment):
Martin Bauer's avatar
Martin Bauer committed
            r = count_operations(node.rhs)
            result['adds'] += r['adds']
            result['muls'] += r['muls']
            result['divs'] += r['divs']
        else:
            for arg in node.args:
                visit(arg)
    visit(ast)
    return result


Martin Bauer's avatar
Martin Bauer committed
def common_denominator(expr: sp.Expr) -> sp.Expr:
    """Finds least common multiple of all denominators occurring in an expression"""
    denominators = [r.q for r in expr.atoms(sp.Rational)]
    return sp.lcm(denominators)
Martin Bauer's avatar
Martin Bauer committed
def get_symmetric_part(expr: sp.Expr, symbols: Iterable[sp.Symbol]) -> sp.Expr:
Martin Bauer's avatar
Martin Bauer committed
    """
    Returns the symmetric part of a sympy expressions.

Martin Bauer's avatar
Martin Bauer committed
    Args:
        expr: sympy expression, labeled here as :math:`f`
        symbols: sequence of symbols which are considered as degrees of freedom, labeled here as :math:`x_0, x_1,...`

    Returns:
        :math:`\frac{1}{2} [ f(x_0, x_1, ..) + f(-x_0, -x_1) ]`
Martin Bauer's avatar
Martin Bauer committed
    """
Martin Bauer's avatar
Martin Bauer committed
    substitution_dict = {e: -e for e in symbols}
    return sp.Rational(1, 2) * (expr + expr.subs(substitution_dict))
Martin Bauer's avatar
Martin Bauer committed
def sort_assignments_topologically(assignments: Sequence[Assignment]) -> List[Assignment]:
    """Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs"""
    res = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in assignments])
    return [Assignment(a, b) for a, b in res]
Martin Bauer's avatar
Martin Bauer committed
class SymbolCreator:
    def __getattribute__(self, name):
        return sp.Symbol(name)