diff --git a/__init__.py b/__init__.py index 36cc170327d526897bc65ce00205c599de7cfb10..d4455a766682f032373ada14e5c9024b71d4a383 100644 --- a/__init__.py +++ b/__init__.py @@ -5,6 +5,7 @@ from pystencils.kernelcreation import createKernel, createIndexedKernel from pystencils.display_utils import showCode, toDot from pystencils.assignment_collection import AssignmentCollection from pystencils.assignment import Assignment +from pystencils.sympyextensions import SymbolCreator __all__ = ['Field', 'FieldType', 'extractCommonSubexpressions', 'TypedSymbol', @@ -12,4 +13,5 @@ __all__ = ['Field', 'FieldType', 'extractCommonSubexpressions', 'createKernel', 'createIndexedKernel', 'showCode', 'toDot', 'AssignmentCollection', - 'Assignment'] + 'Assignment', + 'SymbolCreator'] diff --git a/assignment_collection/assignment_collection.py b/assignment_collection/assignment_collection.py index ea95cf6b9890924485af752493aac096aa11a852..4c66773589775dc276abfa89479b97644b8c84f3 100644 --- a/assignment_collection/assignment_collection.py +++ b/assignment_collection/assignment_collection.py @@ -1,312 +1,359 @@ import sympy as sp from copy import copy +from typing import List, Optional, Dict, Any, Set, Sequence, Iterator, Iterable from pystencils.assignment import Assignment -from pystencils.sympyextensions import fastSubs, countNumberOfOperations, sortEquationsTopologically +from pystencils.sympyextensions import fast_subs, count_operations, sort_assignments_topologically -class AssignmentCollection(object): +class AssignmentCollection: """ - A collection of equations with subexpression definitions, also represented as equations, + A collection of equations with subexpression definitions, also represented as assignments, that are used in the main equations. AssignmentCollection can be passed to simplification methods. These simplification methods can change the subexpressions, but the number and left hand side of the main equations themselves is not altered. Additionally a dictionary of simplification hints is stored, which are set by the functions that create equation collections to transport information to the simplification system. - :ivar mainAssignments: list of sympy equations - :ivar subexpressions: list of sympy equations defining subexpressions used in main equations - :ivar simplificationHints: dictionary that is used to annotate the equation collection with hints that are - used by the simplification system. See documentation of the simplification rules for - potentially required hints and their meaning. + Attributes: + main_assignments: list of assignments + subexpressions: list of assignments defining subexpressions used in main equations + simplification_hints: dict that is used to annotate the equation collection with hints that are + used by the simplification system. See documentation of the simplification rules for + potentially required hints and their meaning. + subexpression_symbol_generator: generator for new symbols that are used when new subexpressions are added + used to get new symbols that are unique for this AssignmentCollection + """ - # ----------------------------------------- Creation --------------------------------------------------------------- + # ------------------------------- Creation & Inplace Manipulation -------------------------------------------------- - def __init__(self, equations, subExpressions, simplificationHints=None, subexpressionSymbolNameGenerator=None): - self.mainAssignments = equations - self.subexpressions = subExpressions + def __init__(self, main_assignments: List[Assignment], subexpressions: List[Assignment], + simplification_hints: Optional[Dict[str, Any]] = None, + subexpression_symbol_generator: Iterator[sp.Symbol] = None) -> None: + self.main_assignments = main_assignments + self.subexpressions = subexpressions - if simplificationHints is None: - simplificationHints = {} + if simplification_hints is None: + simplification_hints = {} - self.simplificationHints = simplificationHints + self.simplification_hints = simplification_hints - if subexpressionSymbolNameGenerator is None: - self.subexpressionSymbolNameGenerator = SymbolGen() + if subexpression_symbol_generator is None: + self.subexpression_symbol_generator = SymbolGen() else: - self.subexpressionSymbolNameGenerator = subexpressionSymbolNameGenerator + self.subexpression_symbol_generator = subexpression_symbol_generator - @property - def mainTerms(self): - return [] + def add_simplification_hint(self, key: str, value: Any) -> None: + """Adds an entry to the simplification_hints dictionary and checks that is does not exist yet.""" + assert key not in self.simplification_hints, "This hint already exists" + self.simplification_hints[key] = value - def copy(self, mainAssignments=None, subexpressions=None): - res = copy(self) - res.simplificationHints = self.simplificationHints.copy() - res.subexpressionSymbolNameGenerator = copy(self.subexpressionSymbolNameGenerator) + def add_subexpression(self, rhs: sp.Expr, lhs: Optional[sp.Symbol] = None, topological_sort=True) -> sp.Symbol: + """Adds a subexpression to current collection. - if mainAssignments is not None: - res.mainAssignments = mainAssignments - else: - res.mainAssignments = self.mainAssignments.copy() + Args: + rhs: right hand side of new subexpression + lhs: optional left hand side of new subexpression. If None a new unique symbol is generated. + topological_sort: sort the subexpressions topologically after insertion, to make sure that + definition of a symbol comes before its usage. If False, subexpression is appended. - if subexpressions is not None: - res.subexpressions = subexpressions - else: - res.subexpressions = self.subexpressions.copy() - - return res - - def copyWithSubstitutionsApplied(self, substitutionDict, addSubstitutionsAsSubexpressions=False, - substituteOnLhs=True): - """ - Returns a new equation collection, where terms are substituted according to the passed `substitutionDict`. - Substitutions are made in the subexpression terms and the main equations + Returns: + left hand side symbol (which could have been generated) """ - if substituteOnLhs: - newSubexpressions = [fastSubs(eq, substitutionDict) for eq in self.subexpressions] - newEquations = [fastSubs(eq, substitutionDict) for eq in self.mainAssignments] - else: - newSubexpressions = [Assignment(eq.lhs, fastSubs(eq.rhs, substitutionDict)) for eq in self.subexpressions] - newEquations = [Assignment(eq.lhs, fastSubs(eq.rhs, substitutionDict)) for eq in self.mainAssignments] - - if addSubstitutionsAsSubexpressions: - newSubexpressions = [Assignment(b, a) for a, b in substitutionDict.items()] + newSubexpressions - newSubexpressions = sortEquationsTopologically(newSubexpressions) - return self.copy(newEquations, newSubexpressions) + if lhs is None: + lhs = sp.Dummy() + eq = Assignment(lhs, rhs) + self.subexpressions.append(eq) + if topological_sort: + self.topological_sort(sort_subexpressions=True, sort_main_assignments=False) + return lhs - def addSimplificationHint(self, key, value): - """ - Adds an entry to the simplificationHints dictionary, and checks that is does not exist yet - """ - assert key not in self.simplificationHints, "This hint already exists" - self.simplificationHints[key] = value + def topological_sort(self, sort_subexpressions: bool = True, sort_main_assignments: bool = True) -> None: + """Sorts subexpressions and/or main_equations topologically to make sure symbol usage comes after definition.""" + if sort_subexpressions: + self.subexpressions = sort_assignments_topologically(self.subexpressions) + if sort_main_assignments: + self.main_assignments = sort_assignments_topologically(self.main_assignments) # ---------------------------------------------- Properties ------------------------------------------------------- @property - def allEquations(self): - """Subexpression and main equations in one sequence""" - return self.subexpressions + self.mainAssignments + def all_assignments(self) -> List[Assignment]: + """Subexpression and main equations as a single list.""" + return self.subexpressions + self.main_assignments @property - def freeSymbols(self): - """All symbols used in the equation collection, which have not been defined inside the equation system""" - freeSymbols = set() - for eq in self.allEquations: - freeSymbols.update(eq.rhs.atoms(sp.Symbol)) - return freeSymbols - self.boundSymbols + def free_symbols(self) -> Set[sp.Symbol]: + """All symbols used in the assignment collection, which do not occur as left hand sides in any assignment.""" + free_symbols = set() + for eq in self.all_assignments: + free_symbols.update(eq.rhs.atoms(sp.Symbol)) + return free_symbols - self.bound_symbols @property - def boundSymbols(self): - """Set of all symbols which occur on left-hand-sides i.e. all symbols which are defined.""" - boundSymbolsSet = set([eq.lhs for eq in self.allEquations]) - assert len(boundSymbolsSet) == len(self.subexpressions) + len(self.mainAssignments), \ + def bound_symbols(self) -> Set[sp.Symbol]: + """All symbols which occur on the left hand side of a main assignment or a subexpression.""" + bound_symbols_set = set([eq.lhs for eq in self.all_assignments]) + assert len(bound_symbols_set) == len(self.subexpressions) + len(self.main_assignments), \ "Not in SSA form - same symbol assigned multiple times" - return boundSymbolsSet + return bound_symbols_set @property - def definedSymbols(self): - """All symbols that occur as left-hand-sides of the main equations""" - return set([eq.lhs for eq in self.mainAssignments]) + def defined_symbols(self) -> Set[sp.Symbol]: + """All symbols which occur as left-hand-sides of one of the main equations""" + return set([assignment.lhs for assignment in self.main_assignments]) @property - def operationCount(self): - """See :func:`countNumberOfOperations` """ - return countNumberOfOperations(self.allEquations, onlyType=None) + def operation_count(self): + """See :func:`count_operations` """ + return count_operations(self.all_assignments, only_type=None) + + def dependent_symbols(self, symbols: Iterable[sp.Symbol]) -> Set[sp.Symbol]: + """Returns all symbols that depend on one of the passed symbols. - def get(self, symbols, frommainAssignmentsOnly=False): - """Return the equations which have symbols as left hand sides""" + A symbol 'a' depends on a symbol 'b', if there is an assignment 'a <- someExpression(b)' i.e. when + 'b' is required to compute 'a'. + """ + + queue = list(symbols) + + def add_symbols_from_expr(expr): + dependent_symbols = expr.atoms(sp.Symbol) + for ds in dependent_symbols: + queue.append(ds) + + handled_symbols = set() + assignment_dict = {e.lhs: e.rhs for e in self.all_assignments} + + while len(queue) > 0: + e = queue.pop(0) + if e in handled_symbols: + continue + if e in assignment_dict: + add_symbols_from_expr(assignment_dict[e]) + handled_symbols.add(e) + + return handled_symbols + + def get(self, symbols: Sequence[sp.Symbol], from_main_assignments_only=False) -> List[Assignment]: + """Extracts all assignments that have a left hand side that is contained in the symbols parameter. + + Args: + symbols: return assignments that have one of these symbols as left hand side + from_main_assignments_only: search only in main assignments (exclude subexpressions) + """ if not hasattr(symbols, "__len__"): - symbols = list(symbols) - symbols = set(symbols) + symbols = set(symbols) + else: + symbols = set(symbols) - if not frommainAssignmentsOnly: - eqsToSearchIn = self.allEquations + if not from_main_assignments_only: + assignments_to_search = self.all_assignments else: - eqsToSearchIn = self.mainAssignments + assignments_to_search = self.main_assignments - return [eq for eq in eqsToSearchIn if eq.lhs in symbols] + return [assignment for assignment in assignments_to_search if assignment.lhs in symbols] - # ----------------------------------------- Display and Printing ------------------------------------------------- + def lambdify(self, symbols: Sequence[sp.Symbol], fixed_symbols: Optional[Dict[sp.Symbol, Any]]=None, module=None): + """Returns a python function to evaluate this equation collection. - def _repr_html_(self): - def makeHtmlEquationTable(equations): - noBorder = 'style="border:none"' - htmlTable = '<table style="border:none; width: 100%; ">' - line = '<tr {nb}> <td {nb}>$${eq}$$</td> </tr> ' - for eq in equations: - formatDict = {'eq': sp.latex(eq), - 'nb': noBorder, } - htmlTable += line.format(**formatDict) - htmlTable += "</table>" - return htmlTable + Args: + symbols: symbol(s) which are the parameter for the created function + fixed_symbols: dictionary with substitutions, that are applied before sympy's lambdify + module: same as sympy.lambdify parameter. Defines which module to use e.g. 'numpy' - result = "" - if len(self.subexpressions) > 0: - result += "<div>Subexpressions:</div>" - result += makeHtmlEquationTable(self.subexpressions) - result += "<div>Main Assignments:</div>" - result += makeHtmlEquationTable(self.mainAssignments) - return result + Examples: + >>> a, b, c, d = sp.symbols("a b c d") + >>> ac = AssignmentCollection([Assignment(c, a + b), Assignment(d, a**2 + b)], + ... subexpressions=[Assignment(b, a + b / 2)]) + >>> python_function = ac.lambdify([a], fixed_symbols={b: 2}) + >>> python_function(4) + {c: 6, d: 18} + """ + assignments = self.new_with_substitutions(fixed_symbols, substitute_on_lhs=False) if fixed_symbols else self + assignments = assignments.new_without_subexpressions().main_assignments + lambdas = {assignment.lhs: sp.lambdify(symbols, assignment.rhs, module) for assignment in assignments} - def __repr__(self): - return "Equation Collection for " + ",".join([str(eq.lhs) for eq in self.mainAssignments]) + def f(*args, **kwargs): + return {s: func(*args, **kwargs) for s, func in lambdas.items()} - def __str__(self): - result = "Subexpressions\n" - for eq in self.subexpressions: - result += str(eq) + "\n" - result += "Main Assignments\n" - for eq in self.mainAssignments: - result += str(eq) + "\n" - return result + return f + # ---------------------------- Creating new modified collections --------------------------------------------------- - # ------------------------------------- Manipulation ------------------------------------------------------------ + def copy(self, + main_assignments: Optional[List[Assignment]] = None, + subexpressions: Optional[List[Assignment]] = None) -> 'AssignmentCollection': + """Returns a copy with optionally replaced main_assignments and/or subexpressions.""" - def merge(self, other): - """Returns a new collection which contains self and other. Subexpressions are renamed if they clash.""" - ownDefs = set([e.lhs for e in self.mainAssignments]) - otherDefs = set([e.lhs for e in other.mainAssignments]) - assert len(ownDefs.intersection(otherDefs)) == 0, "Cannot merge, since both collection define the same symbols" + res = copy(self) + res.simplification_hints = self.simplification_hints.copy() + res.subexpression_symbol_generator = copy(self.subexpression_symbol_generator) - ownSubexpressionSymbols = {e.lhs: e.rhs for e in self.subexpressions} - substitutionDict = {} + if main_assignments is not None: + res.main_assignments = main_assignments + else: + res.main_assignments = self.main_assignments.copy() - processedOtherSubexpressionEquations = [] - for otherSubexpressionEq in other.subexpressions: - if otherSubexpressionEq.lhs in ownSubexpressionSymbols: - if otherSubexpressionEq.rhs == ownSubexpressionSymbols[otherSubexpressionEq.lhs]: - continue # exact the same subexpression equation exists already - else: - # different definition - a new name has to be introduced - newLhs = next(self.subexpressionSymbolNameGenerator) - newEq = Assignment(newLhs, fastSubs(otherSubexpressionEq.rhs, substitutionDict)) - processedOtherSubexpressionEquations.append(newEq) - substitutionDict[otherSubexpressionEq.lhs] = newLhs - else: - processedOtherSubexpressionEquations.append(fastSubs(otherSubexpressionEq, substitutionDict)) + if subexpressions is not None: + res.subexpressions = subexpressions + else: + res.subexpressions = self.subexpressions.copy() - processedOthermainAssignments = [fastSubs(eq, substitutionDict) for eq in other.mainAssignments] - return self.copy(self.mainAssignments + processedOthermainAssignments, - self.subexpressions + processedOtherSubexpressionEquations) + return res - def getDependentSymbols(self, symbolSequence): - """Returns a list of symbols that depend on the passed symbols.""" + def new_with_substitutions(self, substitutions: Dict, add_substitutions_as_subexpressions: bool = False, + substitute_on_lhs: bool = True) -> 'AssignmentCollection': + """Returns new object, where terms are substituted according to the passed substitution dict. - queue = list(symbolSequence) + Args: + substitutions: dict that is passed to sympy subs, substitutions are done main assignments and subexpressions + add_substitutions_as_subexpressions: if True, the substitutions are added as assignments to subexpressions + substitute_on_lhs: if False, the substitutions are done only on the right hand side of assignments - def addSymbolsFromExpr(expr): - dependentSymbols = expr.atoms(sp.Symbol) - for ds in dependentSymbols: - queue.append(ds) + Returns: + New AssignmentCollection where substitutions have been applied, self is not altered. + """ + if substitute_on_lhs: + new_subexpressions = [fast_subs(eq, substitutions) for eq in self.subexpressions] + new_equations = [fast_subs(eq, substitutions) for eq in self.main_assignments] + else: + new_subexpressions = [Assignment(eq.lhs, fast_subs(eq.rhs, substitutions)) for eq in self.subexpressions] + new_equations = [Assignment(eq.lhs, fast_subs(eq.rhs, substitutions)) for eq in self.main_assignments] - handledSymbols = set() - eqMap = {e.lhs: e.rhs for e in self.allEquations} + if add_substitutions_as_subexpressions: + new_subexpressions = [Assignment(b, a) for a, b in substitutions.items()] + new_subexpressions + new_subexpressions = sort_assignments_topologically(new_subexpressions) + return self.copy(new_equations, new_subexpressions) - while len(queue) > 0: - e = queue.pop(0) - if e in handledSymbols: - continue - if e in eqMap: - addSymbolsFromExpr(eqMap[e]) - handledSymbols.add(e) + def new_merged(self, other: 'AssignmentCollection') -> 'AssignmentCollection': + """Returns a new collection which contains self and other. Subexpressions are renamed if they clash.""" + own_definitions = set([e.lhs for e in self.main_assignments]) + other_definitions = set([e.lhs for e in other.main_assignments]) + assert len(own_definitions.intersection(other_definitions)) == 0, \ + "Cannot new_merged, since both collection define the same symbols" - return handledSymbols + own_subexpression_symbols = {e.lhs: e.rhs for e in self.subexpressions} + substitution_dict = {} - def extract(self, symbolsToExtract): - """ - Creates a new equation collection with equations that have symbolsToExtract as left-hand-sides and - only the necessary subexpressions that are used in these equations - """ - symbolsToExtract = set(symbolsToExtract) - dependentSymbols = self.getDependentSymbols(symbolsToExtract) - newEquations = [] - for eq in self.allEquations: - if eq.lhs in symbolsToExtract: - newEquations.append(eq) - - newSubExpr = [eq for eq in self.subexpressions if eq.lhs in dependentSymbols and eq.lhs not in symbolsToExtract] - return AssignmentCollection(newEquations, newSubExpr) - - def newWithoutUnusedSubexpressions(self): - """Returns a new equation collection containing only the subexpressions that - are used/referenced in the equations""" - allLhs = [eq.lhs for eq in self.mainAssignments] - return self.extract(allLhs) - - def appendToSubexpressions(self, rhs, lhs=None, topologicalSort=True): - if lhs is None: - lhs = sp.Dummy() - eq = Assignment(lhs, rhs) - self.subexpressions.append(eq) - if topologicalSort: - self.topologicalSort(subexpressions=True, mainAssignments=False) - return lhs + processed_other_subexpression_equations = [] + for otherSubexpressionEq in other.subexpressions: + if otherSubexpressionEq.lhs in own_subexpression_symbols: + if otherSubexpressionEq.rhs == own_subexpression_symbols[otherSubexpressionEq.lhs]: + continue # exact the same subexpression equation exists already + else: + # different definition - a new name has to be introduced + new_lhs = next(self.subexpression_symbol_generator) + new_eq = Assignment(new_lhs, fast_subs(otherSubexpressionEq.rhs, substitution_dict)) + processed_other_subexpression_equations.append(new_eq) + substitution_dict[otherSubexpressionEq.lhs] = new_lhs + else: + processed_other_subexpression_equations.append(fast_subs(otherSubexpressionEq, substitution_dict)) + + processed_other_main_assignments = [fast_subs(eq, substitution_dict) for eq in other.main_assignments] + return self.copy(self.main_assignments + processed_other_main_assignments, + self.subexpressions + processed_other_subexpression_equations) - def topologicalSort(self, subexpressions=True, mainAssignments=True): - if subexpressions: - self.subexpressions = sortEquationsTopologically(self.subexpressions) - if mainAssignments: - self.mainAssignments = sortEquationsTopologically(self.mainAssignments) + def new_filtered(self, symbols_to_extract: Iterable[sp.Symbol]) -> 'AssignmentCollection': + """Extracts equations that have symbols_to_extract as left hand side, together with necessary subexpressions. - def insertSubexpression(self, symbol): - newSubexpressions = [] - subsDict = None + Returns: + new AssignmentCollection, self is not altered + """ + symbols_to_extract = set(symbols_to_extract) + dependent_symbols = self.dependent_symbols(symbols_to_extract) + new_assignments = [] + for eq in self.all_assignments: + if eq.lhs in symbols_to_extract: + new_assignments.append(eq) + + new_sub_expr = [eq for eq in self.subexpressions + if eq.lhs in dependent_symbols and eq.lhs not in symbols_to_extract] + return AssignmentCollection(new_assignments, new_sub_expr) + + def new_without_unused_subexpressions(self) -> 'AssignmentCollection': + """Returns new collection that only contains subexpressions required to compute the main assignments.""" + all_lhs = [eq.lhs for eq in self.main_assignments] + return self.new_filtered(all_lhs) + + def new_with_inserted_subexpression(self, symbol: sp.Symbol) -> 'AssignmentCollection': + """Eliminates the subexpression with the given symbol on its left hand side, by substituting it everywhere.""" + new_subexpressions = [] + subs_dict = None for se in self.subexpressions: if se.lhs == symbol: - subsDict = {se.lhs: se.rhs} + subs_dict = {se.lhs: se.rhs} else: - newSubexpressions.append(se) - if subsDict is None: + new_subexpressions.append(se) + if subs_dict is None: return self - newSubexpressions = [Assignment(eq.lhs, fastSubs(eq.rhs, subsDict)) for eq in newSubexpressions] - newEqs = [Assignment(eq.lhs, fastSubs(eq.rhs, subsDict)) for eq in self.mainAssignments] - return self.copy(newEqs, newSubexpressions) + new_subexpressions = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in new_subexpressions] + new_eqs = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in self.main_assignments] + return self.copy(new_eqs, new_subexpressions) - def insertSubexpressions(self, subexpressionSymbolsToKeep=set()): - """Returns a new equation collection by inserting all subexpressions into the main equations""" + def new_without_subexpressions(self, subexpressions_to_keep: Set[sp.Symbol] = set()) -> 'AssignmentCollection': + """Returns a new collection where all subexpressions have been inserted.""" if len(self.subexpressions) == 0: return self.copy() - subexpressionSymbolsToKeep = set(subexpressionSymbolsToKeep) + subexpressions_to_keep = set(subexpressions_to_keep) - keptSubexpressions = [] - if self.subexpressions[0].lhs in subexpressionSymbolsToKeep: - subsDict = {} - keptSubexpressions = self.subexpressions[0] + kept_subexpressions = [] + if self.subexpressions[0].lhs in subexpressions_to_keep: + substitution_dict = {} + kept_subexpressions = self.subexpressions[0] else: - subsDict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs} + substitution_dict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs} - subExpr = [e for e in self.subexpressions] - for i in range(1, len(subExpr)): - subExpr[i] = fastSubs(subExpr[i], subsDict) - if subExpr[i].lhs in subexpressionSymbolsToKeep: - keptSubexpressions.append(subExpr[i]) + subexpression = [e for e in self.subexpressions] + for i in range(1, len(subexpression)): + subexpression[i] = fast_subs(subexpression[i], substitution_dict) + if subexpression[i].lhs in subexpressions_to_keep: + kept_subexpressions.append(subexpression[i]) else: - subsDict[subExpr[i].lhs] = subExpr[i].rhs + substitution_dict[subexpression[i].lhs] = subexpression[i].rhs - newEq = [fastSubs(eq, subsDict) for eq in self.mainAssignments] - return self.copy(newEq, keptSubexpressions) + new_assignment = [fast_subs(eq, substitution_dict) for eq in self.main_assignments] + return self.copy(new_assignment, kept_subexpressions) - def lambdify(self, symbols, module=None, fixedSymbols={}): - """ - Returns a function to evaluate this equation collection - :param symbols: symbol(s) which are the parameter for the created function - :param module: same as sympy.lambdify paramter of same same, i.e. which module to use e.g. 'numpy' - :param fixedSymbols: dictionary with substitutions, that are applied before lambdification - """ - eqs = self.copyWithSubstitutionsApplied(fixedSymbols).insertSubexpressions().mainAssignments - lambdas = {eq.lhs: sp.lambdify(symbols, eq.rhs, module) for eq in eqs} + # ----------------------------------------- Display and Printing ------------------------------------------------- - def f(*args, **kwargs): - return {s: f(*args, **kwargs) for s, f in lambdas.items()} + def _repr_html_(self): + """Interface to Jupyter notebook, to display as a nicely formatted HTML table""" + def make_html_equation_table(equations): + no_border = 'style="border:none"' + html_table = '<table style="border:none; width: 100%; ">' + line = '<tr {nb}> <td {nb}>$${eq}$$</td> </tr> ' + for eq in equations: + format_dict = {'eq': sp.latex(eq), + 'nb': no_border, } + html_table += line.format(**format_dict) + html_table += "</table>" + return html_table - return f + result = "" + if len(self.subexpressions) > 0: + result += "<div>Subexpressions:</div>" + result += make_html_equation_table(self.subexpressions) + result += "<div>Main Assignments:</div>" + result += make_html_equation_table(self.main_assignments) + return result + + def __repr__(self): + return "Equation Collection for " + ",".join([str(eq.lhs) for eq in self.main_assignments]) + + def __str__(self): + result = "Subexpressions\n" + for eq in self.subexpressions: + result += str(eq) + "\n" + result += "Main Assignments\n" + for eq in self.main_assignments: + result += str(eq) + "\n" + return result class SymbolGen: + """Default symbol generator producing number symbols ζ_0, ζ_1, ...""" def __init__(self): self._ctr = 0 diff --git a/assignment_collection/simplifications.py b/assignment_collection/simplifications.py index 7d0eab53d92faa37d30c7d31c3db44365d79588c..b7707bc998b95885069efa92cfe2b2125404ec1f 100644 --- a/assignment_collection/simplifications.py +++ b/assignment_collection/simplifications.py @@ -1,87 +1,93 @@ import sympy as sp - +from typing import Callable, List from pystencils import Assignment, AssignmentCollection -from pystencils.sympyextensions import replaceAdditive +from pystencils.sympyextensions import subs_additive -def sympyCseOnEquationList(eqs): - ec = AssignmentCollection(eqs, []) - return sympyCSE(ec).allEquations +def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]: + """Extracts common subexpressions from a list of assignments.""" + ec = AssignmentCollection(assignments, []) + return sympy_cse(ec).all_assignments -def sympyCSE(assignment_collection): - """ - Searches for common subexpressions inside the equation collection, in both the existing subexpressions as well - as the equations themselves. It uses the sympy subexpression detection to do this. Return a new equation collection +def sympy_cse(ac: AssignmentCollection) -> AssignmentCollection: + """Searches for common subexpressions inside the equation collection. + + Searches is done in both the existing subexpressions as well as the assignments themselves. + It uses the sympy subexpression detection to do this. Return a new equation collection with the additional subexpressions found """ - symbolGen = assignment_collection.subexpressionSymbolNameGenerator - replacements, newEq = sp.cse(assignment_collection.subexpressions + assignment_collection.mainAssignments, - symbols=symbolGen) - replacementEqs = [Assignment(*r) for r in replacements] + symbol_gen = ac.subexpression_symbol_generator + replacements, new_eq = sp.cse(ac.subexpressions + ac.main_assignments, + symbols=symbol_gen) + replacement_eqs = [Assignment(*r) for r in replacements] - modifiedSubexpressions = newEq[:len(assignment_collection.subexpressions)] - modifiedUpdateEquations = newEq[len(assignment_collection.subexpressions):] + modified_subexpressions = new_eq[:len(ac.subexpressions)] + modified_update_equations = new_eq[len(ac.subexpressions):] - newSubexpressions = replacementEqs + modifiedSubexpressions - topologicallySortedPairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in newSubexpressions]) - newSubexpressions = [Assignment(a[0], a[1]) for a in topologicallySortedPairs] + new_subexpressions = replacement_eqs + modified_subexpressions + topologically_sorted_pairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in new_subexpressions]) + new_subexpressions = [Assignment(a[0], a[1]) for a in topologically_sorted_pairs] - return assignment_collection.copy(modifiedUpdateEquations, newSubexpressions) + return ac.copy(modified_update_equations, new_subexpressions) -def applyOnAllEquations(assignment_collection, operation): +def apply_to_all_assignments(assignment_collection: AssignmentCollection, + operation: Callable[[sp.Expr], sp.Expr]) -> AssignmentCollection: """Applies sympy expand operation to all equations in collection""" - result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.mainAssignments] + result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.main_assignments] return assignment_collection.copy(result) -def applyOnAllSubexpressions(assignment_collection, operation): - result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.subexpressions] - return assignment_collection.copy(assignment_collection.mainAssignments, result) +def apply_on_all_subexpressions(ac: AssignmentCollection, + operation: Callable[[sp.Expr], sp.Expr]) -> AssignmentCollection: + result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in ac.subexpressions] + return ac.copy(ac.main_assignments, result) -def subexpressionSubstitutionInExistingSubexpressions(assignment_collection): +def subexpression_substitution_in_existing_subexpressions(ac: AssignmentCollection) -> AssignmentCollection: """Goes through the subexpressions list and replaces the term in the following subexpressions""" result = [] - for outerCtr, s in enumerate(assignment_collection.subexpressions): - newRhs = s.rhs + for outerCtr, s in enumerate(ac.subexpressions): + new_rhs = s.rhs for innerCtr in range(outerCtr): - subExpr = assignment_collection.subexpressions[innerCtr] - newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0) - newRhs = newRhs.subs(subExpr.rhs, subExpr.lhs) - result.append(Assignment(s.lhs, newRhs)) + sub_expr = ac.subexpressions[innerCtr] + new_rhs = subs_additive(new_rhs, sub_expr.lhs, sub_expr.rhs, required_match_replacement=1.0) + new_rhs = new_rhs.subs(sub_expr.rhs, sub_expr.lhs) + result.append(Assignment(s.lhs, new_rhs)) - return assignment_collection.copy(assignment_collection.mainAssignments, result) + return ac.copy(ac.main_assignments, result) -def subexpressionSubstitutionInmainAssignments(assignment_collection): - """Replaces already existing subexpressions in the equations of the assignment_collection""" +def subexpression_substitution_in_main_assignments(ac: AssignmentCollection) -> AssignmentCollection: + """Replaces already existing subexpressions in the equations of the assignment_collection.""" result = [] - for s in assignment_collection.mainAssignments: - newRhs = s.rhs - for subExpr in assignment_collection.subexpressions: - newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0) - result.append(Assignment(s.lhs, newRhs)) - return assignment_collection.copy(result) + for s in ac.main_assignments: + new_rhs = s.rhs + for subExpr in ac.subexpressions: + new_rhs = subs_additive(new_rhs, subExpr.lhs, subExpr.rhs, required_match_replacement=1.0) + result.append(Assignment(s.lhs, new_rhs)) + return ac.copy(result) -def addSubexpressionsForDivisions(assignment_collection): +def add_subexpressions_for_divisions(ac: AssignmentCollection) -> AssignmentCollection: """Introduces subexpressions for all divisions which have no constant in the denominator. - e.g. :math:`\frac{1}{x}` is replaced, :math:`\frac{1}{3}` is not replaced.""" + + For example :math:`\frac{1}{x}` is replaced, :math:`\frac{1}{3}` is not replaced. + """ divisors = set() - def searchDivisors(term): + def search_divisors(term): if term.func == sp.Pow: if term.exp.is_integer and term.exp.is_number and term.exp < 0: divisors.add(term) else: for a in term.args: - searchDivisors(a) + search_divisors(a) - for eq in assignment_collection.allEquations: - searchDivisors(eq.rhs) + for eq in ac.all_assignments: + search_divisors(eq.rhs) - newSymbolGen = assignment_collection.subexpressionSymbolNameGenerator - substitutions = {divisor: newSymbol for newSymbol, divisor in zip(newSymbolGen, divisors)} - return assignment_collection.copyWithSubstitutionsApplied(substitutions, True) + new_symbol_gen = ac.subexpression_symbol_generator + substitutions = {divisor: newSymbol for newSymbol, divisor in zip(new_symbol_gen, divisors)} + return ac.new_with_substitutions(substitutions, True) diff --git a/assignment_collection/simplificationstrategy.py b/assignment_collection/simplificationstrategy.py index 3d8cdd62fb0a93752f524e2cbf1c62f384949f3f..fad2279544a4a6085fbcd091be50ac1688fbbd38 100644 --- a/assignment_collection/simplificationstrategy.py +++ b/assignment_collection/simplificationstrategy.py @@ -1,10 +1,12 @@ import sympy as sp from collections import namedtuple +from typing import Callable, Any, Optional, Sequence +from pystencils.assignment_collection.assignment_collection import AssignmentCollection class SimplificationStrategy(object): - """ - A simplification strategy is an ordered collection of simplification rules. + """A simplification strategy is an ordered collection of simplification rules. + Each simplification is a function taking an equation collection, and returning a new simplified equation collection. The strategy can nicely print intermediate simplification stages and results to Jupyter notebooks. @@ -13,10 +15,11 @@ class SimplificationStrategy(object): def __init__(self): self._rules = [] - def add(self, rule): - """ - Adds the given simplification rule to the end of the collection. - :param rule: function that taking one equation collection and returning a (simplified) equation collection + def add(self, rule: Callable[[AssignmentCollection], AssignmentCollection]) -> None: + """Adds the given simplification rule to the end of the collection. + + Args: + rule: function that rewrites/simplifies an assignment collection """ self._rules.append(rule) @@ -24,19 +27,20 @@ class SimplificationStrategy(object): def rules(self): return self._rules - def apply(self, updateRule): - """Applies all simplification rules to the given equation collection""" + def apply(self, assignment_collection: AssignmentCollection) -> AssignmentCollection: + """Runs all rules on the given assignment collection.""" for t in self._rules: - updateRule = t(updateRule) - return updateRule + assignment_collection = t(assignment_collection) + return assignment_collection - def __call__(self, assignment_collection): + def __call__(self, assignment_collection: AssignmentCollection) -> AssignmentCollection: """Same as apply""" return self.apply(assignment_collection) - def createSimplificationReport(self, assignment_collection): - """ - Returns a simplification report containing the number of operations at each simplification stage, together + def create_simplification_report(self, assignment_collection: AssignmentCollection) -> Any: + """Creates a report to be displayed as HTML in a Jupyter notebook. + + The simplification report contains the number of operations at each simplification stage together with the run-time the simplification took. """ @@ -60,70 +64,83 @@ class SimplificationStrategy(object): return result def _repr_html_(self): - htmlTable = '<table style="border:none">' - htmlTable += "<tr><th>Name</th><th>Runtime</th><th>Adds</th><th>Muls</th><th>Divs</th><th>Total</th></tr>" + html_table = '<table style="border:none">' + html_table += "<tr><th>Name</th>" \ + "<th>Runtime</th>" \ + "<th>Adds</th>" \ + "<th>Muls</th>" \ + "<th>Divs</th>" \ + "<th>Total</th></tr>" line = "<tr><td>{simplificationName}</td>" \ "<td>{runtime}</td> <td>{adds}</td> <td>{muls}</td> <td>{divs}</td> <td>{total}</td> </tr>" for e in self.elements: - htmlTable += line.format(**e._asdict()) - htmlTable += "</table>" - return htmlTable + # noinspection PyProtectedMember + html_table += line.format(**e._asdict()) + html_table += "</table>" + return html_table import timeit report = Report() - op = assignment_collection.operationCount + op = assignment_collection.operation_count total = op['adds'] + op['muls'] + op['divs'] report.add(ReportElement("OriginalTerm", '-', op['adds'], op['muls'], op['divs'], total)) for t in self._rules: - startTime = timeit.default_timer() + start_time = timeit.default_timer() assignment_collection = t(assignment_collection) - endTime = timeit.default_timer() - op = assignment_collection.operationCount - timeStr = "%.2f ms" % ((endTime - startTime) * 1000,) + end_time = timeit.default_timer() + op = assignment_collection.operation_count + time_str = "%.2f ms" % ((end_time - start_time) * 1000,) total = op['adds'] + op['muls'] + op['divs'] - report.add(ReportElement(t.__name__, timeStr, op['adds'], op['muls'], op['divs'], total)) + report.add(ReportElement(t.__name__, time_str, op['adds'], op['muls'], op['divs'], total)) return report - def showIntermediateResults(self, assignment_collection, symbols=None): + def show_intermediate_results(self, assignment_collection: AssignmentCollection, + symbols: Optional[Sequence[sp.Symbol]] = None) -> Any: + """Shows the assignment collection after the application of each rule as HTML report for Jupyter notebook. + Args: + assignment_collection: the collection to apply the rules to + symbols: if not None, only the assignments are shown that have one of these symbols as left hand side + """ class IntermediateResults: - def __init__(self, strategy, eqColl, resSyms): + def __init__(self, strategy, collection, restrict_symbols): self.strategy = strategy - self.assignment_collection = eqColl - self.restrictSymbols = resSyms + self.assignment_collection = collection + self.restrict_symbols = restrict_symbols def __str__(self): - def printEqCollection(title, eqColl): + def print_assignment_collection(title, c): text = title - if self.restrictSymbols: - text += "\n".join([str(e) for e in eqColl.get(self.restrictSymbols)]) + if self.restrict_symbols: + text += "\n".join([str(e) for e in c.get(self.restrict_symbols)]) else: - text += (" " * 3 + (" " * 3).join(str(eqColl).splitlines(True))) + text += (" " * 3 + (" " * 3).join(str(c).splitlines(True))) return text - result = printEqCollection("Initial Version", self.assignment_collection) - eqColl = self.assignment_collection + result = print_assignment_collection("Initial Version", self.assignment_collection) + collection = self.assignment_collection for rule in self.strategy.rules: - eqColl = rule(eqColl) - result += printEqCollection(rule.__name__, eqColl) + collection = rule(collection) + result += print_assignment_collection(rule.__name__, collection) return result def _repr_html_(self): - def printEqCollection(title, eqColl): + def print_assignment_collection(title, c): text = '<h5 style="padding-bottom:10px">%s</h5> <div style="padding-left:20px;">' % (title, ) - if self.restrictSymbols: - text += "\n".join(["$$" + sp.latex(e) + '$$' for e in eqColl.get(self.restrictSymbols)]) + if self.restrict_symbols: + text += "\n".join(["$$" + sp.latex(e) + '$$' for e in c.get(self.restrict_symbols)]) else: - text += eqColl._repr_html_() + # noinspection PyProtectedMember + text += c._repr_html_() text += "</div>" return text - result = printEqCollection("Initial Version", self.assignment_collection) - eqColl = self.assignment_collection + result = print_assignment_collection("Initial Version", self.assignment_collection) + collection = self.assignment_collection for rule in self.strategy.rules: - eqColl = rule(eqColl) - result += printEqCollection(rule.__name__, eqColl) + collection = rule(collection) + result += print_assignment_collection(rule.__name__, collection) return result return IntermediateResults(self, assignment_collection, symbols) diff --git a/astnodes.py b/astnodes.py index e98c73ad652a9ccbd481b4bb0b7ccc348e06b88c..58780ac4ea1d745af1986cb1a56bd357cb87199c 100644 --- a/astnodes.py +++ b/astnodes.py @@ -2,7 +2,7 @@ import sympy as sp from sympy.tensor import IndexedBase from pystencils.field import Field from pystencils.data_types import TypedSymbol, createType, castFunc -from pystencils.sympyextensions import fastSubs +from pystencils.sympyextensions import fast_subs class Node(object): @@ -275,11 +275,11 @@ class Block(Node): @property def undefinedSymbols(self): result = set() - definedSymbols = set() + defined_symbols = set() for a in self.args: result.update(a.undefinedSymbols) - definedSymbols.update(a.symbolsDefined) - return result - definedSymbols + defined_symbols.update(a.symbolsDefined) + return result - defined_symbols def __str__(self): return "Block " + ''.join('{!s}\n'.format(node) for node in self._nodes) @@ -426,8 +426,8 @@ class SympyAssignment(Node): self._isDeclaration = False def subs(self, *args, **kwargs): - self.lhs = fastSubs(self.lhs, *args, **kwargs) - self.rhs = fastSubs(self.rhs, *args, **kwargs) + self.lhs = fast_subs(self.lhs, *args, **kwargs) + self.rhs = fast_subs(self.rhs, *args, **kwargs) @property def args(self): @@ -494,11 +494,11 @@ class ResolvedFieldAccess(sp.Indexed): self.args[1].subs(old, new), self.field, self.offsets, self.idxCoordinateValues) - def fastSubs(self, subsDict): - if self in subsDict: - return subsDict[self] - return ResolvedFieldAccess(self.args[0].subs(subsDict), - self.args[1].subs(subsDict), + def fast_subs(self, substitutions): + if self in substitutions: + return substitutions[self] + return ResolvedFieldAccess(self.args[0].subs(substitutions), + self.args[1].subs(substitutions), self.field, self.offsets, self.idxCoordinateValues) def _hashable_content(self): diff --git a/derivative.py b/derivative.py index 75db20b5e3ce05a018572a90e46c87851360392e..c2dadd8e4e540d62f66b5c33393a366bbdc3f246 100644 --- a/derivative.py +++ b/derivative.py @@ -1,6 +1,6 @@ import sympy as sp from collections import namedtuple, defaultdict -from pystencils.sympyextensions import normalizeProduct, prod +from pystencils.sympyextensions import normalize_product, prod def defaultDiffSortKey(d): @@ -57,7 +57,7 @@ class Diff(sp.Expr): if self.arg.func != sp.Mul: constant, variable = 1, self.arg else: - for factor in normalizeProduct(self.arg): + for factor in normalize_product(self.arg): if factor in functions or isinstance(factor, Diff): variable *= factor else: @@ -150,7 +150,7 @@ class DiffOperator(sp.Expr): i.e. DiffOperator('x')*DiffOperator('x') is a second derivative replaced by Diff(Diff(arg, x), t) """ def handleMul(mul): - args = normalizeProduct(mul) + args = normalize_product(mul) diffs = [a for a in args if isinstance(a, DiffOperator)] if len(diffs) == 0: return mul * argument if applyToConstants else mul @@ -254,7 +254,7 @@ def fullDiffExpand(expr, functions=None, constants=None): for term in diffInner.args if diffInner.func == sp.Add else [diffInner]: independentTerms = 1 dependentTerms = [] - for factor in normalizeProduct(term): + for factor in normalize_product(term): if factor in functions or isinstance(factor, Diff): dependentTerms.append(factor) else: @@ -310,7 +310,7 @@ def expandUsingProductRule(expr): if arg.func not in (sp.Mul, sp.Pow): return Diff(arg, target=expr.target, superscript=expr.superscript) else: - prodList = normalizeProduct(arg) + prodList = normalize_product(arg) result = 0 for i in range(len(prodList)): preFactor = prod(prodList[j] for j in range(len(prodList)) if i != j) @@ -347,7 +347,7 @@ def combineUsingProductRule(expr): if isinstance(term, Diff): diffDict[DiffInfo(term.target, term.superscript)].append(DiffSplit(1, term.arg)) else: - mulArgs = normalizeProduct(term) + mulArgs = normalize_product(term) diffs = [d for d in mulArgs if isinstance(d, Diff)] factor = prod(d for d in mulArgs if not isinstance(d, Diff)) if len(diffs) == 0: diff --git a/field.py b/field.py index d310f9b77d341205912bc6c030fa2877c9cd3166..60ac7dd32e7e0b225dbc8b7c905417d7fe13655f 100644 --- a/field.py +++ b/field.py @@ -8,7 +8,7 @@ from sympy.tensor import IndexedBase from pystencils.assignment import Assignment from pystencils.alignedarray import aligned_empty from pystencils.data_types import TypedSymbol, createType, createCompositeTypeFromString, StructType -from pystencils.sympyextensions import isIntegerSequence +from pystencils.sympyextensions import is_integer_sequence class FieldType(Enum): @@ -221,7 +221,7 @@ class Field(object): @property def hasFixedShape(self): - return isIntegerSequence(self.shape) + return is_integer_sequence(self.shape) @property def indexShape(self): @@ -229,7 +229,7 @@ class Field(object): @property def hasFixedIndexShape(self): - return isIntegerSequence(self.indexShape) + return is_integer_sequence(self.indexShape) @property def spatialStrides(self): diff --git a/finitedifferences.py b/finitedifferences.py index 8dba62707946d985d3f5d381f3923018519f48b7..189e414a85d5cb7d316b246c8301f67949fc880d 100644 --- a/finitedifferences.py +++ b/finitedifferences.py @@ -3,7 +3,7 @@ import sympy as sp from pystencils.assignment_collection import AssignmentCollection from pystencils.field import Field -from pystencils.transformations import fastSubs +from pystencils.sympyextensions import fast_subs from pystencils.derivative import Diff @@ -103,7 +103,7 @@ def discretizeStaggered(term, symbolsToFieldDict, coordinate, coordinateOffset, neighborGrad = (field[up+offset](i) - field[down+offset](i)) / (2 * dx) substitutions[grad(s)[d]] = (centerGrad + neighborGrad) / 2 - return fastSubs(term, substitutions) + return fast_subs(term, substitutions) def discretizeDivergence(vectorTerm, symbolsToFieldDict, dx): @@ -356,7 +356,7 @@ class Discretization2ndOrder: elif isinstance(expr, sp.Matrix): return expr.applyfunc(self.__call__) elif isinstance(expr, AssignmentCollection): - return expr.copy(mainAssignments=[e for e in expr.mainAssignments], + return expr.copy(main_assignments=[e for e in expr.main_assignments], subexpressions=[e for e in expr.subexpressions]) transientTerms = expr.atoms(Transient) diff --git a/kerncraft_coupling/kerncraft_interface.py b/kerncraft_coupling/kerncraft_interface.py index b40f1021f5d3eb69b6a02aa56efb330828826026..e44bd1b9b4e2bbb91a3d0b47a18d171f555df0c0 100644 --- a/kerncraft_coupling/kerncraft_interface.py +++ b/kerncraft_coupling/kerncraft_interface.py @@ -12,7 +12,7 @@ from kerncraft.iaca import iaca_analyse_instrumented_binary, iaca_instrumentatio from pystencils.kerncraft_coupling.generate_benchmark import generateBenchmark from pystencils.astnodes import LoopOverCoordinate, SympyAssignment, ResolvedFieldAccess from pystencils.field import getLayoutFromStrides -from pystencils.sympyextensions import countNumberOfOperationsInAst +from pystencils.sympyextensions import count_operations_in_ast from pystencils.utils import DotDict @@ -78,7 +78,7 @@ class PyStencilsKerncraftKernel(kerncraft.kernel.Kernel): self.datatype = list(self.variables.values())[0][0] # flops - operationCount = countNumberOfOperationsInAst(innerLoop) + operationCount = count_operations_in_ast(innerLoop) self._flops = { '+': operationCount['adds'], '*': operationCount['muls'], diff --git a/kernelcreation.py b/kernelcreation.py index 9c84347d93f940ed783a1cd45681b595c7c1c98c..c9599c2bd2dd35393ca709549c96b3842d120d33 100644 --- a/kernelcreation.py +++ b/kernelcreation.py @@ -33,9 +33,9 @@ def createKernel(equations, target='cpu', dataType="double", iterationSlice=None # ---- Normalizing parameters splitGroups = () if isinstance(equations, AssignmentCollection): - if 'splitGroups' in equations.simplificationHints: - splitGroups = equations.simplificationHints['splitGroups'] - equations = equations.allEquations + if 'splitGroups' in equations.simplification_hints: + splitGroups = equations.simplification_hints['splitGroups'] + equations = equations.all_assignments # ---- Creating ast if target == 'cpu': @@ -84,7 +84,7 @@ def createIndexedKernel(equations, indexFields, target='cpu', dataType="double", """ if isinstance(equations, AssignmentCollection): - equations = equations.allEquations + equations = equations.all_assignments if target == 'cpu': from pystencils.cpu import createIndexedKernel from pystencils.cpu import addOpenMP diff --git a/sympyextensions.py b/sympyextensions.py index d407d8f61ef463ecb2fcabe00af598f99c63ece7..8bcfae4cff34ccfaecc5a6b635b1e9951cba5e09 100644 --- a/sympyextensions.py +++ b/sympyextensions.py @@ -1,76 +1,80 @@ -import operator -from functools import reduce -from collections import defaultdict, Sequence import itertools import warnings +import operator +from functools import reduce, partial +from collections import defaultdict, Counter import sympy as sp +from sympy.functions import Abs +from typing import Optional, Union, List, TypeVar, Iterable, Sequence, Callable, Dict, Tuple from pystencils.data_types import getTypeOfExpression, getBaseType from pystencils.assignment import Assignment +T = TypeVar('T') + -def prod(seq): +def prod(seq: Sequence[T]) -> T: """Takes a sequence and returns the product of all elements""" return reduce(operator.mul, seq, 1) -def allIn(a, b): - """Tests if all elements of a container 'a' are contained in 'b'""" - return all(element in b for element in a) - - -def isIntegerSequence(sequence): +def is_integer_sequence(sequence: Iterable) -> bool: + """Checks if all elements of the passed sequence can be cast to integers""" try: - [int(i) for i in sequence] + for i in sequence: + int(i) return True except TypeError: return False -def scalarProduct(a, b): +def scalar_product(a: Iterable[T], b: Iterable[T]) -> T: + """Scalar product between two sequences.""" return sum(a_i * b_i for a_i, b_i in zip(a, b)) -def equationsToMatrix(equations, degreesOfFreedom): - return sp.Matrix(len(equations), len(degreesOfFreedom), - lambda row, col: equations[row].coeff(degreesOfFreedom[col])) - - -def kroneckerDelta(*args): - """Kronecker delta for variable number of arguments, - 1 if all args are equal, otherwise 0""" +def kronecker_delta(*args): + """Kronecker delta for variable number of arguments, 1 if all args are equal, otherwise 0""" for a in args: if a != args[0]: return 0 return 1 -def multidimensionalSummation(i, dim): - """Multidimensional summation""" - prodArgs = [range(dim)] * i - return itertools.product(*prodArgs) - +def multidimensional_sum(i, dim): + """Multidimensional summation -def normalizeProduct(product): + Example: + >>> list(multidimensional_sum(2, dim=3)) + [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2)] """ - Expects a sympy expression that can be interpreted as a product and - - for a Mul node returns its factors ('args') - - for a Pow node with positive integer exponent returns a list of factors - - for other node types [product] is returned + prod_args = [range(dim)] * i + return itertools.product(*prod_args) + + +def normalize_product(product: sp.Expr) -> List[sp.Expr]: + """Expects a sympy expression that can be interpreted as a product and returns a list of all factors. + + Removes sp.Pow nodes that have integer exponent by representing them as single factors in list. + + Returns: + * for a Mul node list of factors ('args') + * for a Pow node with positive integer exponent a list of factors + * for other node types [product] is returned """ - def handlePow(power): + def handle_pow(power): if power.exp.is_integer and power.exp.is_number and power.exp > 0: return [power.base] * power.exp else: return [power] - if product.func == sp.Pow: - return handlePow(product) - elif product.func == sp.Mul: + if isinstance(product, sp.Pow): + return handle_pow(product) + elif isinstance(product, sp.Mul): result = [] for a in product.args: if a.func == sp.Pow: - result += handlePow(a) + result += handle_pow(a) else: result.append(a) return result @@ -78,357 +82,360 @@ def normalizeProduct(product): return [product] -def productSymmetric(*args, withDiagonal=True): - """Similar to itertools.product but returns only values where the index is ascending i.e. values below diagonal""" +def symmetric_product(*args, with_diagonal: bool = True) -> Iterable: + """Similar to itertools.product but yields only values where the index is ascending i.e. values below/up to diagonal + + Examples: + >>> list(symmetric_product([1, 2, 3], ['a', 'b', 'c'])) + [(1, 'a'), (1, 'b'), (1, 'c'), (2, 'b'), (2, 'c'), (3, 'c')] + >>> list(symmetric_product([1, 2, 3], ['a', 'b', 'c'], with_diagonal=False)) + [(1, 'b'), (1, 'c'), (2, 'c')] + """ ranges = [range(len(a)) for a in args] for idx in itertools.product(*ranges): - validIndex = True + valid_index = True for t in range(1, len(idx)): - if (withDiagonal and idx[t - 1] > idx[t]) or (not withDiagonal and idx[t - 1] >= idx[t]): - validIndex = False + if (with_diagonal and idx[t - 1] > idx[t]) or (not with_diagonal and idx[t - 1] >= idx[t]): + valid_index = False break - if validIndex: + if valid_index: yield tuple(a[i] for a, i in zip(args, idx)) -def fastSubs(term, subsDict, skip=None): +def fast_subs(expression: T, substitutions: Dict[sp.Expr, sp.Expr], + skip: Optional[Callable[[sp.Expr], bool]] = None) -> T: """Similar to sympy subs function. - This version is much faster for big substitution dictionaries than sympy version""" + + Args: + expression: expression where parts should be substituted + substitutions: dict defining substitutions by mapping from old to new terms + skip: function that marks expressions to be skipped (if True is returned) - that means that in these skipped + expressions no substitutions are done + + This version is much faster for big substitution dictionaries than sympy version + """ + if type(expression) is sp.Matrix: + return expression.copy().applyfunc(partial(fast_subs, substitutions=substitutions)) + def visit(expr): if skip and skip(expr): return expr - if hasattr(expr, "fastSubs"): - return expr.fastSubs(subsDict) - if expr in subsDict: - return subsDict[expr] + if hasattr(expr, "fast_subs"): + return expr.fast_subs(substitutions) + if expr in substitutions: + return substitutions[expr] if not hasattr(expr, 'args'): return expr - paramList = [visit(a) for a in expr.args] - return expr if not paramList else expr.func(*paramList) + param_list = [visit(a) for a in expr.args] + return expr if not param_list else expr.func(*param_list) - if len(subsDict) == 0: - return term + if len(substitutions) == 0: + return expression else: - return visit(term) + return visit(expression) + +def fast_subs_and_normalize(expression, substitutions: Dict[sp.Expr, sp.Expr], + normalize: Callable[[sp.Expr], sp.Expr]) -> sp.Expr: + """Similar to fast_subs, but calls a normalization function on all substituted terms to save one AST traversal.""" -def fastSubsWithNormalize(term, subsDict, normalizeFunc): def visit(expr): - if expr in subsDict: - return subsDict[expr], True + if expr in substitutions: + return substitutions[expr], True if not hasattr(expr, 'args'): return expr, False - paramList = [] + param_list = [] substituted = False for a in expr.args: - replacedExpr, s = visit(a) - paramList.append(replacedExpr) + replaced_expr, s = visit(a) + param_list.append(replaced_expr) if s: substituted = True - if not paramList: + if not param_list: return expr, False else: if substituted: - result, _ = visit(normalizeFunc(expr.func(*paramList))) + result, _ = visit(normalize(expr.func(*param_list))) return result, True else: - return expr.func(*paramList), False + return expr.func(*param_list), False - if len(subsDict) == 0: - return term + if len(substitutions) == 0: + return expression else: - res, _ = visit(term) + res, _ = visit(expression) return res -def replaceAdditive(expr, replacement, subExpression, requiredMatchReplacement=0.5, requiredMatchOriginal=None): - """ - Transformation for replacing a given subexpression inside a sum - - Example 1: - expr = 3*x + 3 * y - replacement = k - subExpression = x+y - return = 3*k - - Example 2: - expr = 3*x + 3 * y + z - replacement = k - subExpression = x+y+z - return: - if minimalMatchingTerms >=3 the expression would not be altered - if smaller than 3 the result is 3*k - 2*z - - :param expr: input expression - :param replacement: expression that is inserted for subExpression (if found) - :param subExpression: expression to replace - :param requiredMatchReplacement: - - if float: the percentage of terms of the subExpression that has to be matched in order to replace - - if integer: the total number of terms that has to be matched in order to replace - - None: is equal to integer 1 - - if both match parameters are given, both restrictions have to be fulfilled (i.e. logical AND) - :param requiredMatchOriginal: - - if float: the percentage of terms of the original addition expression that has to be matched - - if integer: the total number of terms that has to be matched in order to replace - - None: is equal to integer 1 - :return: new expression with replacement +def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr, + required_match_replacement: Optional[Union[int, float]] = 0.5, + required_match_original: Optional[Union[int, float]] = None) -> sp.Expr: + """Transformation for replacing a given subexpression inside a sum. + + Examples: + The next example demonstrates the advantage of replace_additive compared to sympy.subs: + >>> x, y, z, k = sp.symbols("x y z k") + >>> subs_additive(3*x + 3*y, replacement=k, subexpression=x + y) + 3*k + + Terms that don't match completely can be substituted at the cost of additional terms. + This trade-off is managed using the required_match parameters. + >>> subs_additive(3*x + 3*y + z, replacement=k, subexpression=x+y+z, required_match_original=1.0) + 3*x + 3*y + z + >>> subs_additive(3*x + 3*y + z, replacement=k, subexpression=x+y+z, required_match_original=0.5) + 3*k - 2*z + + Args: + expr: input expression + replacement: expression that is inserted for subExpression (if found) + subexpression: expression to replace + required_match_replacement: + * if float: the percentage of terms of the subExpression that has to be matched in order to replace + * if integer: the total number of terms that has to be matched in order to replace + * None: is equal to integer 1 + * if both match parameters are given, both restrictions have to be fulfilled (i.e. logical AND) + required_match_original: + * if float: the percentage of terms of the original addition expression that has to be matched + * if integer: the total number of terms that has to be matched in order to replace + * None: is equal to integer 1 + + Returns: + new expression with replacement """ - def normalizeMatchParameter(matchParameter, expressingLength): - if matchParameter is None: + def normalize_match_parameter(match_parameter, expression_length): + if match_parameter is None: return 1 - elif isinstance(matchParameter, float): - assert 0 <= matchParameter <= 1 - res = int(matchParameter * expressingLength) + elif isinstance(match_parameter, float): + assert 0 <= match_parameter <= 1 + res = int(match_parameter * expression_length) return max(res, 1) - elif isinstance(matchParameter, int): - assert matchParameter > 0 - return matchParameter + elif isinstance(match_parameter, int): + assert match_parameter > 0 + return match_parameter raise ValueError("Invalid parameter") - normalizedReplacementMatch = normalizeMatchParameter(requiredMatchReplacement, len(subExpression.args)) + normalized_replacement_match = normalize_match_parameter(required_match_replacement, len(subexpression.args)) - def visit(currentExpr): - if currentExpr.is_Add: - exprMaxLength = max(len(currentExpr.args), len(subExpression.args)) - normalizedCurrentExprMatch = normalizeMatchParameter(requiredMatchOriginal, exprMaxLength) - exprCoeffs = currentExpr.as_coefficients_dict() - subexprCoeffDict = subExpression.as_coefficients_dict() - intersection = set(subexprCoeffDict.keys()).intersection(set(exprCoeffs)) - if len(intersection) >= max(normalizedReplacementMatch, normalizedCurrentExprMatch): + def visit(current_expr): + if current_expr.is_Add: + expr_max_length = max(len(current_expr.args), len(subexpression.args)) + normalized_current_expr_match = normalize_match_parameter(required_match_original, expr_max_length) + expr_coefficients = current_expr.as_coefficients_dict() + subexpression_coefficient_dict = subexpression.as_coefficients_dict() + intersection = set(subexpression_coefficient_dict.keys()).intersection(set(expr_coefficients)) + if len(intersection) >= max(normalized_replacement_match, normalized_current_expr_match): # find common factor factors = defaultdict(lambda: 0) skips = 0 - for commonSymbol in subexprCoeffDict.keys(): - if commonSymbol not in exprCoeffs: + for commonSymbol in subexpression_coefficient_dict.keys(): + if commonSymbol not in expr_coefficients: skips += 1 continue - factor = exprCoeffs[commonSymbol] / subexprCoeffDict[commonSymbol] + factor = expr_coefficients[commonSymbol] / subexpression_coefficient_dict[commonSymbol] factors[sp.simplify(factor)] += 1 - commonFactor = max(factors.items(), key=operator.itemgetter(1))[0] - if factors[commonFactor] >= max(normalizedCurrentExprMatch, normalizedReplacementMatch): - return currentExpr - commonFactor * subExpression + commonFactor * replacement + common_factor = max(factors.items(), key=operator.itemgetter(1))[0] + if factors[common_factor] >= max(normalized_current_expr_match, normalized_replacement_match): + return current_expr - common_factor * subexpression + common_factor * replacement # if no subexpression was found - paramList = [visit(a) for a in currentExpr.args] - if not paramList: - return currentExpr + param_list = [visit(a) for a in current_expr.args] + if not param_list: + return current_expr else: - return currentExpr.func(*paramList, evaluate=False) + return current_expr.func(*param_list, evaluate=False) return visit(expr) -def replaceSecondOrderProducts(expr, searchSymbols, positive=None, replaceMixed=None): - """ - Replaces second order mixed terms like x*y by 2* ( (x+y)**2 - x**2 - y**2 ) +def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Symbol], + positive: Optional[bool] = None, + replace_mixed: Optional[List[Assignment]] = None) -> sp.Expr: + """Replaces second order mixed terms like x*y by 2*( (x+y)**2 - x**2 - y**2 ). + This makes the term longer - simplify usually is undoing these - however this transformation can be done to find more common sub-expressions - :param expr: input expression - :param searchSymbols: list of symbols that are searched for - Example: given [ x,y,z] terms like x*y, x*z, z*y are replaced - :param positive: there are two ways to do this substitution, either with term - (x+y)**2 or (x-y)**2 . if positive=True the first version is done, - if positive=False the second version is done, if positive=None the - sign is determined by the sign of the mixed term that is replaced - :param replaceMixed: if a list is passed here the expr x+y or x-y is replaced by a special new symbol - the replacement equation is added to the list - :return: + + Args: + expr: input expression + search_symbols: symbols that are searched for + for example, given [x,y,z] terms like x*y, x*z, z*y are replaced + positive: there are two ways to do this substitution, either with term + (x+y)**2 or (x-y)**2 . if positive=True the first version is done, + if positive=False the second version is done, if positive=None the + sign is determined by the sign of the mixed term that is replaced + replace_mixed: if a list is passed here, the expr x+y or x-y is replaced by a special new symbol + and the replacement equation is added to the list """ - if replaceMixed is not None: - mixedSymbolsReplaced = set([e.lhs for e in replaceMixed]) + mixed_symbols_replaced = set([e.lhs for e in replace_mixed]) if replace_mixed is not None else set() if expr.is_Mul: - distinctVelTerms = set() - nrOfVelTerms = 0 - otherFactors = 1 + distinct_search_symbols = set() + nr_of_search_terms = 0 + other_factors = 1 for t in expr.args: - if t in searchSymbols: - nrOfVelTerms += 1 - distinctVelTerms.add(t) + if t in search_symbols: + nr_of_search_terms += 1 + distinct_search_symbols.add(t) else: - otherFactors *= t - if len(distinctVelTerms) == 2 and nrOfVelTerms == 2: - u, v = sorted(list(distinctVelTerms), key=lambda symbol: symbol.name) + other_factors *= t + if len(distinct_search_symbols) == 2 and nr_of_search_terms == 2: + u, v = sorted(list(distinct_search_symbols), key=lambda symbol: symbol.name) if positive is None: - otherFactorsWithoutSymbols = otherFactors - for s in otherFactors.atoms(sp.Symbol): - otherFactorsWithoutSymbols = otherFactorsWithoutSymbols.subs(s, 1) - positive = otherFactorsWithoutSymbols.is_positive + other_factors_without_symbols = other_factors + for s in other_factors.atoms(sp.Symbol): + other_factors_without_symbols = other_factors_without_symbols.subs(s, 1) + positive = other_factors_without_symbols.is_positive assert positive is not None sign = 1 if positive else -1 - if replaceMixed is not None: - newSymbolStr = 'P' if positive else 'M' - mixedSymbolName = u.name + newSymbolStr + v.name - mixedSymbol = sp.Symbol(mixedSymbolName.replace("_", "")) - if mixedSymbol not in mixedSymbolsReplaced: - mixedSymbolsReplaced.add(mixedSymbol) - replaceMixed.append(Assignment(mixedSymbol, u + sign * v)) + if replace_mixed is not None: + new_symbol_str = 'P' if positive else 'M' + mixed_symbol_name = u.name + new_symbol_str + v.name + mixed_symbol = sp.Symbol(mixed_symbol_name.replace("_", "")) + if mixed_symbol not in mixed_symbols_replaced: + mixed_symbols_replaced.add(mixed_symbol) + replace_mixed.append(Assignment(mixed_symbol, u + sign * v)) else: - mixedSymbol = u + sign * v - return sp.Rational(1, 2) * sign * otherFactors * (mixedSymbol ** 2 - u ** 2 - v ** 2) + mixed_symbol = u + sign * v + return sp.Rational(1, 2) * sign * other_factors * (mixed_symbol ** 2 - u ** 2 - v ** 2) - paramList = [replaceSecondOrderProducts(a, searchSymbols, positive, replaceMixed) for a in expr.args] - result = expr.func(*paramList, evaluate=False) if paramList else expr + param_list = [replace_second_order_products(a, search_symbols, positive, replace_mixed) for a in expr.args] + result = expr.func(*param_list, evaluate=False) if param_list else expr return result -def removeHigherOrderTerms(term, order=3, symbols=None): - """ - Removes all terms that that contain more than 'order' factors of given 'symbols' +def remove_higher_order_terms(expr: sp.Expr, symbols: Sequence[sp.Symbol], order: int = 3) -> sp.Expr: + """Removes all terms that contain more than 'order' factors of given 'symbols' Example: >>> x, y = sp.symbols("x y") >>> term = x**2 * y + y**2 * x + y**3 + x + y ** 2 - >>> removeHigherOrderTerms(term, order=2, symbols=[x, y]) + >>> remove_higher_order_terms(term, order=2, symbols=[x, y]) x + y**2 """ from sympy.core.power import Pow from sympy.core.add import Add, Mul result = 0 - term = term.expand() - - if not symbols: - symbols = sp.symbols(" ".join(["u_%d" % (i,) for i in range(3)])) - symbols += sp.symbols(" ".join(["u_%d" % (i,) for i in range(3)]), real=True) + expr = expr.expand() - def velocityFactorsInProduct(product): - uFactorCount = 0 + def velocity_factors_in_product(product): + factor_count = 0 if type(product) is Mul: for factor in product.args: if type(factor) == Pow: if factor.args[0] in symbols: - uFactorCount += factor.args[1] + factor_count += factor.args[1] if factor in symbols: - uFactorCount += 1 + factor_count += 1 elif type(product) is Pow: if product.args[0] in symbols: - uFactorCount += product.args[1] - return uFactorCount + factor_count += product.args[1] + return factor_count - if type(term) == Mul or type(term) == Pow: - if velocityFactorsInProduct(term) <= order: - return term + if type(expr) == Mul or type(expr) == Pow: + if velocity_factors_in_product(expr) <= order: + return expr else: return sp.Rational(0, 1) - if type(term) != Add: - return term + if type(expr) != Add: + return expr - for sumTerm in term.args: - if velocityFactorsInProduct(sumTerm) <= order: + for sumTerm in expr.args: + if velocity_factors_in_product(sumTerm) <= order: result += sumTerm return result -def completeTheSquare(expr, symbolToComplete, newVariable): - """ - Transforms second order polynomial into only squared part i.e. - a*symbolToComplete**2 + b*symbolToComplete + c - is transformed into - newVariable**2 + d +def complete_the_square(expr: sp.Expr, symbol_to_complete: sp.Symbol, + new_variable: sp.Symbol) -> Tuple[sp.Expr, Optional[Tuple[sp.Symbol, sp.Expr]]]: + """Transforms second order polynomial into only squared part. - returns replacedExpr, "a tuple to to replace newVariable such that old expr comes out again" + Examples: + >>> a, b, c, s, n = sp.symbols("a b c s n") + >>> expr = a * s**2 + b * s + c + >>> completed_expr, substitution = complete_the_square(expr, symbol_to_complete=s, new_variable=n) + >>> completed_expr + a*n**2 + c - b**2/(4*a) + >>> substitution + (n, s + b/(2*a)) - if given expr is not a second order polynomial: - return expr, None + Returns: + (replacedExpr, tuple to pass to subs, such that old expr comes out again) """ - p = sp.Poly(expr, symbolToComplete) - coeffs = p.all_coeffs() - if len(coeffs) != 3: + p = sp.Poly(expr, symbol_to_complete) + coefficients = p.all_coeffs() + if len(coefficients) != 3: return expr, None - a, b, _ = coeffs - expr = expr.subs(symbolToComplete, newVariable - b / (2 * a)) - return sp.simplify(expr), (newVariable, symbolToComplete + b / (2 * a)) + a, b, _ = coefficients + expr = expr.subs(symbol_to_complete, new_variable - b / (2 * a)) + return sp.simplify(expr), (new_variable, symbol_to_complete + b / (2 * a)) -def makeExponentialFuncArgumentSquares(expr, variablesToCompleteSquares): - """Completes squares in arguments of exponential which makes them simpler to integrate - Very useful for integrating Maxwell-Boltzmann and its moment generating function""" - expr = sp.simplify(expr) - dim = len(variablesToCompleteSquares) - dummies = [sp.Dummy() for i in range(dim)] +def complete_the_squares_in_exp(expr: sp.Expr, symbols_to_complete: Sequence[sp.Symbol]): + """Completes squares in arguments of exponential which makes them simpler to integrate. + + Very useful for integrating Maxwell-Boltzmann equilibria and its moment generating function + """ + dummies = [sp.Dummy() for _ in symbols_to_complete] def visit(term): if term.func == sp.exp: - expArg = term.args[0] - for i in range(dim): - expArg, substitution = completeTheSquare(expArg, variablesToCompleteSquares[i], dummies[i]) - return sp.exp(sp.expand(expArg)) + exp_arg = term.args[0] + for symbol_to_complete, dummy in zip(symbols_to_complete, dummies): + exp_arg, substitution = complete_the_square(exp_arg, symbol_to_complete, dummy) + return sp.exp(sp.expand(exp_arg)) else: - paramList = [visit(a) for a in term.args] - if not paramList: + param_list = [visit(a) for a in term.args] + if not param_list: return term else: - return term.func(*paramList) + return term.func(*param_list) result = visit(expr) - for i in range(dim): - result = result.subs(dummies[i], variablesToCompleteSquares[i]) + for s, d in zip(symbols_to_complete, dummies): + result = result.subs(d, s) return result def pow2mul(expr): - """ - Convert integer powers in an expression to Muls, like a**2 => a*a. - """ - pows = list(expr.atoms(sp.Pow)) - if any(not e.is_Integer for b, e in (i.as_base_exp() for i in pows)): + """Convert integer powers in an expression to Muls, like a**2 => a*a. """ + powers = list(expr.atoms(sp.Pow)) + if any(not e.is_Integer for b, e in (i.as_base_exp() for i in powers)): raise ValueError("A power contains a non-integer exponent") - repl = zip(pows, (sp.Mul(*[b]*e, evaluate=False) for b, e in (i.as_base_exp() for i in pows))) - return expr.subs(repl) + substitutions = zip(powers, (sp.Mul(*[b]*e, evaluate=False) for b, e in (i.as_base_exp() for i in powers))) + return expr.subs(substitutions) -def extractMostCommonFactor(term): +def extract_most_common_factor(term): """Processes a sum of fractions: determines the most common factor and splits term in common factor and rest""" - import operator - from collections import Counter - from sympy.functions import Abs - - coeffDict = term.as_coefficients_dict() - counter = Counter([Abs(v) for v in coeffDict.values()]) - commonFactor, occurrences = max(counter.items(), key=operator.itemgetter(1)) + coefficient_dict = term.as_coefficients_dict() + counter = Counter([Abs(v) for v in coefficient_dict.values()]) + common_factor, occurrences = max(counter.items(), key=operator.itemgetter(1)) if occurrences == 1 and (1 in counter): - commonFactor = 1 - return commonFactor, term / commonFactor + common_factor = 1 + return common_factor, term / common_factor -def mostCommonTermFactorization(term): - commonFactor, term = extractMostCommonFactor(term) - - factorization = sp.factor(term) - if factorization.is_Mul: - symbolsInFactorization = [] - constantsInFactorization = 1 - for arg in factorization.args: - if len(arg.atoms(sp.Symbol)) == 0: - constantsInFactorization *= arg - else: - symbolsInFactorization.append(arg) - if len(symbolsInFactorization) <= 1: - return sp.Mul(commonFactor, term, evaluate=False) - else: - args = symbolsInFactorization[:-1] + [constantsInFactorization * symbolsInFactorization[-1]] - return sp.Mul(commonFactor, *args) - else: - return sp.Mul(commonFactor, term, evaluate=False) +def count_operations(term: Union[sp.Expr, List[sp.Expr]], + only_type: Optional[str] = 'real') -> Dict[str, int]: + """Counts the number of additions, multiplications and division. + Args: + term: a sympy expression (term, assignment) or sequence of sympy objects + only_type: 'real' or 'int' to count only operations on these types, or None for all -def countNumberOfOperations(term, onlyType='real'): - """ - Counts the number of additions, multiplications and division - :param term: a sympy term, equation or sequence of terms/equations - :param onlyType: 'real' or 'int' to count only operations on these types, or None for all - :return: a dictionary with 'adds', 'muls' and 'divs' keys + Returns: + dict with 'adds', 'muls' and 'divs' keys """ result = {'adds': 0, 'muls': 0, 'divs': 0} if isinstance(term, Sequence): for element in term: - r = countNumberOfOperations(element, onlyType) + r = count_operations(element, only_type) for operationName in result.keys(): result[operationName] += r[operationName] return result @@ -437,27 +444,27 @@ def countNumberOfOperations(term, onlyType='real'): term = term.evalf() - def checkType(e): - if onlyType is None: + def check_type(e): + if only_type is None: return True try: - type = getBaseType(getTypeOfExpression(e)) + base_type = getBaseType(getTypeOfExpression(e)) except ValueError: return False - if onlyType == 'int' and (type.is_int() or type.is_uint()): + if only_type == 'int' and (base_type.is_int() or base_type.is_uint()): return True - if onlyType == 'real' and (type.is_float()): + if only_type == 'real' and (base_type.is_float()): return True else: - return type == onlyType + return base_type == only_type def visit(t): - visitChildren = True + visit_children = True if t.func is sp.Add: - if checkType(t): + if check_type(t): result['adds'] += len(t.args) - 1 elif t.func is sp.Mul: - if checkType(t): + if check_type(t): result['muls'] += len(t.args) - 1 for a in t.args: if a == 1 or a == -1: @@ -465,14 +472,14 @@ def countNumberOfOperations(term, onlyType='real'): elif t.func is sp.Float: pass elif isinstance(t, sp.Symbol): - visitChildren = False + visit_children = False elif isinstance(t, sp.tensor.Indexed): - visitChildren = False + visit_children = False elif t.is_integer: pass elif t.func is sp.Pow: - if checkType(t.args[0]): - visitChildren = False + if check_type(t.args[0]): + visit_children = False if t.exp.is_integer and t.exp.is_number: if t.exp >= 0: result['muls'] += int(t.exp) - 1 @@ -486,7 +493,7 @@ def countNumberOfOperations(term, onlyType='real'): else: warnings.warn("Unknown sympy node of type " + str(t.func) + " counting will be inaccurate") - if visitChildren: + if visit_children: for a in t.args: visit(a) @@ -494,14 +501,14 @@ def countNumberOfOperations(term, onlyType='real'): return result -def countNumberOfOperationsInAst(ast): - """Counts number of operations in an abstract syntax tree, see also :func:`countNumberOfOperations`""" +def count_operations_in_ast(ast) -> Dict[str, int]: + """Counts number of operations in an abstract syntax tree, see also :func:`count_operations`""" from pystencils.astnodes import SympyAssignment result = {'adds': 0, 'muls': 0, 'divs': 0} def visit(node): if isinstance(node, SympyAssignment): - r = countNumberOfOperations(node.rhs) + r = count_operations(node.rhs) result['adds'] += r['adds'] result['muls'] += r['muls'] result['divs'] += r['divs'] @@ -512,37 +519,34 @@ def countNumberOfOperationsInAst(ast): return result -def matrixFromColumnVectors(columnVectors): - """Creates a sympy matrix from column vectors. - :param columnVectors: nested sequence - i.e. a sequence of column vectors - """ - c = columnVectors - return sp.Matrix([list(c[i]) for i in range(len(c))]).transpose() - - -def commonDenominator(expr): +def common_denominator(expr: sp.Expr) -> sp.Expr: + """Finds least common multiple of all denominators occurring in an expression""" denominators = [r.q for r in expr.atoms(sp.Rational)] return sp.lcm(denominators) -def getSymmetricPart(term, vars): +def get_symmetric_part(expr: sp.Expr, symbols: Iterable[sp.Symbol]) -> sp.Expr: """ Returns the symmetric part of a sympy expressions. - :param term: sympy expression, labeled here as :math:`f` - :param vars: sequence of symbols which are considered as degrees of freedom, labeled here as :math:`x_0, x_1,...` - :returns: :math:`\frac{1}{2} [ f(x_0, x_1, ..) + f(-x_0, -x_1) ]` + Args: + expr: sympy expression, labeled here as :math:`f` + symbols: sequence of symbols which are considered as degrees of freedom, labeled here as :math:`x_0, x_1,...` + + Returns: + :math:`\frac{1}{2} [ f(x_0, x_1, ..) + f(-x_0, -x_1) ]` """ - substitutionDict = {e: -e for e in vars} - return sp.Rational(1, 2) * (term + term.subs(substitutionDict)) + substitution_dict = {e: -e for e in symbols} + return sp.Rational(1, 2) * (expr + expr.subs(substitution_dict)) -def sortEquationsTopologically(equationSequence): - res = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in equationSequence]) +def sort_assignments_topologically(assignments: Sequence[Assignment]) -> List[Assignment]: + """Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs""" + res = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in assignments]) return [Assignment(a, b) for a, b in res] -def getEquationsFromFunction(func, **kwargs): +def assignments_from_python_function(func, **kwargs): """ Mechanism to simplify the generation of a list of sympy equations. Introduces a special "assignment operator" written as "@=". Each line containing this operator gives an @@ -552,50 +556,51 @@ def getEquationsFromFunction(func, **kwargs): Example: - >>> def myKernel(): + >>> def my_kernel(s): ... from pystencils import Field ... f = Field.createGeneric('f', spatialDimensions=2, indexDimensions=0) ... g = f.newFieldWithDifferentName('g') ... - ... S.neighbors @= f[0,1] + f[1,0] - ... g[0,0] @= S.neighbors + f[0,0] - >>> getEquationsFromFunction(myKernel) + ... s.neighbors @= f[0,1] + f[1,0] + ... g[0,0] @= s.neighbors + f[0,0] + >>> assignments_from_python_function(my_kernel) [Assignment(neighbors, f_E + f_N), Assignment(g_C, f_C + neighbors)] """ import inspect import re - class SymbolCreator: - def __getattribute__(self, name): - return sp.Symbol(name) - - assignmentRegexp = re.compile(r'(\s*)(.+?)@=(.*)') - whitespaceRegexp = re.compile(r'(\s*)(.*)') - sourceLines = inspect.getsourcelines(func)[0] + assignment_regexp = re.compile(r'(\s*)(.+?)@=(.*)') + whitespace_regexp = re.compile(r'(\s*)(.*)') + source_lines = inspect.getsourcelines(func)[0] # determine indentation - firstCodeLine = sourceLines[1] - matchRes = whitespaceRegexp.match(firstCodeLine) - assert matchRes, "First line is not indented" - numWhitespaces = len(matchRes.group(1)) - - for i in range(1, len(sourceLines)): - sourceLine = sourceLines[i][numWhitespaces:] - if 'return' in sourceLine: + first_code_line = source_lines[1] + match_res = whitespace_regexp.match(first_code_line) + assert match_res, "First line is not indented" + num_whitespaces = len(match_res.group(1)) + + for i in range(1, len(source_lines)): + source_line = source_lines[i][num_whitespaces:] + if 'return' in source_line: raise ValueError("Function may not have a return statement!") - matchRes = assignmentRegexp.match(sourceLine) - if matchRes: - sourceLine = "%s_result.append(Eq(%s, %s))\n" % matchRes.groups() - sourceLines[i] = sourceLine + match_res = assignment_regexp.match(source_line) + if match_res: + source_line = "%s_result.append(Assignment(%s, %s))\n" % tuple(match_res.groups()[i] for i in range(3)) + source_lines[i] = source_line - code = "".join(sourceLines[1:]) + code = "".join(source_lines[1:]) result = [] - localsDict = {'_result': result, - 'Eq': Assignment, - 'S': SymbolCreator()} - localsDict.update(kwargs) - globalsDict = inspect.stack()[1][0].f_globals.copy() - globalsDict.update(inspect.stack()[1][0].f_locals) - - exec(code, globalsDict, localsDict) + locals_dict = {'_result': result, + 'Assignment': Assignment, + 's': SymbolCreator()} + locals_dict.update(kwargs) + globals_dict = inspect.stack()[1][0].f_globals.copy() + globals_dict.update(inspect.stack()[1][0].f_locals) + + exec(code, globals_dict, locals_dict) return result + + +class SymbolCreator: + def __getattribute__(self, name): + return sp.Symbol(name) diff --git a/transformations/transformations.py b/transformations/transformations.py index 5c9c9d1fe765939ddfb60dc98d77de0d7b20917b..ae5b08bd6ce59b2262f24cfbd334e14e8f7da9d5 100644 --- a/transformations/transformations.py +++ b/transformations/transformations.py @@ -21,22 +21,6 @@ def filteredTreeIteration(node, nodeType): yield from filteredTreeIteration(arg, nodeType) -def fastSubs(term, subsDict): - """Similar to sympy subs function. - This version is much faster for big substitution dictionaries than sympy version""" - if type(term) is sp.Matrix: - return term.copy().applyfunc(functools.partial(fastSubs, subsDict=subsDict)) - - def visit(expr): - if expr in subsDict: - return subsDict[expr] - if not hasattr(expr, 'args'): - return expr - paramList = [visit(a) for a in expr.args] - return expr if not paramList else expr.func(*paramList) - return visit(term) - - def getCommonShape(fieldSet): """Takes a set of pystencils Fields and returns their common spatial shape if it exists. Otherwise ValueError is raised""" diff --git a/vectorization.py b/vectorization.py index 291f8a96af45d47ff7ff78f56b91241e496a9aab..8e1377c04a62c808cbfcef54bd7a7bd8cc8595af 100644 --- a/vectorization.py +++ b/vectorization.py @@ -1,7 +1,7 @@ import sympy as sp import warnings -from pystencils.sympyextensions import fastSubs +from pystencils.sympyextensions import fast_subs from pystencils.transformations import filteredTreeIteration from pystencils.data_types import TypedSymbol, VectorType, BasicType, getTypeOfExpression, castFunc, collateTypes, \ PointerType @@ -97,7 +97,7 @@ def insertVectorCasts(astNode): substitutionDict = {} for asmt in filteredTreeIteration(astNode, ast.SympyAssignment): - subsExpr = fastSubs(asmt.rhs, substitutionDict, skip=lambda e: isinstance(e, ast.ResolvedFieldAccess)) + subsExpr = fast_subs(asmt.rhs, substitutionDict, skip=lambda e: isinstance(e, ast.ResolvedFieldAccess)) asmt.rhs = visitExpr(subsExpr) rhsType = getTypeOfExpression(asmt.rhs) if isinstance(asmt.lhs, TypedSymbol):