Commit ef924b18 authored by Martin Bauer's avatar Martin Bauer
Browse files

Code Cleanup

- assignment collection
- sympyextensions
parent c43672d2
......@@ -5,6 +5,7 @@ from pystencils.kernelcreation import createKernel, createIndexedKernel
from pystencils.display_utils import showCode, toDot
from pystencils.assignment_collection import AssignmentCollection
from pystencils.assignment import Assignment
from pystencils.sympyextensions import SymbolCreator
__all__ = ['Field', 'FieldType', 'extractCommonSubexpressions',
'TypedSymbol',
......@@ -12,4 +13,5 @@ __all__ = ['Field', 'FieldType', 'extractCommonSubexpressions',
'createKernel', 'createIndexedKernel',
'showCode', 'toDot',
'AssignmentCollection',
'Assignment']
'Assignment',
'SymbolCreator']
import sympy as sp
from copy import copy
from typing import List, Optional, Dict, Any, Set, Sequence, Iterator, Iterable
from pystencils.assignment import Assignment
from pystencils.sympyextensions import fastSubs, countNumberOfOperations, sortEquationsTopologically
from pystencils.sympyextensions import fast_subs, count_operations, sort_assignments_topologically
class AssignmentCollection(object):
class AssignmentCollection:
"""
A collection of equations with subexpression definitions, also represented as equations,
A collection of equations with subexpression definitions, also represented as assignments,
that are used in the main equations. AssignmentCollection can be passed to simplification methods.
These simplification methods can change the subexpressions, but the number and
left hand side of the main equations themselves is not altered.
Additionally a dictionary of simplification hints is stored, which are set by the functions that create
equation collections to transport information to the simplification system.
:ivar mainAssignments: list of sympy equations
:ivar subexpressions: list of sympy equations defining subexpressions used in main equations
:ivar simplificationHints: dictionary that is used to annotate the equation collection with hints that are
used by the simplification system. See documentation of the simplification rules for
potentially required hints and their meaning.
Attributes:
main_assignments: list of assignments
subexpressions: list of assignments defining subexpressions used in main equations
simplification_hints: dict that is used to annotate the equation collection with hints that are
used by the simplification system. See documentation of the simplification rules for
potentially required hints and their meaning.
subexpression_symbol_generator: generator for new symbols that are used when new subexpressions are added
used to get new symbols that are unique for this AssignmentCollection
"""
# ----------------------------------------- Creation ---------------------------------------------------------------
# ------------------------------- Creation & Inplace Manipulation --------------------------------------------------
def __init__(self, equations, subExpressions, simplificationHints=None, subexpressionSymbolNameGenerator=None):
self.mainAssignments = equations
self.subexpressions = subExpressions
def __init__(self, main_assignments: List[Assignment], subexpressions: List[Assignment],
simplification_hints: Optional[Dict[str, Any]] = None,
subexpression_symbol_generator: Iterator[sp.Symbol] = None) -> None:
self.main_assignments = main_assignments
self.subexpressions = subexpressions
if simplificationHints is None:
simplificationHints = {}
if simplification_hints is None:
simplification_hints = {}
self.simplificationHints = simplificationHints
self.simplification_hints = simplification_hints
if subexpressionSymbolNameGenerator is None:
self.subexpressionSymbolNameGenerator = SymbolGen()
if subexpression_symbol_generator is None:
self.subexpression_symbol_generator = SymbolGen()
else:
self.subexpressionSymbolNameGenerator = subexpressionSymbolNameGenerator
self.subexpression_symbol_generator = subexpression_symbol_generator
@property
def mainTerms(self):
return []
def add_simplification_hint(self, key: str, value: Any) -> None:
"""Adds an entry to the simplification_hints dictionary and checks that is does not exist yet."""
assert key not in self.simplification_hints, "This hint already exists"
self.simplification_hints[key] = value
def copy(self, mainAssignments=None, subexpressions=None):
res = copy(self)
res.simplificationHints = self.simplificationHints.copy()
res.subexpressionSymbolNameGenerator = copy(self.subexpressionSymbolNameGenerator)
def add_subexpression(self, rhs: sp.Expr, lhs: Optional[sp.Symbol] = None, topological_sort=True) -> sp.Symbol:
"""Adds a subexpression to current collection.
if mainAssignments is not None:
res.mainAssignments = mainAssignments
else:
res.mainAssignments = self.mainAssignments.copy()
Args:
rhs: right hand side of new subexpression
lhs: optional left hand side of new subexpression. If None a new unique symbol is generated.
topological_sort: sort the subexpressions topologically after insertion, to make sure that
definition of a symbol comes before its usage. If False, subexpression is appended.
if subexpressions is not None:
res.subexpressions = subexpressions
else:
res.subexpressions = self.subexpressions.copy()
return res
def copyWithSubstitutionsApplied(self, substitutionDict, addSubstitutionsAsSubexpressions=False,
substituteOnLhs=True):
"""
Returns a new equation collection, where terms are substituted according to the passed `substitutionDict`.
Substitutions are made in the subexpression terms and the main equations
Returns:
left hand side symbol (which could have been generated)
"""
if substituteOnLhs:
newSubexpressions = [fastSubs(eq, substitutionDict) for eq in self.subexpressions]
newEquations = [fastSubs(eq, substitutionDict) for eq in self.mainAssignments]
else:
newSubexpressions = [Assignment(eq.lhs, fastSubs(eq.rhs, substitutionDict)) for eq in self.subexpressions]
newEquations = [Assignment(eq.lhs, fastSubs(eq.rhs, substitutionDict)) for eq in self.mainAssignments]
if addSubstitutionsAsSubexpressions:
newSubexpressions = [Assignment(b, a) for a, b in substitutionDict.items()] + newSubexpressions
newSubexpressions = sortEquationsTopologically(newSubexpressions)
return self.copy(newEquations, newSubexpressions)
if lhs is None:
lhs = sp.Dummy()
eq = Assignment(lhs, rhs)
self.subexpressions.append(eq)
if topological_sort:
self.topological_sort(sort_subexpressions=True, sort_main_assignments=False)
return lhs
def addSimplificationHint(self, key, value):
"""
Adds an entry to the simplificationHints dictionary, and checks that is does not exist yet
"""
assert key not in self.simplificationHints, "This hint already exists"
self.simplificationHints[key] = value
def topological_sort(self, sort_subexpressions: bool = True, sort_main_assignments: bool = True) -> None:
"""Sorts subexpressions and/or main_equations topologically to make sure symbol usage comes after definition."""
if sort_subexpressions:
self.subexpressions = sort_assignments_topologically(self.subexpressions)
if sort_main_assignments:
self.main_assignments = sort_assignments_topologically(self.main_assignments)
# ---------------------------------------------- Properties -------------------------------------------------------
@property
def allEquations(self):
"""Subexpression and main equations in one sequence"""
return self.subexpressions + self.mainAssignments
def all_assignments(self) -> List[Assignment]:
"""Subexpression and main equations as a single list."""
return self.subexpressions + self.main_assignments
@property
def freeSymbols(self):
"""All symbols used in the equation collection, which have not been defined inside the equation system"""
freeSymbols = set()
for eq in self.allEquations:
freeSymbols.update(eq.rhs.atoms(sp.Symbol))
return freeSymbols - self.boundSymbols
def free_symbols(self) -> Set[sp.Symbol]:
"""All symbols used in the assignment collection, which do not occur as left hand sides in any assignment."""
free_symbols = set()
for eq in self.all_assignments:
free_symbols.update(eq.rhs.atoms(sp.Symbol))
return free_symbols - self.bound_symbols
@property
def boundSymbols(self):
"""Set of all symbols which occur on left-hand-sides i.e. all symbols which are defined."""
boundSymbolsSet = set([eq.lhs for eq in self.allEquations])
assert len(boundSymbolsSet) == len(self.subexpressions) + len(self.mainAssignments), \
def bound_symbols(self) -> Set[sp.Symbol]:
"""All symbols which occur on the left hand side of a main assignment or a subexpression."""
bound_symbols_set = set([eq.lhs for eq in self.all_assignments])
assert len(bound_symbols_set) == len(self.subexpressions) + len(self.main_assignments), \
"Not in SSA form - same symbol assigned multiple times"
return boundSymbolsSet
return bound_symbols_set
@property
def definedSymbols(self):
"""All symbols that occur as left-hand-sides of the main equations"""
return set([eq.lhs for eq in self.mainAssignments])
def defined_symbols(self) -> Set[sp.Symbol]:
"""All symbols which occur as left-hand-sides of one of the main equations"""
return set([assignment.lhs for assignment in self.main_assignments])
@property
def operationCount(self):
"""See :func:`countNumberOfOperations` """
return countNumberOfOperations(self.allEquations, onlyType=None)
def operation_count(self):
"""See :func:`count_operations` """
return count_operations(self.all_assignments, only_type=None)
def dependent_symbols(self, symbols: Iterable[sp.Symbol]) -> Set[sp.Symbol]:
"""Returns all symbols that depend on one of the passed symbols.
def get(self, symbols, frommainAssignmentsOnly=False):
"""Return the equations which have symbols as left hand sides"""
A symbol 'a' depends on a symbol 'b', if there is an assignment 'a <- someExpression(b)' i.e. when
'b' is required to compute 'a'.
"""
queue = list(symbols)
def add_symbols_from_expr(expr):
dependent_symbols = expr.atoms(sp.Symbol)
for ds in dependent_symbols:
queue.append(ds)
handled_symbols = set()
assignment_dict = {e.lhs: e.rhs for e in self.all_assignments}
while len(queue) > 0:
e = queue.pop(0)
if e in handled_symbols:
continue
if e in assignment_dict:
add_symbols_from_expr(assignment_dict[e])
handled_symbols.add(e)
return handled_symbols
def get(self, symbols: Sequence[sp.Symbol], from_main_assignments_only=False) -> List[Assignment]:
"""Extracts all assignments that have a left hand side that is contained in the symbols parameter.
Args:
symbols: return assignments that have one of these symbols as left hand side
from_main_assignments_only: search only in main assignments (exclude subexpressions)
"""
if not hasattr(symbols, "__len__"):
symbols = list(symbols)
symbols = set(symbols)
symbols = set(symbols)
else:
symbols = set(symbols)
if not frommainAssignmentsOnly:
eqsToSearchIn = self.allEquations
if not from_main_assignments_only:
assignments_to_search = self.all_assignments
else:
eqsToSearchIn = self.mainAssignments
assignments_to_search = self.main_assignments
return [eq for eq in eqsToSearchIn if eq.lhs in symbols]
return [assignment for assignment in assignments_to_search if assignment.lhs in symbols]
# ----------------------------------------- Display and Printing -------------------------------------------------
def lambdify(self, symbols: Sequence[sp.Symbol], fixed_symbols: Optional[Dict[sp.Symbol, Any]]=None, module=None):
"""Returns a python function to evaluate this equation collection.
def _repr_html_(self):
def makeHtmlEquationTable(equations):
noBorder = 'style="border:none"'
htmlTable = '<table style="border:none; width: 100%; ">'
line = '<tr {nb}> <td {nb}>$${eq}$$</td> </tr> '
for eq in equations:
formatDict = {'eq': sp.latex(eq),
'nb': noBorder, }
htmlTable += line.format(**formatDict)
htmlTable += "</table>"
return htmlTable
Args:
symbols: symbol(s) which are the parameter for the created function
fixed_symbols: dictionary with substitutions, that are applied before sympy's lambdify
module: same as sympy.lambdify parameter. Defines which module to use e.g. 'numpy'
result = ""
if len(self.subexpressions) > 0:
result += "<div>Subexpressions:</div>"
result += makeHtmlEquationTable(self.subexpressions)
result += "<div>Main Assignments:</div>"
result += makeHtmlEquationTable(self.mainAssignments)
return result
Examples:
>>> a, b, c, d = sp.symbols("a b c d")
>>> ac = AssignmentCollection([Assignment(c, a + b), Assignment(d, a**2 + b)],
... subexpressions=[Assignment(b, a + b / 2)])
>>> python_function = ac.lambdify([a], fixed_symbols={b: 2})
>>> python_function(4)
{c: 6, d: 18}
"""
assignments = self.new_with_substitutions(fixed_symbols, substitute_on_lhs=False) if fixed_symbols else self
assignments = assignments.new_without_subexpressions().main_assignments
lambdas = {assignment.lhs: sp.lambdify(symbols, assignment.rhs, module) for assignment in assignments}
def __repr__(self):
return "Equation Collection for " + ",".join([str(eq.lhs) for eq in self.mainAssignments])
def f(*args, **kwargs):
return {s: func(*args, **kwargs) for s, func in lambdas.items()}
def __str__(self):
result = "Subexpressions\n"
for eq in self.subexpressions:
result += str(eq) + "\n"
result += "Main Assignments\n"
for eq in self.mainAssignments:
result += str(eq) + "\n"
return result
return f
# ---------------------------- Creating new modified collections ---------------------------------------------------
# ------------------------------------- Manipulation ------------------------------------------------------------
def copy(self,
main_assignments: Optional[List[Assignment]] = None,
subexpressions: Optional[List[Assignment]] = None) -> 'AssignmentCollection':
"""Returns a copy with optionally replaced main_assignments and/or subexpressions."""
def merge(self, other):
"""Returns a new collection which contains self and other. Subexpressions are renamed if they clash."""
ownDefs = set([e.lhs for e in self.mainAssignments])
otherDefs = set([e.lhs for e in other.mainAssignments])
assert len(ownDefs.intersection(otherDefs)) == 0, "Cannot merge, since both collection define the same symbols"
res = copy(self)
res.simplification_hints = self.simplification_hints.copy()
res.subexpression_symbol_generator = copy(self.subexpression_symbol_generator)
ownSubexpressionSymbols = {e.lhs: e.rhs for e in self.subexpressions}
substitutionDict = {}
if main_assignments is not None:
res.main_assignments = main_assignments
else:
res.main_assignments = self.main_assignments.copy()
processedOtherSubexpressionEquations = []
for otherSubexpressionEq in other.subexpressions:
if otherSubexpressionEq.lhs in ownSubexpressionSymbols:
if otherSubexpressionEq.rhs == ownSubexpressionSymbols[otherSubexpressionEq.lhs]:
continue # exact the same subexpression equation exists already
else:
# different definition - a new name has to be introduced
newLhs = next(self.subexpressionSymbolNameGenerator)
newEq = Assignment(newLhs, fastSubs(otherSubexpressionEq.rhs, substitutionDict))
processedOtherSubexpressionEquations.append(newEq)
substitutionDict[otherSubexpressionEq.lhs] = newLhs
else:
processedOtherSubexpressionEquations.append(fastSubs(otherSubexpressionEq, substitutionDict))
if subexpressions is not None:
res.subexpressions = subexpressions
else:
res.subexpressions = self.subexpressions.copy()
processedOthermainAssignments = [fastSubs(eq, substitutionDict) for eq in other.mainAssignments]
return self.copy(self.mainAssignments + processedOthermainAssignments,
self.subexpressions + processedOtherSubexpressionEquations)
return res
def getDependentSymbols(self, symbolSequence):
"""Returns a list of symbols that depend on the passed symbols."""
def new_with_substitutions(self, substitutions: Dict, add_substitutions_as_subexpressions: bool = False,
substitute_on_lhs: bool = True) -> 'AssignmentCollection':
"""Returns new object, where terms are substituted according to the passed substitution dict.
queue = list(symbolSequence)
Args:
substitutions: dict that is passed to sympy subs, substitutions are done main assignments and subexpressions
add_substitutions_as_subexpressions: if True, the substitutions are added as assignments to subexpressions
substitute_on_lhs: if False, the substitutions are done only on the right hand side of assignments
def addSymbolsFromExpr(expr):
dependentSymbols = expr.atoms(sp.Symbol)
for ds in dependentSymbols:
queue.append(ds)
Returns:
New AssignmentCollection where substitutions have been applied, self is not altered.
"""
if substitute_on_lhs:
new_subexpressions = [fast_subs(eq, substitutions) for eq in self.subexpressions]
new_equations = [fast_subs(eq, substitutions) for eq in self.main_assignments]
else:
new_subexpressions = [Assignment(eq.lhs, fast_subs(eq.rhs, substitutions)) for eq in self.subexpressions]
new_equations = [Assignment(eq.lhs, fast_subs(eq.rhs, substitutions)) for eq in self.main_assignments]
handledSymbols = set()
eqMap = {e.lhs: e.rhs for e in self.allEquations}
if add_substitutions_as_subexpressions:
new_subexpressions = [Assignment(b, a) for a, b in substitutions.items()] + new_subexpressions
new_subexpressions = sort_assignments_topologically(new_subexpressions)
return self.copy(new_equations, new_subexpressions)
while len(queue) > 0:
e = queue.pop(0)
if e in handledSymbols:
continue
if e in eqMap:
addSymbolsFromExpr(eqMap[e])
handledSymbols.add(e)
def new_merged(self, other: 'AssignmentCollection') -> 'AssignmentCollection':
"""Returns a new collection which contains self and other. Subexpressions are renamed if they clash."""
own_definitions = set([e.lhs for e in self.main_assignments])
other_definitions = set([e.lhs for e in other.main_assignments])
assert len(own_definitions.intersection(other_definitions)) == 0, \
"Cannot new_merged, since both collection define the same symbols"
return handledSymbols
own_subexpression_symbols = {e.lhs: e.rhs for e in self.subexpressions}
substitution_dict = {}
def extract(self, symbolsToExtract):
"""
Creates a new equation collection with equations that have symbolsToExtract as left-hand-sides and
only the necessary subexpressions that are used in these equations
"""
symbolsToExtract = set(symbolsToExtract)
dependentSymbols = self.getDependentSymbols(symbolsToExtract)
newEquations = []
for eq in self.allEquations:
if eq.lhs in symbolsToExtract:
newEquations.append(eq)
newSubExpr = [eq for eq in self.subexpressions if eq.lhs in dependentSymbols and eq.lhs not in symbolsToExtract]
return AssignmentCollection(newEquations, newSubExpr)
def newWithoutUnusedSubexpressions(self):
"""Returns a new equation collection containing only the subexpressions that
are used/referenced in the equations"""
allLhs = [eq.lhs for eq in self.mainAssignments]
return self.extract(allLhs)
def appendToSubexpressions(self, rhs, lhs=None, topologicalSort=True):
if lhs is None:
lhs = sp.Dummy()
eq = Assignment(lhs, rhs)
self.subexpressions.append(eq)
if topologicalSort:
self.topologicalSort(subexpressions=True, mainAssignments=False)
return lhs
processed_other_subexpression_equations = []
for otherSubexpressionEq in other.subexpressions:
if otherSubexpressionEq.lhs in own_subexpression_symbols:
if otherSubexpressionEq.rhs == own_subexpression_symbols[otherSubexpressionEq.lhs]:
continue # exact the same subexpression equation exists already
else:
# different definition - a new name has to be introduced
new_lhs = next(self.subexpression_symbol_generator)
new_eq = Assignment(new_lhs, fast_subs(otherSubexpressionEq.rhs, substitution_dict))
processed_other_subexpression_equations.append(new_eq)
substitution_dict[otherSubexpressionEq.lhs] = new_lhs
else:
processed_other_subexpression_equations.append(fast_subs(otherSubexpressionEq, substitution_dict))
processed_other_main_assignments = [fast_subs(eq, substitution_dict) for eq in other.main_assignments]
return self.copy(self.main_assignments + processed_other_main_assignments,
self.subexpressions + processed_other_subexpression_equations)
def topologicalSort(self, subexpressions=True, mainAssignments=True):
if subexpressions:
self.subexpressions = sortEquationsTopologically(self.subexpressions)
if mainAssignments:
self.mainAssignments = sortEquationsTopologically(self.mainAssignments)
def new_filtered(self, symbols_to_extract: Iterable[sp.Symbol]) -> 'AssignmentCollection':
"""Extracts equations that have symbols_to_extract as left hand side, together with necessary subexpressions.
def insertSubexpression(self, symbol):
newSubexpressions = []
subsDict = None
Returns:
new AssignmentCollection, self is not altered
"""
symbols_to_extract = set(symbols_to_extract)
dependent_symbols = self.dependent_symbols(symbols_to_extract)
new_assignments = []
for eq in self.all_assignments:
if eq.lhs in symbols_to_extract:
new_assignments.append(eq)
new_sub_expr = [eq for eq in self.subexpressions
if eq.lhs in dependent_symbols and eq.lhs not in symbols_to_extract]
return AssignmentCollection(new_assignments, new_sub_expr)
def new_without_unused_subexpressions(self) -> 'AssignmentCollection':
"""Returns new collection that only contains subexpressions required to compute the main assignments."""
all_lhs = [eq.lhs for eq in self.main_assignments]
return self.new_filtered(all_lhs)
def new_with_inserted_subexpression(self, symbol: sp.Symbol) -> 'AssignmentCollection':
"""Eliminates the subexpression with the given symbol on its left hand side, by substituting it everywhere."""
new_subexpressions = []
subs_dict = None
for se in self.subexpressions:
if se.lhs == symbol:
subsDict = {se.lhs: se.rhs}
subs_dict = {se.lhs: se.rhs}
else:
newSubexpressions.append(se)
if subsDict is None:
new_subexpressions.append(se)
if subs_dict is None:
return self
newSubexpressions = [Assignment(eq.lhs, fastSubs(eq.rhs, subsDict)) for eq in newSubexpressions]
newEqs = [Assignment(eq.lhs, fastSubs(eq.rhs, subsDict)) for eq in self.mainAssignments]
return self.copy(newEqs, newSubexpressions)
new_subexpressions = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in new_subexpressions]
new_eqs = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in self.main_assignments]
return self.copy(new_eqs, new_subexpressions)
def insertSubexpressions(self, subexpressionSymbolsToKeep=set()):
"""Returns a new equation collection by inserting all subexpressions into the main equations"""
def new_without_subexpressions(self, subexpressions_to_keep: Set[sp.Symbol] = set()) -> 'AssignmentCollection':
"""Returns a new collection where all subexpressions have been inserted."""
if len(self.subexpressions) == 0:
return self.copy()
subexpressionSymbolsToKeep = set(subexpressionSymbolsToKeep)
subexpressions_to_keep = set(subexpressions_to_keep)
keptSubexpressions = []
if self.subexpressions[0].lhs in subexpressionSymbolsToKeep:
subsDict = {}
keptSubexpressions = self.subexpressions[0]
kept_subexpressions = []
if self.subexpressions[0].lhs in subexpressions_to_keep:
substitution_dict = {}
kept_subexpressions = self.subexpressions[0]
else:
subsDict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs}
substitution_dict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs}
subExpr = [e for e in self.subexpressions]
for i in range(1, len(subExpr)):
subExpr[i] = fastSubs(subExpr[i], subsDict)
if subExpr[i].lhs in subexpressionSymbolsToKeep:
keptSubexpressions.append(subExpr[i])
subexpression = [e for e in self.subexpressions]
for i in range(1, len(subexpression)):
subexpression[i] = fast_subs(subexpression[i], substitution_dict)
if subexpression[i].lhs in subexpressions_to_keep:
kept_subexpressions.append(subexpression[i])
else:
subsDict[subExpr[i].lhs] = subExpr[i].rhs
substitution_dict[subexpression[i].lhs] = subexpression[i].rhs
newEq = [fastSubs(eq, subsDict) for eq in self.mainAssignments]
return self.copy(newEq, keptSubexpressions)
new_assignment = [fast_subs(eq, substitution_dict) for eq in self.main_assignments]
return self.copy(new_assignment, kept_subexpressions)
def lambdify(self, symbols, module=None, fixedSymbols={}):
"""
Returns a function to evaluate this equation collection
:param symbols: symbol(s) which are the parameter for the created function
:param module: same as sympy.lambdify paramter of same same, i.e. which module to use e.g. 'numpy'
:param fixedSymbols: dictionary with substitutions, that are applied before lambdification
"""
eqs = self.copyWithSubstitutionsApplied(fixedSymbols).insertSubexpressions().mainAssignments
lambdas = {eq.lhs: sp.lambdify(symbols, eq.rhs, module) for eq in eqs}
# ----------------------------------------- Display and Printing -------------------------------------------------
def f(*args, **kwargs):
return {s: f(*args, **kwargs) for s, f in lambdas.items()}
def _repr_html_(self):
"""Interface to Jupyter notebook, to display as a nicely formatted HTML table"""
def make_html_equation_table(equations):
no_border = 'style="border:none"'
html_table = '<table style="border:none; width: 100%; ">'