Skip to content
Snippets Groups Projects
Commit cfef0f9d authored by Martin Bauer's avatar Martin Bauer
Browse files

finished equation collection submodule of pystencils

- moved simplification strategies from lbmpy to pystencils
- better notebook display for simplification reports
- additional demo and documentation
parent adf98ef6
No related merge requests found
from pystencils.equationcollection.equationcollection import EquationCollection from pystencils.equationcollection.equationcollection import EquationCollection
from pystencils.equationcollection.simplificationstrategy import SimplificationStrategy
import sympy as sp import sympy as sp
from pystencils.transformations import fastSubs from pystencils.sympyextensions import fastSubs, countNumberOfOperations
class EquationCollection: class EquationCollection:
...@@ -79,17 +79,33 @@ class EquationCollection: ...@@ -79,17 +79,33 @@ class EquationCollection:
"""All symbols that occur as left-hand-sides of the main equations""" """All symbols that occur as left-hand-sides of the main equations"""
return set([eq.lhs for eq in self.mainEquations]) return set([eq.lhs for eq in self.mainEquations])
@property
def operationCount(self):
"""See :func:`countNumberOfOperations` """
return countNumberOfOperations(self.allEquations)
def get(self, symbols, fromMainEquationsOnly=False):
"""Return the equations which have symbols as left hand sides"""
if not hasattr(symbols, "__len__"):
symbols = list(symbols)
symbols = set(symbols)
if not fromMainEquationsOnly:
eqsToSearchIn = self.allEquations
else:
eqsToSearchIn = self.mainEquations
return [eq for eq in eqsToSearchIn if eq.lhs in symbols]
# ----------------------------------------- Display and Printing ------------------------------------------------- # ----------------------------------------- Display and Printing -------------------------------------------------
def _repr_html_(self): def _repr_html_(self):
def makeHtmlEquationTable(equations): def makeHtmlEquationTable(equations):
noBorder = 'style="border:none"' noBorder = 'style="border:none"'
htmlTable = '<table style="border:none; width: 100%; ">' htmlTable = '<table style="border:none; width: 100%; ">'
line = '<tr {nb}> <td {nb}>${lhs}$</td> <td {nb}>$=$</td> ' \ line = '<tr {nb}> <td {nb}>$${eq}$$</td> </tr> '
'<td style="border:none; width: 100%;">${rhs}$</td> </tr>'
for eq in equations: for eq in equations:
formatDict = {'lhs': sp.latex(eq.lhs), formatDict = {'eq': sp.latex(eq),
'rhs': sp.latex(eq.rhs),
'nb': noBorder, } 'nb': noBorder, }
htmlTable += line.format(**formatDict) htmlTable += line.format(**formatDict)
htmlTable += "</table>" htmlTable += "</table>"
...@@ -97,15 +113,24 @@ class EquationCollection: ...@@ -97,15 +113,24 @@ class EquationCollection:
result = "" result = ""
if len(self.subexpressions) > 0: if len(self.subexpressions) > 0:
result += "<div>Subexpressions:<div>" result += "<div>Subexpressions:</div>"
result += makeHtmlEquationTable(self.subexpressions) result += makeHtmlEquationTable(self.subexpressions)
result += "<div>Main Equations:<div>" result += "<div>Main Equations:</div>"
result += makeHtmlEquationTable(self.mainEquations) result += makeHtmlEquationTable(self.mainEquations)
return result return result
def __repr__(self): def __repr__(self):
return "Equation Collection for " + ",".join([str(eq.lhs) for eq in self.mainEquations]) return "Equation Collection for " + ",".join([str(eq.lhs) for eq in self.mainEquations])
def __str__(self):
result = "Subexpressions\n"
for eq in self.subexpressions:
result += str(eq) + "\n"
result += "Main Equations\n"
for eq in self.mainEquations:
result += str(eq) + "\n"
return result
# ------------------------------------- Manipulation ------------------------------------------------------------ # ------------------------------------- Manipulation ------------------------------------------------------------
def merge(self, other): def merge(self, other):
...@@ -194,11 +219,7 @@ class EquationCollection: ...@@ -194,11 +219,7 @@ class EquationCollection:
:param fixedSymbols: dictionary with substitutions, that are applied before lambdification :param fixedSymbols: dictionary with substitutions, that are applied before lambdification
""" """
eqs = self.createNewWithSubstitutionsApplied(fixedSymbols).insertSubexpressions().mainEquations eqs = self.createNewWithSubstitutionsApplied(fixedSymbols).insertSubexpressions().mainEquations
print('abc') lambdas = {eq.lhs: sp.lambdify(symbols, eq.rhs, module) for eq in eqs}
for eq in eqs:
print(eq)
sp.lambdify(eq.rhs, symbols, module)
lambdas = {eq.lhs: sp.lambdify(eq.rhs, symbols, module) for eq in eqs}
def f(*args, **kwargs): def f(*args, **kwargs):
return {s: f(*args, **kwargs) for s, f in lambdas.items()} return {s: f(*args, **kwargs) for s, f in lambdas.items()}
......
import sympy as sp
import textwrap
from collections import namedtuple
class SimplificationStrategy:
"""
A simplification strategy is an ordered collection of simplification rules.
Each simplification is a function taking an equation collection, and returning a new simplified
equation collection. The strategy can nicely print intermediate simplification stages and results
to Jupyter notebooks.
"""
def __init__(self):
self._rules = []
def addSimplificationRule(self, rule):
"""
Adds the given simplification rule to the end of the collection.
:param rule: function that taking one equation collection and returning a (simplified) equation collection
"""
self._rules.append(rule)
@property
def rules(self):
return self._rules
def apply(self, updateRule):
"""Applies all simplification rules to the given equation collection"""
for t in self._rules:
updateRule = t(updateRule)
return updateRule
def __call__(self, equationCollection):
"""Same as apply"""
return self.apply(equationCollection)
def createSimplificationReport(self, equationCollection):
"""
Returns a simplification report containing the number of operations at each simplification stage, together
with the run-time the simplification took.
"""
ReportElement = namedtuple('ReportElement', ['simplificationName', 'adds', 'muls', 'divs', 'runtime'])
class Report:
def __init__(self):
self.elements = []
def add(self, element):
self.elements.append(element)
def __str__(self):
try:
import tabulate
return tabulate(self.elements, headers=['Name', 'Adds', 'Muls', 'Divs', 'Runtime'])
except ImportError:
result = "Name, Adds, Muls, Divs, Runtime\n"
for e in self.elements:
result += ",".join(e) + "\n"
return result
def _repr_html_(self):
htmlTable = '<table style="border:none">'
htmlTable += "<tr> <th>Name</th> <th>Adds</th> <th>Muls</th> <th>Divs</th> <th>Runtime</th></tr>"
line = "<tr><td>{simplificationName}</td>" \
"<td>{adds}</td> <td>{muls}</td> <td>{divs}</td> <td>{runtime}</td> </tr>"
for e in self.elements:
htmlTable += line.format(**e._asdict())
htmlTable += "</table>"
return htmlTable
import time
report = Report()
op = equationCollection.operationCount
report.add(ReportElement("OriginalTerm", op['adds'], op['muls'], op['divs'], '-'))
for t in self._rules:
startTime = time.perf_counter()
equationCollection = t(equationCollection)
endTime = time.perf_counter()
op = equationCollection.operationCount
timeStr = "%.2f ms" % ((endTime - startTime) * 1000,)
report.add(ReportElement(t.__name__, op['adds'], op['muls'], op['divs'], timeStr))
return report
def showIntermediateResults(self, equationCollection, symbols=None):
class IntermediateResults:
def __init__(self, strategy, eqColl, resSyms):
self.strategy = strategy
self.equationCollection = eqColl
self.restrictSymbols = resSyms
def __str__(self):
def printEqCollection(title, eqColl):
text = title
if self.restrictSymbols:
text += "\n".join([str(e) for e in eqColl.get(self.restrictSymbols)])
else:
text += textwrap.indent(str(eqColl), " " * 3)
return text
result = printEqCollection("Initial Version", self.equationCollection)
eqColl = self.equationCollection
for rule in self.strategy.rules:
eqColl = rule(eqColl)
result += printEqCollection(rule.__name__, eqColl)
return result
def _repr_html_(self):
def printEqCollection(title, eqColl):
text = '<h5 style="padding-bottom:10px">%s</h5> <div style="padding-left:20px;">' % (title, )
if self.restrictSymbols:
text += "\n".join(["$$" + sp.latex(e) + '$$' for e in eqColl.get(self.restrictSymbols)])
else:
text += eqColl._repr_html_()
text += "</div>"
return text
result = printEqCollection("Initial Version", self.equationCollection)
eqColl = self.equationCollection
for rule in self.strategy.rules:
eqColl = rule(eqColl)
result += printEqCollection(rule.__name__, eqColl)
return result
return IntermediateResults(self, equationCollection, symbols)
def __repr__(self):
result = "Simplification Strategy:\n"
for t in self._rules:
result += " - %s\n" % (t.__name__,)
return result
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment