From 0a29a147e0fc931625ff8c38f19fc9373963f1ac Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Mon, 23 Jan 2017 09:52:23 +0100
Subject: [PATCH] boundary generatlization

---
 equationcollection/equationcollection.py | 15 ++++++++++++++
 equationcollection/simplifications.py    |  8 +++++---
 sympyextensions.py                       | 25 ++++++++++++++++++++++--
 3 files changed, 43 insertions(+), 5 deletions(-)

diff --git a/equationcollection/equationcollection.py b/equationcollection/equationcollection.py
index f6cd8be31..c3a6e9294 100644
--- a/equationcollection/equationcollection.py
+++ b/equationcollection/equationcollection.py
@@ -226,6 +226,21 @@ class EquationCollection:
         allLhs = [eq.lhs for eq in self.mainEquations]
         return self.extract(allLhs)
 
+    def insertSubexpression(self, symbol):
+        newSubexpressions = []
+        subsDict = None
+        for se in self.subexpressions:
+            if se.lhs == symbol:
+                subsDict = {se.lhs: se.rhs}
+            else:
+                newSubexpressions.append(se)
+        if subsDict is None:
+            return self
+
+        newSubexpressions = [sp.Eq(eq.lhs, fastSubs(eq.rhs, subsDict)) for eq in newSubexpressions]
+        newEqs = [sp.Eq(eq.lhs, fastSubs(eq.rhs, subsDict)) for eq in self.mainEquations]
+        return self.copy(newEqs, newSubexpressions)
+
     def insertSubexpressions(self, subexpressionSymbolsToKeep=set()):
         """Returns a new equation collection by inserting all subexpressions into the main equations"""
         if len(self.subexpressions) == 0:
diff --git a/equationcollection/simplifications.py b/equationcollection/simplifications.py
index 328373599..ebe3cc79c 100644
--- a/equationcollection/simplifications.py
+++ b/equationcollection/simplifications.py
@@ -25,13 +25,13 @@ def sympyCSE(equationCollection):
 
 def applyOnAllEquations(equationCollection, operation):
     """Applies sympy expand operation to all equations in collection"""
-    result = [operation(s) for s in equationCollection.mainEquations]
+    result = [sp.Eq(eq.lhs, operation(eq.rhs)) for eq in equationCollection.mainEquations]
     return equationCollection.copy(result)
 
 
 def applyOnAllSubexpressions(equationCollection, operation):
-    return equationCollection.copy(equationCollection.mainEquations,
-                                   [operation(s) for s in equationCollection.subexpressions])
+    result = [sp.Eq(eq.lhs, operation(eq.rhs)) for eq in equationCollection.subexpressions]
+    return equationCollection.copy(equationCollection.mainEquations, result)
 
 
 def subexpressionSubstitutionInExistingSubexpressions(equationCollection):
@@ -60,6 +60,8 @@ def subexpressionSubstitutionInMainEquations(equationCollection):
 
 
 def addSubexpressionsForDivisions(equationCollection):
+    """Introduces subexpressions for all divisions which have no constant in the denominator.
+    e.g.  :math:`\frac{1}{x}` is replaced, :math:`\frac{1}{3}` is not replaced."""
     divisors = set()
 
     def searchDivisors(term):
diff --git a/sympyextensions.py b/sympyextensions.py
index 8daa67136..dc5a0073b 100644
--- a/sympyextensions.py
+++ b/sympyextensions.py
@@ -260,12 +260,33 @@ def extractMostCommonFactor(term):
 
     coeffDict = term.as_coefficients_dict()
     counter = Counter([Abs(v) for v in coeffDict.values()])
-    commonFactor, occurances = max(counter.items(), key=operator.itemgetter(1))
-    if occurances == 1 and (1 in counter):
+    commonFactor, occurrences = max(counter.items(), key=operator.itemgetter(1))
+    if occurrences == 1 and (1 in counter):
         commonFactor = 1
     return commonFactor, term / commonFactor
 
 
+def mostCommonTermFactorization(term):
+    commonFactor, term = extractMostCommonFactor(term)
+
+    factorization = sp.factor(term)
+    if factorization.is_Mul:
+        symbolsInFactorization = []
+        constantsInFactorization = 1
+        for arg in factorization.args:
+            if len(arg.atoms(sp.Symbol)) == 0:
+                constantsInFactorization *= arg
+            else:
+                symbolsInFactorization.append(arg)
+        if len(symbolsInFactorization) <= 1:
+            return sp.Mul(commonFactor, term, evaluate=False)
+        else:
+            return sp.Mul(commonFactor, *symbolsInFactorization[:-1],
+                          constantsInFactorization * symbolsInFactorization[-1])
+    else:
+        return sp.Mul(commonFactor, term, evaluate=False)
+
+
 def countNumberOfOperations(term):
     """
     Counts the number of additions, multiplications and division
-- 
GitLab