Skip to content
Snippets Groups Projects
Commit 0a29a147 authored by Martin Bauer's avatar Martin Bauer
Browse files

boundary generatlization

parent d6d843fb
No related merge requests found
......@@ -226,6 +226,21 @@ class EquationCollection:
allLhs = [eq.lhs for eq in self.mainEquations]
return self.extract(allLhs)
def insertSubexpression(self, symbol):
newSubexpressions = []
subsDict = None
for se in self.subexpressions:
if se.lhs == symbol:
subsDict = {se.lhs: se.rhs}
else:
newSubexpressions.append(se)
if subsDict is None:
return self
newSubexpressions = [sp.Eq(eq.lhs, fastSubs(eq.rhs, subsDict)) for eq in newSubexpressions]
newEqs = [sp.Eq(eq.lhs, fastSubs(eq.rhs, subsDict)) for eq in self.mainEquations]
return self.copy(newEqs, newSubexpressions)
def insertSubexpressions(self, subexpressionSymbolsToKeep=set()):
"""Returns a new equation collection by inserting all subexpressions into the main equations"""
if len(self.subexpressions) == 0:
......
......@@ -25,13 +25,13 @@ def sympyCSE(equationCollection):
def applyOnAllEquations(equationCollection, operation):
"""Applies sympy expand operation to all equations in collection"""
result = [operation(s) for s in equationCollection.mainEquations]
result = [sp.Eq(eq.lhs, operation(eq.rhs)) for eq in equationCollection.mainEquations]
return equationCollection.copy(result)
def applyOnAllSubexpressions(equationCollection, operation):
return equationCollection.copy(equationCollection.mainEquations,
[operation(s) for s in equationCollection.subexpressions])
result = [sp.Eq(eq.lhs, operation(eq.rhs)) for eq in equationCollection.subexpressions]
return equationCollection.copy(equationCollection.mainEquations, result)
def subexpressionSubstitutionInExistingSubexpressions(equationCollection):
......@@ -60,6 +60,8 @@ def subexpressionSubstitutionInMainEquations(equationCollection):
def addSubexpressionsForDivisions(equationCollection):
"""Introduces subexpressions for all divisions which have no constant in the denominator.
e.g. :math:`\frac{1}{x}` is replaced, :math:`\frac{1}{3}` is not replaced."""
divisors = set()
def searchDivisors(term):
......
......@@ -260,12 +260,33 @@ def extractMostCommonFactor(term):
coeffDict = term.as_coefficients_dict()
counter = Counter([Abs(v) for v in coeffDict.values()])
commonFactor, occurances = max(counter.items(), key=operator.itemgetter(1))
if occurances == 1 and (1 in counter):
commonFactor, occurrences = max(counter.items(), key=operator.itemgetter(1))
if occurrences == 1 and (1 in counter):
commonFactor = 1
return commonFactor, term / commonFactor
def mostCommonTermFactorization(term):
commonFactor, term = extractMostCommonFactor(term)
factorization = sp.factor(term)
if factorization.is_Mul:
symbolsInFactorization = []
constantsInFactorization = 1
for arg in factorization.args:
if len(arg.atoms(sp.Symbol)) == 0:
constantsInFactorization *= arg
else:
symbolsInFactorization.append(arg)
if len(symbolsInFactorization) <= 1:
return sp.Mul(commonFactor, term, evaluate=False)
else:
return sp.Mul(commonFactor, *symbolsInFactorization[:-1],
constantsInFactorization * symbolsInFactorization[-1])
else:
return sp.Mul(commonFactor, term, evaluate=False)
def countNumberOfOperations(term):
"""
Counts the number of additions, multiplications and division
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment