Commit ef924b18 authored by Martin Bauer's avatar Martin Bauer
Browse files

Code Cleanup

- assignment collection
- sympyextensions
parent c43672d2
...@@ -5,6 +5,7 @@ from pystencils.kernelcreation import createKernel, createIndexedKernel ...@@ -5,6 +5,7 @@ from pystencils.kernelcreation import createKernel, createIndexedKernel
from pystencils.display_utils import showCode, toDot from pystencils.display_utils import showCode, toDot
from pystencils.assignment_collection import AssignmentCollection from pystencils.assignment_collection import AssignmentCollection
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.sympyextensions import SymbolCreator
__all__ = ['Field', 'FieldType', 'extractCommonSubexpressions', __all__ = ['Field', 'FieldType', 'extractCommonSubexpressions',
'TypedSymbol', 'TypedSymbol',
...@@ -12,4 +13,5 @@ __all__ = ['Field', 'FieldType', 'extractCommonSubexpressions', ...@@ -12,4 +13,5 @@ __all__ = ['Field', 'FieldType', 'extractCommonSubexpressions',
'createKernel', 'createIndexedKernel', 'createKernel', 'createIndexedKernel',
'showCode', 'toDot', 'showCode', 'toDot',
'AssignmentCollection', 'AssignmentCollection',
'Assignment'] 'Assignment',
'SymbolCreator']
import sympy as sp import sympy as sp
from typing import Callable, List
from pystencils import Assignment, AssignmentCollection from pystencils import Assignment, AssignmentCollection
from pystencils.sympyextensions import replaceAdditive from pystencils.sympyextensions import subs_additive
def sympyCseOnEquationList(eqs): def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]:
ec = AssignmentCollection(eqs, []) """Extracts common subexpressions from a list of assignments."""
return sympyCSE(ec).allEquations ec = AssignmentCollection(assignments, [])
return sympy_cse(ec).all_assignments
def sympyCSE(assignment_collection): def sympy_cse(ac: AssignmentCollection) -> AssignmentCollection:
""" """Searches for common subexpressions inside the equation collection.
Searches for common subexpressions inside the equation collection, in both the existing subexpressions as well
as the equations themselves. It uses the sympy subexpression detection to do this. Return a new equation collection Searches is done in both the existing subexpressions as well as the assignments themselves.
It uses the sympy subexpression detection to do this. Return a new equation collection
with the additional subexpressions found with the additional subexpressions found
""" """
symbolGen = assignment_collection.subexpressionSymbolNameGenerator symbol_gen = ac.subexpression_symbol_generator
replacements, newEq = sp.cse(assignment_collection.subexpressions + assignment_collection.mainAssignments, replacements, new_eq = sp.cse(ac.subexpressions + ac.main_assignments,
symbols=symbolGen) symbols=symbol_gen)
replacementEqs = [Assignment(*r) for r in replacements] replacement_eqs = [Assignment(*r) for r in replacements]
modifiedSubexpressions = newEq[:len(assignment_collection.subexpressions)] modified_subexpressions = new_eq[:len(ac.subexpressions)]
modifiedUpdateEquations = newEq[len(assignment_collection.subexpressions):] modified_update_equations = new_eq[len(ac.subexpressions):]
newSubexpressions = replacementEqs + modifiedSubexpressions new_subexpressions = replacement_eqs + modified_subexpressions
topologicallySortedPairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in newSubexpressions]) topologically_sorted_pairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in new_subexpressions])
newSubexpressions = [Assignment(a[0], a[1]) for a in topologicallySortedPairs] new_subexpressions = [Assignment(a[0], a[1]) for a in topologically_sorted_pairs]
return assignment_collection.copy(modifiedUpdateEquations, newSubexpressions) return ac.copy(modified_update_equations, new_subexpressions)
def applyOnAllEquations(assignment_collection, operation): def apply_to_all_assignments(assignment_collection: AssignmentCollection,
operation: Callable[[sp.Expr], sp.Expr]) -> AssignmentCollection:
"""Applies sympy expand operation to all equations in collection""" """Applies sympy expand operation to all equations in collection"""
result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.mainAssignments] result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.main_assignments]
return assignment_collection.copy(result) return assignment_collection.copy(result)
def applyOnAllSubexpressions(assignment_collection, operation): def apply_on_all_subexpressions(ac: AssignmentCollection,
result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.subexpressions] operation: Callable[[sp.Expr], sp.Expr]) -> AssignmentCollection:
return assignment_collection.copy(assignment_collection.mainAssignments, result) result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in ac.subexpressions]
return ac.copy(ac.main_assignments, result)
def subexpressionSubstitutionInExistingSubexpressions(assignment_collection): def subexpression_substitution_in_existing_subexpressions(ac: AssignmentCollection) -> AssignmentCollection:
"""Goes through the subexpressions list and replaces the term in the following subexpressions""" """Goes through the subexpressions list and replaces the term in the following subexpressions"""
result = [] result = []
for outerCtr, s in enumerate(assignment_collection.subexpressions): for outerCtr, s in enumerate(ac.subexpressions):
newRhs = s.rhs new_rhs = s.rhs
for innerCtr in range(outerCtr): for innerCtr in range(outerCtr):
subExpr = assignment_collection.subexpressions[innerCtr] sub_expr = ac.subexpressions[innerCtr]
newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0) new_rhs = subs_additive(new_rhs, sub_expr.lhs, sub_expr.rhs, required_match_replacement=1.0)
newRhs = newRhs.subs(subExpr.rhs, subExpr.lhs) new_rhs = new_rhs.subs(sub_expr.rhs, sub_expr.lhs)
result.append(Assignment(s.lhs, newRhs)) result.append(Assignment(s.lhs, new_rhs))
return assignment_collection.copy(assignment_collection.mainAssignments, result) return ac.copy(ac.main_assignments, result)
def subexpressionSubstitutionInmainAssignments(assignment_collection): def subexpression_substitution_in_main_assignments(ac: AssignmentCollection) -> AssignmentCollection:
"""Replaces already existing subexpressions in the equations of the assignment_collection""" """Replaces already existing subexpressions in the equations of the assignment_collection."""
result = [] result = []
for s in assignment_collection.mainAssignments: for s in ac.main_assignments:
newRhs = s.rhs new_rhs = s.rhs
for subExpr in assignment_collection.subexpressions: for subExpr in ac.subexpressions:
newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0) new_rhs = subs_additive(new_rhs, subExpr.lhs, subExpr.rhs, required_match_replacement=1.0)
result.append(Assignment(s.lhs, newRhs)) result.append(Assignment(s.lhs, new_rhs))
return assignment_collection.copy(result) return ac.copy(result)
def addSubexpressionsForDivisions(assignment_collection): def add_subexpressions_for_divisions(ac: AssignmentCollection) -> AssignmentCollection:
"""Introduces subexpressions for all divisions which have no constant in the denominator. """Introduces subexpressions for all divisions which have no constant in the denominator.
e.g. :math:`\frac{1}{x}` is replaced, :math:`\frac{1}{3}` is not replaced."""
For example :math:`\frac{1}{x}` is replaced, :math:`\frac{1}{3}` is not replaced.
"""
divisors = set() divisors = set()
def searchDivisors(term): def search_divisors(term):
if term.func == sp.Pow: if term.func == sp.Pow:
if term.exp.is_integer and term.exp.is_number and term.exp < 0: if term.exp.is_integer and term.exp.is_number and term.exp < 0:
divisors.add(term) divisors.add(term)
else: else:
for a in term.args: for a in term.args:
searchDivisors(a) search_divisors(a)
for eq in assignment_collection.allEquations: for eq in ac.all_assignments:
searchDivisors(eq.rhs) search_divisors(eq.rhs)
newSymbolGen = assignment_collection.subexpressionSymbolNameGenerator new_symbol_gen = ac.subexpression_symbol_generator
substitutions = {divisor: newSymbol for newSymbol, divisor in zip(newSymbolGen, divisors)} substitutions = {divisor: newSymbol for newSymbol, divisor in zip(new_symbol_gen, divisors)}
return assignment_collection.copyWithSubstitutionsApplied(substitutions, True) return ac.new_with_substitutions(substitutions, True)
import sympy as sp import sympy as sp
from collections import namedtuple from collections import namedtuple
from typing import Callable, Any, Optional, Sequence
from pystencils.assignment_collection.assignment_collection import AssignmentCollection
class SimplificationStrategy(object): class SimplificationStrategy(object):
""" """A simplification strategy is an ordered collection of simplification rules.
A simplification strategy is an ordered collection of simplification rules.
Each simplification is a function taking an equation collection, and returning a new simplified Each simplification is a function taking an equation collection, and returning a new simplified
equation collection. The strategy can nicely print intermediate simplification stages and results equation collection. The strategy can nicely print intermediate simplification stages and results
to Jupyter notebooks. to Jupyter notebooks.
...@@ -13,10 +15,11 @@ class SimplificationStrategy(object): ...@@ -13,10 +15,11 @@ class SimplificationStrategy(object):
def __init__(self): def __init__(self):
self._rules = [] self._rules = []
def add(self, rule): def add(self, rule: Callable[[AssignmentCollection], AssignmentCollection]) -> None:
""" """Adds the given simplification rule to the end of the collection.
Adds the given simplification rule to the end of the collection.
:param rule: function that taking one equation collection and returning a (simplified) equation collection Args:
rule: function that rewrites/simplifies an assignment collection
""" """
self._rules.append(rule) self._rules.append(rule)
...@@ -24,19 +27,20 @@ class SimplificationStrategy(object): ...@@ -24,19 +27,20 @@ class SimplificationStrategy(object):
def rules(self): def rules(self):
return self._rules return self._rules
def apply(self, updateRule): def apply(self, assignment_collection: AssignmentCollection) -> AssignmentCollection:
"""Applies all simplification rules to the given equation collection""" """Runs all rules on the given assignment collection."""
for t in self._rules: for t in self._rules:
updateRule = t(updateRule) assignment_collection = t(assignment_collection)
return updateRule return assignment_collection
def __call__(self, assignment_collection): def __call__(self, assignment_collection: AssignmentCollection) -> AssignmentCollection:
"""Same as apply""" """Same as apply"""
return self.apply(assignment_collection) return self.apply(assignment_collection)
def createSimplificationReport(self, assignment_collection): def create_simplification_report(self, assignment_collection: AssignmentCollection) -> Any:
""" """Creates a report to be displayed as HTML in a Jupyter notebook.
Returns a simplification report containing the number of operations at each simplification stage, together
The simplification report contains the number of operations at each simplification stage together
with the run-time the simplification took. with the run-time the simplification took.
""" """
...@@ -60,70 +64,83 @@ class SimplificationStrategy(object): ...@@ -60,70 +64,83 @@ class SimplificationStrategy(object):
return result return result
def _repr_html_(self): def _repr_html_(self):
htmlTable = '<table style="border:none">' html_table = '<table style="border:none">'
htmlTable += "<tr><th>Name</th><th>Runtime</th><th>Adds</th><th>Muls</th><th>Divs</th><th>Total</th></tr>" html_table += "<tr><th>Name</th>" \
"<th>Runtime</th>" \
"<th>Adds</th>" \
"<th>Muls</th>" \
"<th>Divs</th>" \
"<th>Total</th></tr>"
line = "<tr><td>{simplificationName}</td>" \ line = "<tr><td>{simplificationName}</td>" \
"<td>{runtime}</td> <td>{adds}</td> <td>{muls}</td> <td>{divs}</td> <td>{total}</td> </tr>" "<td>{runtime}</td> <td>{adds}</td> <td>{muls}</td> <td>{divs}</td> <td>{total}</td> </tr>"
for e in self.elements: for e in self.elements:
htmlTable += line.format(**e._asdict()) # noinspection PyProtectedMember
htmlTable += "</table>" html_table += line.format(**e._asdict())
return htmlTable html_table += "</table>"
return html_table
import timeit import timeit
report = Report() report = Report()
op = assignment_collection.operationCount op = assignment_collection.operation_count
total = op['adds'] + op['muls'] + op['divs'] total = op['adds'] + op['muls'] + op['divs']
report.add(ReportElement("OriginalTerm", '-', op['adds'], op['muls'], op['divs'], total)) report.add(ReportElement("OriginalTerm", '-', op['adds'], op['muls'], op['divs'], total))
for t in self._rules: for t in self._rules:
startTime = timeit.default_timer() start_time = timeit.default_timer()
assignment_collection = t(assignment_collection) assignment_collection = t(assignment_collection)
endTime = timeit.default_timer() end_time = timeit.default_timer()
op = assignment_collection.operationCount op = assignment_collection.operation_count
timeStr = "%.2f ms" % ((endTime - startTime) * 1000,) time_str = "%.2f ms" % ((end_time - start_time) * 1000,)
total = op['adds'] + op['muls'] + op['divs'] total = op['adds'] + op['muls'] + op['divs']
report.add(ReportElement(t.__name__, timeStr, op['adds'], op['muls'], op['divs'], total)) report.add(ReportElement(t.__name__, time_str, op['adds'], op['muls'], op['divs'], total))
return report return report
def showIntermediateResults(self, assignment_collection, symbols=None): def show_intermediate_results(self, assignment_collection: AssignmentCollection,
symbols: Optional[Sequence[sp.Symbol]] = None) -> Any:
"""Shows the assignment collection after the application of each rule as HTML report for Jupyter notebook.
Args:
assignment_collection: the collection to apply the rules to
symbols: if not None, only the assignments are shown that have one of these symbols as left hand side
"""
class IntermediateResults: class IntermediateResults:
def __init__(self, strategy, eqColl, resSyms): def __init__(self, strategy, collection, restrict_symbols):
self.strategy = strategy self.strategy = strategy
self.assignment_collection = eqColl self.assignment_collection = collection
self.restrictSymbols = resSyms self.restrict_symbols = restrict_symbols
def __str__(self): def __str__(self):
def printEqCollection(title, eqColl): def print_assignment_collection(title, c):
text = title text = title
if self.restrictSymbols: if self.restrict_symbols:
text += "\n".join([str(e) for e in eqColl.get(self.restrictSymbols)]) text += "\n".join([str(e) for e in c.get(self.restrict_symbols)])
else: else:
text += (" " * 3 + (" " * 3).join(str(eqColl).splitlines(True))) text += (" " * 3 + (" " * 3).join(str(c).splitlines(True)))
return text return text
result = printEqCollection("Initial Version", self.assignment_collection) result = print_assignment_collection("Initial Version", self.assignment_collection)
eqColl = self.assignment_collection collection = self.assignment_collection
for rule in self.strategy.rules: for rule in self.strategy.rules:
eqColl = rule(eqColl) collection = rule(collection)
result += printEqCollection(rule.__name__, eqColl) result += print_assignment_collection(rule.__name__, collection)
return result return result
def _repr_html_(self): def _repr_html_(self):
def printEqCollection(title, eqColl): def print_assignment_collection(title, c):
text = '<h5 style="padding-bottom:10px">%s</h5> <div style="padding-left:20px;">' % (title, ) text = '<h5 style="padding-bottom:10px">%s</h5> <div style="padding-left:20px;">' % (title, )
if self.restrictSymbols: if self.restrict_symbols:
text += "\n".join(["$$" + sp.latex(e) + '$$' for e in eqColl.get(self.restrictSymbols)]) text += "\n".join(["$$" + sp.latex(e) + '$$' for e in c.get(self.restrict_symbols)])
else: else:
text += eqColl._repr_html_() # noinspection PyProtectedMember
text += c._repr_html_()
text += "</div>" text += "</div>"
return text return text
result = printEqCollection("Initial Version", self.assignment_collection) result = print_assignment_collection("Initial Version", self.assignment_collection)
eqColl = self.assignment_collection collection = self.assignment_collection
for rule in self.strategy.rules: for rule in self.strategy.rules:
eqColl = rule(eqColl) collection = rule(collection)
result += printEqCollection(rule.__name__, eqColl) result += print_assignment_collection(rule.__name__, collection)
return result return result
return IntermediateResults(self, assignment_collection, symbols) return IntermediateResults(self, assignment_collection, symbols)
......
...@@ -2,7 +2,7 @@ import sympy as sp ...@@ -2,7 +2,7 @@ import sympy as sp
from sympy.tensor import IndexedBase from sympy.tensor import IndexedBase
from pystencils.field import Field from pystencils.field import Field
from pystencils.data_types import TypedSymbol, createType, castFunc from pystencils.data_types import TypedSymbol, createType, castFunc
from pystencils.sympyextensions import fastSubs from pystencils.sympyextensions import fast_subs
class Node(object): class Node(object):
...@@ -275,11 +275,11 @@ class Block(Node): ...@@ -275,11 +275,11 @@ class Block(Node):
@property @property
def undefinedSymbols(self): def undefinedSymbols(self):
result = set() result = set()
definedSymbols = set() defined_symbols = set()
for a in self.args: for a in self.args:
result.update(a.undefinedSymbols) result.update(a.undefinedSymbols)
definedSymbols.update(a.symbolsDefined) defined_symbols.update(a.symbolsDefined)
return result - definedSymbols return result - defined_symbols
def __str__(self): def __str__(self):
return "Block " + ''.join('{!s}\n'.format(node) for node in self._nodes) return "Block " + ''.join('{!s}\n'.format(node) for node in self._nodes)
...@@ -426,8 +426,8 @@ class SympyAssignment(Node): ...@@ -426,8 +426,8 @@ class SympyAssignment(Node):
self._isDeclaration = False self._isDeclaration = False
def subs(self, *args, **kwargs): def subs(self, *args, **kwargs):
self.lhs = fastSubs(self.lhs, *args, **kwargs) self.lhs = fast_subs(self.lhs, *args, **kwargs)
self.rhs = fastSubs(self.rhs, *args, **kwargs) self.rhs = fast_subs(self.rhs, *args, **kwargs)
@property @property
def args(self): def args(self):
...@@ -494,11 +494,11 @@ class ResolvedFieldAccess(sp.Indexed): ...@@ -494,11 +494,11 @@ class ResolvedFieldAccess(sp.Indexed):
self.args[1].subs(old, new), self.args[1].subs(old, new),
self.field, self.offsets, self.idxCoordinateValues) self.field, self.offsets, self.idxCoordinateValues)
def fastSubs(self, subsDict): def fast_subs(self, substitutions):
if self in subsDict: if self in substitutions:
return subsDict[self] return substitutions[self]
return ResolvedFieldAccess(self.args[0].subs(subsDict), return ResolvedFieldAccess(self.args[0].subs(substitutions),
self.args[1].subs(subsDict), self.args[1].subs(substitutions),
self.field, self.offsets, self.idxCoordinateValues) self.field, self.offsets, self.idxCoordinateValues)
def _hashable_content(self): def _hashable_content(self):
......
import sympy as sp import sympy as sp
from collections import namedtuple, defaultdict from collections import namedtuple, defaultdict
from pystencils.sympyextensions import normalizeProduct, prod from pystencils.sympyextensions import normalize_product, prod
def defaultDiffSortKey(d): def defaultDiffSortKey(d):
...@@ -57,7 +57,7 @@ class Diff(sp.Expr): ...@@ -57,7 +57,7 @@ class Diff(sp.Expr):
if self.arg.func != sp.Mul: if self.arg.func != sp.Mul:
constant, variable = 1, self.arg constant, variable = 1, self.arg
else: else:
for factor in normalizeProduct(self.arg): for factor in normalize_product(self.arg):
if factor in functions or isinstance(factor, Diff): if factor in functions or isinstance(factor, Diff):
variable *= factor variable *= factor
else: else:
...@@ -150,7 +150,7 @@ class DiffOperator(sp.Expr): ...@@ -150,7 +150,7 @@ class DiffOperator(sp.Expr):
i.e. DiffOperator('x')*DiffOperator('x') is a second derivative replaced by Diff(Diff(arg, x), t) i.e. DiffOperator('x')*DiffOperator('x') is a second derivative replaced by Diff(Diff(arg, x), t)
""" """
def handleMul(mul): def handleMul(mul):
args = normalizeProduct(mul) args = normalize_product(mul)
diffs = [a for a in args if isinstance(a, DiffOperator)] diffs = [a for a in args if isinstance(a, DiffOperator)]
if len(diffs) == 0: if len(diffs) == 0:
return mul * argument if applyToConstants else mul return mul * argument if applyToConstants else mul
...@@ -254,7 +254,7 @@ def fullDiffExpand(expr, functions=None, constants=None): ...@@ -254,7 +254,7 @@ def fullDiffExpand(expr, functions=None, constants=None):
for term in diffInner.args if diffInner.func == sp.Add else [diffInner]: for term in diffInner.args if diffInner.func == sp.Add else [diffInner]:
independentTerms = 1 independentTerms = 1
dependentTerms = [] dependentTerms = []
for factor in normalizeProduct(term): for factor in normalize_product(term):
if factor in functions or isinstance(factor, Diff): if factor in functions or isinstance(factor, Diff):
dependentTerms.append(factor) dependentTerms.append(factor)
else: else:
...@@ -310,7 +310,7 @@ def expandUsingProductRule(expr): ...@@ -310,7 +310,7 @@ def expandUsingProductRule(expr):
if arg.func not in (sp.Mul, sp.Pow): if arg.func not in (sp.Mul, sp.Pow):
return Diff(arg, target=expr.target, superscript=expr.superscript) return Diff(arg, target=expr.target, superscript=expr.superscript)
else: else:
prodList = normalizeProduct(arg) prodList = normalize_product(arg)
result = 0 result = 0
for i in range(len(prodList)): for i in range(len(prodList)):
preFactor = prod(prodList[j] for j in range(len(prodList)) if i != j) preFactor = prod(prodList[j] for j in range(len(prodList)) if i != j)
...@@ -347,7 +347,7 @@ def combineUsingProductRule(expr): ...@@ -347,7 +347,7 @@ def combineUsingProductRule(expr):
if isinstance(term, Diff): if isinstance(term, Diff):
diffDict[DiffInfo(term.target, term.superscript)].append(DiffSplit(1, term.arg)) diffDict[DiffInfo(term.target, term.superscript)].append(DiffSplit(1, term.arg))
else: else:
mulArgs = normalizeProduct(term) mulArgs = normalize_product(term)
diffs = [d for d in mulArgs if isinstance(d, Diff)] diffs = [d for d in mulArgs if isinstance(d, Diff)]
factor = prod(d for d in mulArgs if not isinstance(d, Diff)) factor = prod(d for d in mulArgs if not isinstance(d, Diff))
if len(diffs) == 0: if len(diffs) == 0:
......
...@@ -8,7 +8,7 @@ from sympy.tensor import IndexedBase ...@@ -8,7 +8,7 @@ from sympy.tensor import IndexedBase
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.alignedarray import aligned_empty from pystencils.alignedarray import aligned_empty
from pystencils.data_types import TypedSymbol, createType, createCompositeTypeFromString, StructType from pystencils.data_types import TypedSymbol, createType, createCompositeTypeFromString, StructType
from pystencils.sympyextensions import isIntegerSequence from pystencils.sympyextensions import is_integer_sequence
class FieldType(Enum): class FieldType(Enum):
...@@ -221,7 +221,7 @@ class Field(object): ...@@ -221,7 +221,7 @@ class Field(object):
@property @property
def hasFixedShape(self): def hasFixedShape(self):
return isIntegerSequence(self.shape) return is_integer_sequence(self.shape)
@property @property