Commit 9f071fdd authored by Martin Bauer's avatar Martin Bauer
Browse files

pystencils: Assignment instead of sympy.Eq

- Previously sympy.Eq was used to represent assignments. However Eq
  represents equality not assignment. This means that sometimes sympy
  "simplified" an equation like a = a  to True,
-> replaced sp.Eq by pystencils.Assignment everywhere
- renamed EquationCollection to AssignmentCollection
parent eadeaadf
...@@ -3,14 +3,13 @@ from pystencils.data_types import TypedSymbol ...@@ -3,14 +3,13 @@ from pystencils.data_types import TypedSymbol
from pystencils.slicing import makeSlice from pystencils.slicing import makeSlice
from pystencils.kernelcreation import createKernel, createIndexedKernel from pystencils.kernelcreation import createKernel, createIndexedKernel
from pystencils.display_utils import showCode, toDot from pystencils.display_utils import showCode, toDot
from pystencils.equationcollection import EquationCollection from pystencils.assignment_collection import AssignmentCollection
from sympy.codegen.ast import Assignment as Assign from pystencils.assignment import Assignment
__all__ = ['Field', 'FieldType', 'extractCommonSubexpressions', __all__ = ['Field', 'FieldType', 'extractCommonSubexpressions',
'TypedSymbol', 'TypedSymbol',
'makeSlice', 'makeSlice',
'createKernel', 'createIndexedKernel', 'createKernel', 'createIndexedKernel',
'showCode', 'toDot', 'showCode', 'toDot',
'EquationCollection', 'AssignmentCollection',
'Assign'] 'Assignment']
from sympy.codegen.ast import Assignment
from sympy.printing.latex import LatexPrinter
__all__ = ['Assignment']
def print_assignment_latex(printer, expr):
"""sympy cannot print Assignments as Latex. Thus, this function is added to the sympy Latex printer"""
printed_lhs = printer.doprint(expr.lhs)
printed_rhs = printer.doprint(expr.rhs)
return f"{printed_lhs} \leftarrow {printed_rhs}"
LatexPrinter._print_Assignment = print_assignment_latex
from pystencils.assignment_collection.assignment_collection import AssignmentCollection
from pystencils.assignment_collection.simplificationstrategy import SimplificationStrategy
import sympy as sp import sympy as sp
from copy import copy from copy import copy
from pystencils.assignment import Assignment
from pystencils.sympyextensions import fastSubs, countNumberOfOperations, sortEquationsTopologically from pystencils.sympyextensions import fastSubs, countNumberOfOperations, sortEquationsTopologically
class EquationCollection(object): class AssignmentCollection(object):
""" """
A collection of equations with subexpression definitions, also represented as equations, A collection of equations with subexpression definitions, also represented as equations,
that are used in the main equations. EquationCollections can be passed to simplification methods. that are used in the main equations. AssignmentCollection can be passed to simplification methods.
These simplification methods can change the subexpressions, but the number and These simplification methods can change the subexpressions, but the number and
left hand side of the main equations themselves is not altered. left hand side of the main equations themselves is not altered.
Additionally a dictionary of simplification hints is stored, which are set by the functions that create Additionally a dictionary of simplification hints is stored, which are set by the functions that create
equation collections to transport information to the simplification system. equation collections to transport information to the simplification system.
:ivar mainEquations: list of sympy equations :ivar mainAssignments: list of sympy equations
:ivar subexpressions: list of sympy equations defining subexpressions used in main equations :ivar subexpressions: list of sympy equations defining subexpressions used in main equations
:ivar simplificationHints: dictionary that is used to annotate the equation collection with hints that are :ivar simplificationHints: dictionary that is used to annotate the equation collection with hints that are
used by the simplification system. See documentation of the simplification rules for used by the simplification system. See documentation of the simplification rules for
...@@ -22,7 +23,7 @@ class EquationCollection(object): ...@@ -22,7 +23,7 @@ class EquationCollection(object):
# ----------------------------------------- Creation --------------------------------------------------------------- # ----------------------------------------- Creation ---------------------------------------------------------------
def __init__(self, equations, subExpressions, simplificationHints=None, subexpressionSymbolNameGenerator=None): def __init__(self, equations, subExpressions, simplificationHints=None, subexpressionSymbolNameGenerator=None):
self.mainEquations = equations self.mainAssignments = equations
self.subexpressions = subExpressions self.subexpressions = subExpressions
if simplificationHints is None: if simplificationHints is None:
...@@ -39,15 +40,15 @@ class EquationCollection(object): ...@@ -39,15 +40,15 @@ class EquationCollection(object):
def mainTerms(self): def mainTerms(self):
return [] return []
def copy(self, mainEquations=None, subexpressions=None): def copy(self, mainAssignments=None, subexpressions=None):
res = copy(self) res = copy(self)
res.simplificationHints = self.simplificationHints.copy() res.simplificationHints = self.simplificationHints.copy()
res.subexpressionSymbolNameGenerator = copy(self.subexpressionSymbolNameGenerator) res.subexpressionSymbolNameGenerator = copy(self.subexpressionSymbolNameGenerator)
if mainEquations is not None: if mainAssignments is not None:
res.mainEquations = mainEquations res.mainAssignments = mainAssignments
else: else:
res.mainEquations = self.mainEquations.copy() res.mainAssignments = self.mainAssignments.copy()
if subexpressions is not None: if subexpressions is not None:
res.subexpressions = subexpressions res.subexpressions = subexpressions
...@@ -64,13 +65,13 @@ class EquationCollection(object): ...@@ -64,13 +65,13 @@ class EquationCollection(object):
""" """
if substituteOnLhs: if substituteOnLhs:
newSubexpressions = [fastSubs(eq, substitutionDict) for eq in self.subexpressions] newSubexpressions = [fastSubs(eq, substitutionDict) for eq in self.subexpressions]
newEquations = [fastSubs(eq, substitutionDict) for eq in self.mainEquations] newEquations = [fastSubs(eq, substitutionDict) for eq in self.mainAssignments]
else: else:
newSubexpressions = [sp.Eq(eq.lhs, fastSubs(eq.rhs, substitutionDict)) for eq in self.subexpressions] newSubexpressions = [Assignment(eq.lhs, fastSubs(eq.rhs, substitutionDict)) for eq in self.subexpressions]
newEquations = [sp.Eq(eq.lhs, fastSubs(eq.rhs, substitutionDict)) for eq in self.mainEquations] newEquations = [Assignment(eq.lhs, fastSubs(eq.rhs, substitutionDict)) for eq in self.mainAssignments]
if addSubstitutionsAsSubexpressions: if addSubstitutionsAsSubexpressions:
newSubexpressions = [sp.Eq(b, a) for a, b in substitutionDict.items()] + newSubexpressions newSubexpressions = [Assignment(b, a) for a, b in substitutionDict.items()] + newSubexpressions
newSubexpressions = sortEquationsTopologically(newSubexpressions) newSubexpressions = sortEquationsTopologically(newSubexpressions)
return self.copy(newEquations, newSubexpressions) return self.copy(newEquations, newSubexpressions)
...@@ -86,7 +87,7 @@ class EquationCollection(object): ...@@ -86,7 +87,7 @@ class EquationCollection(object):
@property @property
def allEquations(self): def allEquations(self):
"""Subexpression and main equations in one sequence""" """Subexpression and main equations in one sequence"""
return self.subexpressions + self.mainEquations return self.subexpressions + self.mainAssignments
@property @property
def freeSymbols(self): def freeSymbols(self):
...@@ -100,30 +101,30 @@ class EquationCollection(object): ...@@ -100,30 +101,30 @@ class EquationCollection(object):
def boundSymbols(self): def boundSymbols(self):
"""Set of all symbols which occur on left-hand-sides i.e. all symbols which are defined.""" """Set of all symbols which occur on left-hand-sides i.e. all symbols which are defined."""
boundSymbolsSet = set([eq.lhs for eq in self.allEquations]) boundSymbolsSet = set([eq.lhs for eq in self.allEquations])
assert len(boundSymbolsSet) == len(self.subexpressions) + len(self.mainEquations), \ assert len(boundSymbolsSet) == len(self.subexpressions) + len(self.mainAssignments), \
"Not in SSA form - same symbol assigned multiple times" "Not in SSA form - same symbol assigned multiple times"
return boundSymbolsSet return boundSymbolsSet
@property @property
def definedSymbols(self): def definedSymbols(self):
"""All symbols that occur as left-hand-sides of the main equations""" """All symbols that occur as left-hand-sides of the main equations"""
return set([eq.lhs for eq in self.mainEquations]) return set([eq.lhs for eq in self.mainAssignments])
@property @property
def operationCount(self): def operationCount(self):
"""See :func:`countNumberOfOperations` """ """See :func:`countNumberOfOperations` """
return countNumberOfOperations(self.allEquations, onlyType=None) return countNumberOfOperations(self.allEquations, onlyType=None)
def get(self, symbols, fromMainEquationsOnly=False): def get(self, symbols, frommainAssignmentsOnly=False):
"""Return the equations which have symbols as left hand sides""" """Return the equations which have symbols as left hand sides"""
if not hasattr(symbols, "__len__"): if not hasattr(symbols, "__len__"):
symbols = list(symbols) symbols = list(symbols)
symbols = set(symbols) symbols = set(symbols)
if not fromMainEquationsOnly: if not frommainAssignmentsOnly:
eqsToSearchIn = self.allEquations eqsToSearchIn = self.allEquations
else: else:
eqsToSearchIn = self.mainEquations eqsToSearchIn = self.mainAssignments
return [eq for eq in eqsToSearchIn if eq.lhs in symbols] return [eq for eq in eqsToSearchIn if eq.lhs in symbols]
...@@ -145,19 +146,19 @@ class EquationCollection(object): ...@@ -145,19 +146,19 @@ class EquationCollection(object):
if len(self.subexpressions) > 0: if len(self.subexpressions) > 0:
result += "<div>Subexpressions:</div>" result += "<div>Subexpressions:</div>"
result += makeHtmlEquationTable(self.subexpressions) result += makeHtmlEquationTable(self.subexpressions)
result += "<div>Main Equations:</div>" result += "<div>Main Assignments:</div>"
result += makeHtmlEquationTable(self.mainEquations) result += makeHtmlEquationTable(self.mainAssignments)
return result return result
def __repr__(self): def __repr__(self):
return "Equation Collection for " + ",".join([str(eq.lhs) for eq in self.mainEquations]) return "Equation Collection for " + ",".join([str(eq.lhs) for eq in self.mainAssignments])
def __str__(self): def __str__(self):
result = "Subexpressions\n" result = "Subexpressions\n"
for eq in self.subexpressions: for eq in self.subexpressions:
result += str(eq) + "\n" result += str(eq) + "\n"
result += "Main Equations\n" result += "Main Assignments\n"
for eq in self.mainEquations: for eq in self.mainAssignments:
result += str(eq) + "\n" result += str(eq) + "\n"
return result return result
...@@ -165,8 +166,8 @@ class EquationCollection(object): ...@@ -165,8 +166,8 @@ class EquationCollection(object):
def merge(self, other): def merge(self, other):
"""Returns a new collection which contains self and other. Subexpressions are renamed if they clash.""" """Returns a new collection which contains self and other. Subexpressions are renamed if they clash."""
ownDefs = set([e.lhs for e in self.mainEquations]) ownDefs = set([e.lhs for e in self.mainAssignments])
otherDefs = set([e.lhs for e in other.mainEquations]) otherDefs = set([e.lhs for e in other.mainAssignments])
assert len(ownDefs.intersection(otherDefs)) == 0, "Cannot merge, since both collection define the same symbols" assert len(ownDefs.intersection(otherDefs)) == 0, "Cannot merge, since both collection define the same symbols"
ownSubexpressionSymbols = {e.lhs: e.rhs for e in self.subexpressions} ownSubexpressionSymbols = {e.lhs: e.rhs for e in self.subexpressions}
...@@ -180,14 +181,14 @@ class EquationCollection(object): ...@@ -180,14 +181,14 @@ class EquationCollection(object):
else: else:
# different definition - a new name has to be introduced # different definition - a new name has to be introduced
newLhs = next(self.subexpressionSymbolNameGenerator) newLhs = next(self.subexpressionSymbolNameGenerator)
newEq = sp.Eq(newLhs, fastSubs(otherSubexpressionEq.rhs, substitutionDict)) newEq = Assignment(newLhs, fastSubs(otherSubexpressionEq.rhs, substitutionDict))
processedOtherSubexpressionEquations.append(newEq) processedOtherSubexpressionEquations.append(newEq)
substitutionDict[otherSubexpressionEq.lhs] = newLhs substitutionDict[otherSubexpressionEq.lhs] = newLhs
else: else:
processedOtherSubexpressionEquations.append(fastSubs(otherSubexpressionEq, substitutionDict)) processedOtherSubexpressionEquations.append(fastSubs(otherSubexpressionEq, substitutionDict))
processedOtherMainEquations = [fastSubs(eq, substitutionDict) for eq in other.mainEquations] processedOthermainAssignments = [fastSubs(eq, substitutionDict) for eq in other.mainAssignments]
return self.copy(self.mainEquations + processedOtherMainEquations, return self.copy(self.mainAssignments + processedOthermainAssignments,
self.subexpressions + processedOtherSubexpressionEquations) self.subexpressions + processedOtherSubexpressionEquations)
def getDependentSymbols(self, symbolSequence): def getDependentSymbols(self, symbolSequence):
...@@ -226,28 +227,28 @@ class EquationCollection(object): ...@@ -226,28 +227,28 @@ class EquationCollection(object):
newEquations.append(eq) newEquations.append(eq)
newSubExpr = [eq for eq in self.subexpressions if eq.lhs in dependentSymbols and eq.lhs not in symbolsToExtract] newSubExpr = [eq for eq in self.subexpressions if eq.lhs in dependentSymbols and eq.lhs not in symbolsToExtract]
return EquationCollection(newEquations, newSubExpr) return AssignmentCollection(newEquations, newSubExpr)
def newWithoutUnusedSubexpressions(self): def newWithoutUnusedSubexpressions(self):
"""Returns a new equation collection containing only the subexpressions that """Returns a new equation collection containing only the subexpressions that
are used/referenced in the equations""" are used/referenced in the equations"""
allLhs = [eq.lhs for eq in self.mainEquations] allLhs = [eq.lhs for eq in self.mainAssignments]
return self.extract(allLhs) return self.extract(allLhs)
def appendToSubexpressions(self, rhs, lhs=None, topologicalSort=True): def appendToSubexpressions(self, rhs, lhs=None, topologicalSort=True):
if lhs is None: if lhs is None:
lhs = sp.Dummy() lhs = sp.Dummy()
eq = sp.Eq(lhs, rhs) eq = Assignment(lhs, rhs)
self.subexpressions.append(eq) self.subexpressions.append(eq)
if topologicalSort: if topologicalSort:
self.topologicalSort(subexpressions=True, mainEquations=False) self.topologicalSort(subexpressions=True, mainAssignments=False)
return lhs return lhs
def topologicalSort(self, subexpressions=True, mainEquations=True): def topologicalSort(self, subexpressions=True, mainAssignments=True):
if subexpressions: if subexpressions:
self.subexpressions = sortEquationsTopologically(self.subexpressions) self.subexpressions = sortEquationsTopologically(self.subexpressions)
if mainEquations: if mainAssignments:
self.mainEquations = sortEquationsTopologically(self.mainEquations) self.mainAssignments = sortEquationsTopologically(self.mainAssignments)
def insertSubexpression(self, symbol): def insertSubexpression(self, symbol):
newSubexpressions = [] newSubexpressions = []
...@@ -260,8 +261,8 @@ class EquationCollection(object): ...@@ -260,8 +261,8 @@ class EquationCollection(object):
if subsDict is None: if subsDict is None:
return self return self
newSubexpressions = [sp.Eq(eq.lhs, fastSubs(eq.rhs, subsDict)) for eq in newSubexpressions] newSubexpressions = [Assignment(eq.lhs, fastSubs(eq.rhs, subsDict)) for eq in newSubexpressions]
newEqs = [sp.Eq(eq.lhs, fastSubs(eq.rhs, subsDict)) for eq in self.mainEquations] newEqs = [Assignment(eq.lhs, fastSubs(eq.rhs, subsDict)) for eq in self.mainAssignments]
return self.copy(newEqs, newSubexpressions) return self.copy(newEqs, newSubexpressions)
def insertSubexpressions(self, subexpressionSymbolsToKeep=set()): def insertSubexpressions(self, subexpressionSymbolsToKeep=set()):
...@@ -286,7 +287,7 @@ class EquationCollection(object): ...@@ -286,7 +287,7 @@ class EquationCollection(object):
else: else:
subsDict[subExpr[i].lhs] = subExpr[i].rhs subsDict[subExpr[i].lhs] = subExpr[i].rhs
newEq = [fastSubs(eq, subsDict) for eq in self.mainEquations] newEq = [fastSubs(eq, subsDict) for eq in self.mainAssignments]
return self.copy(newEq, keptSubexpressions) return self.copy(newEq, keptSubexpressions)
def lambdify(self, symbols, module=None, fixedSymbols={}): def lambdify(self, symbols, module=None, fixedSymbols={}):
...@@ -296,7 +297,7 @@ class EquationCollection(object): ...@@ -296,7 +297,7 @@ class EquationCollection(object):
:param module: same as sympy.lambdify paramter of same same, i.e. which module to use e.g. 'numpy' :param module: same as sympy.lambdify paramter of same same, i.e. which module to use e.g. 'numpy'
:param fixedSymbols: dictionary with substitutions, that are applied before lambdification :param fixedSymbols: dictionary with substitutions, that are applied before lambdification
""" """
eqs = self.copyWithSubstitutionsApplied(fixedSymbols).insertSubexpressions().mainEquations eqs = self.copyWithSubstitutionsApplied(fixedSymbols).insertSubexpressions().mainAssignments
lambdas = {eq.lhs: sp.lambdify(symbols, eq.rhs, module) for eq in eqs} lambdas = {eq.lhs: sp.lambdify(symbols, eq.rhs, module) for eq in eqs}
def f(*args, **kwargs): def f(*args, **kwargs):
......
import sympy as sp import sympy as sp
from pystencils.equationcollection.equationcollection import EquationCollection from pystencils import Assignment, AssignmentCollection
from pystencils.sympyextensions import replaceAdditive from pystencils.sympyextensions import replaceAdditive
def sympyCseOnEquationList(eqs): def sympyCseOnEquationList(eqs):
ec = EquationCollection(eqs, []) ec = AssignmentCollection(eqs, [])
return sympyCSE(ec).allEquations return sympyCSE(ec).allEquations
def sympyCSE(equationCollection): def sympyCSE(assignment_collection):
""" """
Searches for common subexpressions inside the equation collection, in both the existing subexpressions as well Searches for common subexpressions inside the equation collection, in both the existing subexpressions as well
as the equations themselves. It uses the sympy subexpression detection to do this. Return a new equation collection as the equations themselves. It uses the sympy subexpression detection to do this. Return a new equation collection
with the additional subexpressions found with the additional subexpressions found
""" """
symbolGen = equationCollection.subexpressionSymbolNameGenerator symbolGen = assignment_collection.subexpressionSymbolNameGenerator
replacements, newEq = sp.cse(equationCollection.subexpressions + equationCollection.mainEquations, replacements, newEq = sp.cse(assignment_collection.subexpressions + assignment_collection.mainAssignments,
symbols=symbolGen) symbols=symbolGen)
replacementEqs = [sp.Eq(*r) for r in replacements] replacementEqs = [Assignment(*r) for r in replacements]
modifiedSubexpressions = newEq[:len(equationCollection.subexpressions)] modifiedSubexpressions = newEq[:len(assignment_collection.subexpressions)]
modifiedUpdateEquations = newEq[len(equationCollection.subexpressions):] modifiedUpdateEquations = newEq[len(assignment_collection.subexpressions):]
newSubexpressions = replacementEqs + modifiedSubexpressions newSubexpressions = replacementEqs + modifiedSubexpressions
topologicallySortedPairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in newSubexpressions]) topologicallySortedPairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in newSubexpressions])
newSubexpressions = [sp.Eq(a[0], a[1]) for a in topologicallySortedPairs] newSubexpressions = [Assignment(a[0], a[1]) for a in topologicallySortedPairs]
return equationCollection.copy(modifiedUpdateEquations, newSubexpressions) return assignment_collection.copy(modifiedUpdateEquations, newSubexpressions)
def applyOnAllEquations(equationCollection, operation): def applyOnAllEquations(assignment_collection, operation):
"""Applies sympy expand operation to all equations in collection""" """Applies sympy expand operation to all equations in collection"""
result = [sp.Eq(eq.lhs, operation(eq.rhs)) for eq in equationCollection.mainEquations] result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.mainAssignments]
return equationCollection.copy(result) return assignment_collection.copy(result)
def applyOnAllSubexpressions(equationCollection, operation): def applyOnAllSubexpressions(assignment_collection, operation):
result = [sp.Eq(eq.lhs, operation(eq.rhs)) for eq in equationCollection.subexpressions] result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.subexpressions]
return equationCollection.copy(equationCollection.mainEquations, result) return assignment_collection.copy(assignment_collection.mainAssignments, result)
def subexpressionSubstitutionInExistingSubexpressions(equationCollection): def subexpressionSubstitutionInExistingSubexpressions(assignment_collection):
"""Goes through the subexpressions list and replaces the term in the following subexpressions""" """Goes through the subexpressions list and replaces the term in the following subexpressions"""
result = [] result = []
for outerCtr, s in enumerate(equationCollection.subexpressions): for outerCtr, s in enumerate(assignment_collection.subexpressions):
newRhs = s.rhs newRhs = s.rhs
for innerCtr in range(outerCtr): for innerCtr in range(outerCtr):
subExpr = equationCollection.subexpressions[innerCtr] subExpr = assignment_collection.subexpressions[innerCtr]
newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0) newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0)
newRhs = newRhs.subs(subExpr.rhs, subExpr.lhs) newRhs = newRhs.subs(subExpr.rhs, subExpr.lhs)
result.append(sp.Eq(s.lhs, newRhs)) result.append(Assignment(s.lhs, newRhs))
return equationCollection.copy(equationCollection.mainEquations, result) return assignment_collection.copy(assignment_collection.mainAssignments, result)
def subexpressionSubstitutionInMainEquations(equationCollection): def subexpressionSubstitutionInmainAssignments(assignment_collection):
"""Replaces already existing subexpressions in the equations of the equationCollection""" """Replaces already existing subexpressions in the equations of the assignment_collection"""
result = [] result = []
for s in equationCollection.mainEquations: for s in assignment_collection.mainAssignments:
newRhs = s.rhs newRhs = s.rhs
for subExpr in equationCollection.subexpressions: for subExpr in assignment_collection.subexpressions:
newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0) newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0)
result.append(sp.Eq(s.lhs, newRhs)) result.append(Assignment(s.lhs, newRhs))
return equationCollection.copy(result) return assignment_collection.copy(result)
def addSubexpressionsForDivisions(equationCollection): def addSubexpressionsForDivisions(assignment_collection):
"""Introduces subexpressions for all divisions which have no constant in the denominator. """Introduces subexpressions for all divisions which have no constant in the denominator.
e.g. :math:`\frac{1}{x}` is replaced, :math:`\frac{1}{3}` is not replaced.""" e.g. :math:`\frac{1}{x}` is replaced, :math:`\frac{1}{3}` is not replaced."""
divisors = set() divisors = set()
...@@ -79,9 +79,9 @@ def addSubexpressionsForDivisions(equationCollection): ...@@ -79,9 +79,9 @@ def addSubexpressionsForDivisions(equationCollection):
for a in term.args: for a in term.args:
searchDivisors(a) searchDivisors(a)
for eq in equationCollection.allEquations: for eq in assignment_collection.allEquations:
searchDivisors(eq.rhs) searchDivisors(eq.rhs)
newSymbolGen = equationCollection.subexpressionSymbolNameGenerator newSymbolGen = assignment_collection.subexpressionSymbolNameGenerator
substitutions = {divisor: newSymbol for newSymbol, divisor in zip(newSymbolGen, divisors)} substitutions = {divisor: newSymbol for newSymbol, divisor in zip(newSymbolGen, divisors)}
return equationCollection.copyWithSubstitutionsApplied(substitutions, True) return assignment_collection.copyWithSubstitutionsApplied(substitutions, True)
...@@ -30,11 +30,11 @@ class SimplificationStrategy(object): ...@@ -30,11 +30,11 @@ class SimplificationStrategy(object):
updateRule = t(updateRule) updateRule = t(updateRule)
return updateRule return updateRule
def __call__(self, equationCollection): def __call__(self, assignment_collection):
"""Same as apply""" """Same as apply"""
return self.apply(equationCollection) return self.apply(assignment_collection)
def createSimplificationReport(self, equationCollection): def createSimplificationReport(self, assignment_collection):
""" """
Returns a simplification report containing the number of operations at each simplification stage, together Returns a simplification report containing the number of operations at each simplification stage, together
with the run-time the simplification took. with the run-time the simplification took.
...@@ -72,25 +72,25 @@ class SimplificationStrategy(object): ...@@ -72,25 +72,25 @@ class SimplificationStrategy(object):
import timeit import timeit
report = Report() report = Report()
op = equationCollection.operationCount op = assignment_collection.operationCount
total = op['adds'] + op['muls'] + op['divs'] total = op['adds'] + op['muls'] + op['divs']
report.add(ReportElement("OriginalTerm", '-', op['adds'], op['muls'], op['divs'], total)) report.add(ReportElement("OriginalTerm", '-', op['adds'], op['muls'], op['divs'], total))
for t in self._rules: for t in self._rules:
startTime = timeit.default_timer() startTime = timeit.default_timer()
equationCollection = t(equationCollection) assignment_collection = t(assignment_collection)
endTime = timeit.default_timer() endTime = timeit.default_timer()
op = equationCollection.operationCount op = assignment_collection.operationCount
timeStr = "%.2f ms" % ((endTime - startTime) * 1000,) timeStr = "%.2f ms" % ((endTime - startTime) * 1000,)
total = op['adds'] + op['muls'] + op['divs'] total = op['adds'] + op['muls'] + op['divs']
report.add(ReportElement(t.__name__, timeStr, op['adds'], op['muls'], op['divs'], total))