diff --git a/sympyextensions.py b/sympyextensions.py index 79fe024487b775b91d2e00bb511c607305e1939b..d14630b9dca638915d43d8363b88facddc3e52ff 100644 --- a/sympyextensions.py +++ b/sympyextensions.py @@ -97,6 +97,37 @@ def fastSubs(term, subsDict): return visit(term) +def fastSubsWithNormalize(term, subsDict, normalizeFunc): + def visit(expr): + if expr in subsDict: + return subsDict[expr], True + if not hasattr(expr, 'args'): + return expr, False + + paramList = [] + substituted = False + for a in expr.args: + replacedExpr, s = visit(a) + paramList.append(replacedExpr) + if s: + substituted = True + + if not paramList: + return expr, False + else: + if substituted: + result, _ = visit(normalizeFunc(expr.func(*paramList))) + return result, True + else: + return expr.func(*paramList), False + + if len(subsDict) == 0: + return term + else: + res, _ = visit(term) + return res + + def replaceAdditive(expr, replacement, subExpression, requiredMatchReplacement=0.5, requiredMatchOriginal=None): """ Transformation for replacing a given subexpression inside a sum