diff --git a/sympyextensions.py b/sympyextensions.py index 634a04429a0da910141760cdd70e770d907dd6b5..fc936272feab077613945aeb5e86236e4bf04bbf 100644 --- a/sympyextensions.py +++ b/sympyextensions.py @@ -135,40 +135,6 @@ def fast_subs(expression: T, substitutions: Dict, return visit(expression) -def fast_subs_and_normalize(expression, substitutions: Dict[sp.Expr, sp.Expr], - normalize: Callable[[sp.Expr], sp.Expr]) -> sp.Expr: - """Similar to fast_subs, but calls a normalization function on all substituted terms to save one AST traversal.""" - - def visit(expr): - if expr in substitutions: - return substitutions[expr], True - if not hasattr(expr, 'args'): - return expr, False - - param_list = [] - substituted = False - for a in expr.args: - replaced_expr, s = visit(a) - param_list.append(replaced_expr) - if s: - substituted = True - - if not param_list: - return expr, False - else: - if substituted: - result, _ = visit(normalize(expr.func(*param_list))) - return result, True - else: - return expr.func(*param_list), False - - if len(substitutions) == 0: - return expression - else: - res, _ = visit(expression) - return res - - def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr, required_match_replacement: Optional[Union[int, float]] = 0.5, required_match_original: Optional[Union[int, float]] = None) -> sp.Expr: @@ -186,6 +152,8 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr, 3*x + 3*y + z >>> subs_additive(3*x + 3*y + z, replacement=k, subexpression=x+y+z, required_match_original=0.5) 3*k - 2*z + >>> subs_additive(3*x + 3*y + z, replacement=k, subexpression=x+y+z, required_match_original=2) + 3*k - 2*z Args: expr: input expression @@ -401,15 +369,6 @@ def complete_the_squares_in_exp(expr: sp.Expr, symbols_to_complete: Sequence[sp. return result -def pow2mul(expr): - """Convert integer powers in an expression to Muls, like a**2 => a*a. """ - powers = list(expr.atoms(sp.Pow)) - if any(not e.is_Integer for b, e in (i.as_base_exp() for i in powers)): - raise ValueError("A power contains a non-integer exponent") - substitutions = zip(powers, (sp.Mul(*[b]*e, evaluate=False) for b, e in (i.as_base_exp() for i in powers))) - return expr.subs(substitutions) - - def extract_most_common_factor(term): """Processes a sum of fractions: determines the most common factor and splits term in common factor and rest""" coefficient_dict = term.as_coefficients_dict()