From cfef0f9d014dfd8a0ca203fb2dc8329704607113 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Wed, 21 Dec 2016 10:48:46 +0100
Subject: [PATCH] finished equation collection submodule of pystencils

- moved simplification strategies from lbmpy to pystencils
- better notebook display for simplification reports
- additional demo and documentation
---
 equationcollection/__init__.py               |   1 +
 equationcollection/equationcollection.py     |  45 +++++--
 equationcollection/simplificationstrategy.py | 134 +++++++++++++++++++
 3 files changed, 168 insertions(+), 12 deletions(-)
 create mode 100644 equationcollection/simplificationstrategy.py

diff --git a/equationcollection/__init__.py b/equationcollection/__init__.py
index 295aa35bb..434f02e21 100644
--- a/equationcollection/__init__.py
+++ b/equationcollection/__init__.py
@@ -1 +1,2 @@
 from pystencils.equationcollection.equationcollection import EquationCollection
+from pystencils.equationcollection.simplificationstrategy import SimplificationStrategy
diff --git a/equationcollection/equationcollection.py b/equationcollection/equationcollection.py
index fc29cbf98..b3c9cc33e 100644
--- a/equationcollection/equationcollection.py
+++ b/equationcollection/equationcollection.py
@@ -1,5 +1,5 @@
 import sympy as sp
-from pystencils.transformations import fastSubs
+from pystencils.sympyextensions import fastSubs, countNumberOfOperations
 
 
 class EquationCollection:
@@ -79,17 +79,33 @@ class EquationCollection:
         """All symbols that occur as left-hand-sides of the main equations"""
         return set([eq.lhs for eq in self.mainEquations])
 
+    @property
+    def operationCount(self):
+        """See :func:`countNumberOfOperations` """
+        return countNumberOfOperations(self.allEquations)
+
+    def get(self, symbols, fromMainEquationsOnly=False):
+        """Return the equations which have symbols as left hand sides"""
+        if not hasattr(symbols, "__len__"):
+            symbols = list(symbols)
+        symbols = set(symbols)
+
+        if not fromMainEquationsOnly:
+            eqsToSearchIn = self.allEquations
+        else:
+            eqsToSearchIn = self.mainEquations
+
+        return [eq for eq in eqsToSearchIn if eq.lhs in symbols]
+
     # ----------------------------------------- Display and Printing   -------------------------------------------------
 
     def _repr_html_(self):
         def makeHtmlEquationTable(equations):
             noBorder = 'style="border:none"'
             htmlTable = '<table style="border:none; width: 100%; ">'
-            line = '<tr {nb}> <td {nb}>${lhs}$</td> <td {nb}>$=$</td> ' \
-                   '<td style="border:none; width: 100%;">${rhs}$</td> </tr>'
+            line = '<tr {nb}> <td {nb}>$${eq}$$</td>  </tr> '
             for eq in equations:
-                formatDict = {'lhs': sp.latex(eq.lhs),
-                              'rhs': sp.latex(eq.rhs),
+                formatDict = {'eq': sp.latex(eq),
                               'nb': noBorder, }
                 htmlTable += line.format(**formatDict)
             htmlTable += "</table>"
@@ -97,15 +113,24 @@ class EquationCollection:
 
         result = ""
         if len(self.subexpressions) > 0:
-            result += "<div>Subexpressions:<div>"
+            result += "<div>Subexpressions:</div>"
             result += makeHtmlEquationTable(self.subexpressions)
-        result += "<div>Main Equations:<div>"
+        result += "<div>Main Equations:</div>"
         result += makeHtmlEquationTable(self.mainEquations)
         return result
 
     def __repr__(self):
         return "Equation Collection for " + ",".join([str(eq.lhs) for eq in self.mainEquations])
 
+    def __str__(self):
+        result = "Subexpressions\n"
+        for eq in self.subexpressions:
+            result += str(eq) + "\n"
+        result += "Main Equations\n"
+        for eq in self.mainEquations:
+            result += str(eq) + "\n"
+        return result
+
     # -------------------------------------   Manipulation  ------------------------------------------------------------
 
     def merge(self, other):
@@ -194,11 +219,7 @@ class EquationCollection:
         :param fixedSymbols: dictionary with substitutions, that are applied before lambdification
         """
         eqs = self.createNewWithSubstitutionsApplied(fixedSymbols).insertSubexpressions().mainEquations
-        print('abc')
-        for eq in eqs:
-            print(eq)
-            sp.lambdify(eq.rhs, symbols, module)
-        lambdas = {eq.lhs: sp.lambdify(eq.rhs, symbols, module) for eq in eqs}
+        lambdas = {eq.lhs: sp.lambdify(symbols, eq.rhs, module) for eq in eqs}
 
         def f(*args, **kwargs):
             return {s: f(*args, **kwargs) for s, f in lambdas.items()}
diff --git a/equationcollection/simplificationstrategy.py b/equationcollection/simplificationstrategy.py
new file mode 100644
index 000000000..0caa0f269
--- /dev/null
+++ b/equationcollection/simplificationstrategy.py
@@ -0,0 +1,134 @@
+import sympy as sp
+import textwrap
+from collections import namedtuple
+
+
+class SimplificationStrategy:
+    """
+    A simplification strategy is an ordered collection of simplification rules.
+    Each simplification is a function taking an equation collection, and returning a new simplified
+    equation collection. The strategy can nicely print intermediate simplification stages and results
+    to Jupyter notebooks.
+    """
+
+    def __init__(self):
+        self._rules = []
+
+    def addSimplificationRule(self, rule):
+        """
+        Adds the given simplification rule to the end of the collection.
+        :param rule: function that taking one equation collection and returning a (simplified) equation collection
+        """
+        self._rules.append(rule)
+
+    @property
+    def rules(self):
+        return self._rules
+
+    def apply(self, updateRule):
+        """Applies all simplification rules to the given equation collection"""
+        for t in self._rules:
+            updateRule = t(updateRule)
+        return updateRule
+
+    def __call__(self, equationCollection):
+        """Same as apply"""
+        return self.apply(equationCollection)
+
+    def createSimplificationReport(self, equationCollection):
+        """
+        Returns a simplification report containing the number of operations at each simplification stage, together
+        with the run-time the simplification took.
+        """
+
+        ReportElement = namedtuple('ReportElement', ['simplificationName', 'adds', 'muls', 'divs', 'runtime'])
+
+        class Report:
+            def __init__(self):
+                self.elements = []
+
+            def add(self, element):
+                self.elements.append(element)
+
+            def __str__(self):
+                try:
+                    import tabulate
+                    return tabulate(self.elements, headers=['Name', 'Adds', 'Muls', 'Divs', 'Runtime'])
+                except ImportError:
+                    result = "Name, Adds, Muls, Divs, Runtime\n"
+                    for e in self.elements:
+                        result += ",".join(e) + "\n"
+                    return result
+
+            def _repr_html_(self):
+                htmlTable = '<table style="border:none">'
+                htmlTable += "<tr> <th>Name</th> <th>Adds</th> <th>Muls</th> <th>Divs</th> <th>Runtime</th></tr>"
+                line = "<tr><td>{simplificationName}</td>" \
+                       "<td>{adds}</td> <td>{muls}</td> <td>{divs}</td>  <td>{runtime}</td> </tr>"
+
+                for e in self.elements:
+                    htmlTable += line.format(**e._asdict())
+                htmlTable += "</table>"
+                return htmlTable
+
+        import time
+        report = Report()
+        op = equationCollection.operationCount
+        report.add(ReportElement("OriginalTerm", op['adds'], op['muls'], op['divs'], '-'))
+        for t in self._rules:
+            startTime = time.perf_counter()
+            equationCollection = t(equationCollection)
+            endTime = time.perf_counter()
+            op = equationCollection.operationCount
+            timeStr = "%.2f ms" % ((endTime - startTime) * 1000,)
+            report.add(ReportElement(t.__name__, op['adds'], op['muls'], op['divs'], timeStr))
+        return report
+
+    def showIntermediateResults(self, equationCollection, symbols=None):
+
+        class IntermediateResults:
+            def __init__(self, strategy, eqColl, resSyms):
+                self.strategy = strategy
+                self.equationCollection = eqColl
+                self.restrictSymbols = resSyms
+
+            def __str__(self):
+                def printEqCollection(title, eqColl):
+                    text = title
+                    if self.restrictSymbols:
+                        text += "\n".join([str(e) for e in eqColl.get(self.restrictSymbols)])
+                    else:
+                        text += textwrap.indent(str(eqColl), " " * 3)
+                    return text
+
+                result = printEqCollection("Initial Version", self.equationCollection)
+                eqColl = self.equationCollection
+                for rule in self.strategy.rules:
+                    eqColl = rule(eqColl)
+                    result += printEqCollection(rule.__name__, eqColl)
+                return result
+
+            def _repr_html_(self):
+                def printEqCollection(title, eqColl):
+                    text = '<h5 style="padding-bottom:10px">%s</h5> <div style="padding-left:20px;">' % (title, )
+                    if self.restrictSymbols:
+                        text += "\n".join(["$$" + sp.latex(e) + '$$' for e in eqColl.get(self.restrictSymbols)])
+                    else:
+                        text += eqColl._repr_html_()
+                    text += "</div>"
+                    return text
+
+                result = printEqCollection("Initial Version", self.equationCollection)
+                eqColl = self.equationCollection
+                for rule in self.strategy.rules:
+                    eqColl = rule(eqColl)
+                    result += printEqCollection(rule.__name__, eqColl)
+                return result
+
+        return IntermediateResults(self, equationCollection, symbols)
+
+    def __repr__(self):
+        result = "Simplification Strategy:\n"
+        for t in self._rules:
+            result += " - %s\n" % (t.__name__,)
+        return result
-- 
GitLab