diff --git a/equationcollection/__init__.py b/equationcollection/__init__.py index 295aa35bbd24c1861bec5a56cc7a60896d8fe08f..434f02e213073ec556c9d70b5ed9091283533347 100644 --- a/equationcollection/__init__.py +++ b/equationcollection/__init__.py @@ -1 +1,2 @@ from pystencils.equationcollection.equationcollection import EquationCollection +from pystencils.equationcollection.simplificationstrategy import SimplificationStrategy diff --git a/equationcollection/equationcollection.py b/equationcollection/equationcollection.py index fc29cbf982306e9bf71c00b5d07304b4dd78756e..b3c9cc33e566b66d380083b3dba9434614a1ff20 100644 --- a/equationcollection/equationcollection.py +++ b/equationcollection/equationcollection.py @@ -1,5 +1,5 @@ import sympy as sp -from pystencils.transformations import fastSubs +from pystencils.sympyextensions import fastSubs, countNumberOfOperations class EquationCollection: @@ -79,17 +79,33 @@ class EquationCollection: """All symbols that occur as left-hand-sides of the main equations""" return set([eq.lhs for eq in self.mainEquations]) + @property + def operationCount(self): + """See :func:`countNumberOfOperations` """ + return countNumberOfOperations(self.allEquations) + + def get(self, symbols, fromMainEquationsOnly=False): + """Return the equations which have symbols as left hand sides""" + if not hasattr(symbols, "__len__"): + symbols = list(symbols) + symbols = set(symbols) + + if not fromMainEquationsOnly: + eqsToSearchIn = self.allEquations + else: + eqsToSearchIn = self.mainEquations + + return [eq for eq in eqsToSearchIn if eq.lhs in symbols] + # ----------------------------------------- Display and Printing ------------------------------------------------- def _repr_html_(self): def makeHtmlEquationTable(equations): noBorder = 'style="border:none"' htmlTable = '<table style="border:none; width: 100%; ">' - line = '<tr {nb}> <td {nb}>${lhs}$</td> <td {nb}>$=$</td> ' \ - '<td style="border:none; width: 100%;">${rhs}$</td> </tr>' + line = '<tr {nb}> <td {nb}>$${eq}$$</td> </tr> ' for eq in equations: - formatDict = {'lhs': sp.latex(eq.lhs), - 'rhs': sp.latex(eq.rhs), + formatDict = {'eq': sp.latex(eq), 'nb': noBorder, } htmlTable += line.format(**formatDict) htmlTable += "</table>" @@ -97,15 +113,24 @@ class EquationCollection: result = "" if len(self.subexpressions) > 0: - result += "<div>Subexpressions:<div>" + result += "<div>Subexpressions:</div>" result += makeHtmlEquationTable(self.subexpressions) - result += "<div>Main Equations:<div>" + result += "<div>Main Equations:</div>" result += makeHtmlEquationTable(self.mainEquations) return result def __repr__(self): return "Equation Collection for " + ",".join([str(eq.lhs) for eq in self.mainEquations]) + def __str__(self): + result = "Subexpressions\n" + for eq in self.subexpressions: + result += str(eq) + "\n" + result += "Main Equations\n" + for eq in self.mainEquations: + result += str(eq) + "\n" + return result + # ------------------------------------- Manipulation ------------------------------------------------------------ def merge(self, other): @@ -194,11 +219,7 @@ class EquationCollection: :param fixedSymbols: dictionary with substitutions, that are applied before lambdification """ eqs = self.createNewWithSubstitutionsApplied(fixedSymbols).insertSubexpressions().mainEquations - print('abc') - for eq in eqs: - print(eq) - sp.lambdify(eq.rhs, symbols, module) - lambdas = {eq.lhs: sp.lambdify(eq.rhs, symbols, module) for eq in eqs} + lambdas = {eq.lhs: sp.lambdify(symbols, eq.rhs, module) for eq in eqs} def f(*args, **kwargs): return {s: f(*args, **kwargs) for s, f in lambdas.items()} diff --git a/equationcollection/simplificationstrategy.py b/equationcollection/simplificationstrategy.py new file mode 100644 index 0000000000000000000000000000000000000000..0caa0f269a60088609a69e73d2ade0639fb0b361 --- /dev/null +++ b/equationcollection/simplificationstrategy.py @@ -0,0 +1,134 @@ +import sympy as sp +import textwrap +from collections import namedtuple + + +class SimplificationStrategy: + """ + A simplification strategy is an ordered collection of simplification rules. + Each simplification is a function taking an equation collection, and returning a new simplified + equation collection. The strategy can nicely print intermediate simplification stages and results + to Jupyter notebooks. + """ + + def __init__(self): + self._rules = [] + + def addSimplificationRule(self, rule): + """ + Adds the given simplification rule to the end of the collection. + :param rule: function that taking one equation collection and returning a (simplified) equation collection + """ + self._rules.append(rule) + + @property + def rules(self): + return self._rules + + def apply(self, updateRule): + """Applies all simplification rules to the given equation collection""" + for t in self._rules: + updateRule = t(updateRule) + return updateRule + + def __call__(self, equationCollection): + """Same as apply""" + return self.apply(equationCollection) + + def createSimplificationReport(self, equationCollection): + """ + Returns a simplification report containing the number of operations at each simplification stage, together + with the run-time the simplification took. + """ + + ReportElement = namedtuple('ReportElement', ['simplificationName', 'adds', 'muls', 'divs', 'runtime']) + + class Report: + def __init__(self): + self.elements = [] + + def add(self, element): + self.elements.append(element) + + def __str__(self): + try: + import tabulate + return tabulate(self.elements, headers=['Name', 'Adds', 'Muls', 'Divs', 'Runtime']) + except ImportError: + result = "Name, Adds, Muls, Divs, Runtime\n" + for e in self.elements: + result += ",".join(e) + "\n" + return result + + def _repr_html_(self): + htmlTable = '<table style="border:none">' + htmlTable += "<tr> <th>Name</th> <th>Adds</th> <th>Muls</th> <th>Divs</th> <th>Runtime</th></tr>" + line = "<tr><td>{simplificationName}</td>" \ + "<td>{adds}</td> <td>{muls}</td> <td>{divs}</td> <td>{runtime}</td> </tr>" + + for e in self.elements: + htmlTable += line.format(**e._asdict()) + htmlTable += "</table>" + return htmlTable + + import time + report = Report() + op = equationCollection.operationCount + report.add(ReportElement("OriginalTerm", op['adds'], op['muls'], op['divs'], '-')) + for t in self._rules: + startTime = time.perf_counter() + equationCollection = t(equationCollection) + endTime = time.perf_counter() + op = equationCollection.operationCount + timeStr = "%.2f ms" % ((endTime - startTime) * 1000,) + report.add(ReportElement(t.__name__, op['adds'], op['muls'], op['divs'], timeStr)) + return report + + def showIntermediateResults(self, equationCollection, symbols=None): + + class IntermediateResults: + def __init__(self, strategy, eqColl, resSyms): + self.strategy = strategy + self.equationCollection = eqColl + self.restrictSymbols = resSyms + + def __str__(self): + def printEqCollection(title, eqColl): + text = title + if self.restrictSymbols: + text += "\n".join([str(e) for e in eqColl.get(self.restrictSymbols)]) + else: + text += textwrap.indent(str(eqColl), " " * 3) + return text + + result = printEqCollection("Initial Version", self.equationCollection) + eqColl = self.equationCollection + for rule in self.strategy.rules: + eqColl = rule(eqColl) + result += printEqCollection(rule.__name__, eqColl) + return result + + def _repr_html_(self): + def printEqCollection(title, eqColl): + text = '<h5 style="padding-bottom:10px">%s</h5> <div style="padding-left:20px;">' % (title, ) + if self.restrictSymbols: + text += "\n".join(["$$" + sp.latex(e) + '$$' for e in eqColl.get(self.restrictSymbols)]) + else: + text += eqColl._repr_html_() + text += "</div>" + return text + + result = printEqCollection("Initial Version", self.equationCollection) + eqColl = self.equationCollection + for rule in self.strategy.rules: + eqColl = rule(eqColl) + result += printEqCollection(rule.__name__, eqColl) + return result + + return IntermediateResults(self, equationCollection, symbols) + + def __repr__(self): + result = "Simplification Strategy:\n" + for t in self._rules: + result += " - %s\n" % (t.__name__,) + return result