simplifications.py 7.57 KB
 Martin Bauer committed Aug 08, 2019 1 2 from itertools import chain from typing import Callable, List, Sequence, Union  Martin Bauer committed Nov 16, 2018 3   Martin Bauer committed Jul 11, 2019 4 5 import sympy as sp  Martin Bauer committed Apr 10, 2018 6 from pystencils.assignment import Assignment  Martin Bauer committed Aug 08, 2019 7 from pystencils.astnodes import Node  Martin Bauer committed Jul 11, 2019 8 from pystencils.field import AbstractField, Field  Martin Bauer committed Apr 10, 2018 9 from pystencils.sympyextensions import subs_additive  Martin Bauer committed Apr 10, 2018 10   Martin Bauer committed Aug 08, 2019 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28  def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]: """Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs""" edges = [] for c1, e1 in enumerate(assignments): if isinstance(e1, Assignment): symbols = [e1.lhs] elif isinstance(e1, Node): symbols = e1.symbols_defined else: symbols = [] for lhs in symbols: for c2, e2 in enumerate(assignments): if isinstance(e2, Assignment) and lhs in e2.rhs.free_symbols: edges.append((c1, c2)) elif isinstance(e2, Node) and lhs in e2.undefined_symbols: edges.append((c1, c2)) return [assignments[i] for i in sp.topological_sort((range(len(assignments)), edges))]  Martin Bauer committed Apr 10, 2018 29   Martin Bauer committed Apr 11, 2018 30   Martin Bauer committed Aug 08, 2019 31 def sympy_cse(ac):  Martin Bauer committed Apr 10, 2018 32 33 34 35  """Searches for common subexpressions inside the equation collection. Searches is done in both the existing subexpressions as well as the assignments themselves. It uses the sympy subexpression detection to do this. Return a new equation collection  Martin Bauer committed Apr 10, 2018 36 37  with the additional subexpressions found """  Martin Bauer committed Apr 10, 2018 38  symbol_gen = ac.subexpression_symbol_generator  Martin Bauer committed Aug 08, 2019 39 40 41 42 43  all_assignments = [e for e in chain(ac.subexpressions, ac.main_assignments) if isinstance(e, Assignment)] other_objects = [e for e in chain(ac.subexpressions, ac.main_assignments) if not isinstance(e, Assignment)] replacements, new_eq = sp.cse(all_assignments, symbols=symbol_gen)  Martin Bauer committed Apr 10, 2018 44  replacement_eqs = [Assignment(*r) for r in replacements]  Martin Bauer committed Apr 10, 2018 45   Martin Bauer committed Apr 10, 2018 46 47  modified_subexpressions = new_eq[:len(ac.subexpressions)] modified_update_equations = new_eq[len(ac.subexpressions):]  Martin Bauer committed Apr 10, 2018 48   Martin Bauer committed Aug 08, 2019 49  new_subexpressions = sort_assignments_topologically(other_objects + replacement_eqs + modified_subexpressions)  Martin Bauer committed Apr 10, 2018 50  return ac.copy(modified_update_equations, new_subexpressions)  Martin Bauer committed Apr 10, 2018 51 52   Martin Bauer committed Apr 10, 2018 53 54 def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]: """Extracts common subexpressions from a list of assignments."""  Martin Bauer committed Aug 08, 2019 55 56  from pystencils.simp.assignment_collection import AssignmentCollection ec = AssignmentCollection([], assignments)  Martin Bauer committed Apr 10, 2018 57 58 59  return sympy_cse(ec).all_assignments  Martin Bauer committed Aug 08, 2019 60 def subexpression_substitution_in_existing_subexpressions(ac):  Martin Bauer committed Apr 10, 2018 61  """Goes through the subexpressions list and replaces the term in the following subexpressions."""  Martin Bauer committed Apr 10, 2018 62  result = []  Martin Bauer committed Apr 10, 2018 63  for outer_ctr, s in enumerate(ac.subexpressions):  Martin Bauer committed Apr 10, 2018 64  new_rhs = s.rhs  Martin Bauer committed Apr 10, 2018 65 66  for inner_ctr in range(outer_ctr): sub_expr = ac.subexpressions[inner_ctr]  Martin Bauer committed Apr 10, 2018 67 68 69  new_rhs = subs_additive(new_rhs, sub_expr.lhs, sub_expr.rhs, required_match_replacement=1.0) new_rhs = new_rhs.subs(sub_expr.rhs, sub_expr.lhs) result.append(Assignment(s.lhs, new_rhs))  Martin Bauer committed Apr 10, 2018 70   Martin Bauer committed Apr 10, 2018 71  return ac.copy(ac.main_assignments, result)  Martin Bauer committed Apr 10, 2018 72 73   Martin Bauer committed Aug 08, 2019 74 def subexpression_substitution_in_main_assignments(ac):  Martin Bauer committed Apr 10, 2018 75  """Replaces already existing subexpressions in the equations of the assignment_collection."""  Martin Bauer committed Apr 10, 2018 76  result = []  Martin Bauer committed Apr 10, 2018 77 78  for s in ac.main_assignments: new_rhs = s.rhs  Martin Bauer committed Apr 10, 2018 79 80  for sub_expr in ac.subexpressions: new_rhs = subs_additive(new_rhs, sub_expr.lhs, sub_expr.rhs, required_match_replacement=1.0)  Martin Bauer committed Apr 10, 2018 81 82  result.append(Assignment(s.lhs, new_rhs)) return ac.copy(result)  Martin Bauer committed Apr 10, 2018 83 84   Martin Bauer committed Aug 08, 2019 85 def add_subexpressions_for_divisions(ac):  Martin Bauer committed Apr 30, 2018 86  r"""Introduces subexpressions for all divisions which have no constant in the denominator.  Martin Bauer committed Apr 10, 2018 87   Martin Bauer committed Apr 30, 2018 88  For example :math:\frac{1}{x} is replaced while :math:\frac{1}{3} is not replaced.  Martin Bauer committed Apr 10, 2018 89  """  Martin Bauer committed Apr 10, 2018 90 91  divisors = set()  Martin Bauer committed Apr 10, 2018 92  def search_divisors(term):  Martin Bauer committed Apr 10, 2018 93 94 95 96 97  if term.func == sp.Pow: if term.exp.is_integer and term.exp.is_number and term.exp < 0: divisors.add(term) else: for a in term.args:  Martin Bauer committed Apr 10, 2018 98  search_divisors(a)  Martin Bauer committed Apr 10, 2018 99   Martin Bauer committed Apr 10, 2018 100 101  for eq in ac.all_assignments: search_divisors(eq.rhs)  Martin Bauer committed Apr 10, 2018 102   Nils Kohl committed May 03, 2019 103  divisors = sorted(list(divisors), key=lambda x: str(x))  Martin Bauer committed Apr 10, 2018 104  new_symbol_gen = ac.subexpression_symbol_generator  Martin Bauer committed Apr 10, 2018 105  substitutions = {divisor: new_symbol for new_symbol, divisor in zip(new_symbol_gen, divisors)}  Martin Bauer committed Apr 10, 2018 106  return ac.new_with_substitutions(substitutions, True)  Martin Bauer committed Apr 11, 2018 107 108   Martin Bauer committed Aug 08, 2019 109 def add_subexpressions_for_sums(ac):  Nils Kohl committed Apr 26, 2019 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135  r"""Introduces subexpressions for all sums - i.e. splits addends into subexpressions.""" addends = [] def contains_sum(term): if term.func == sp.add.Add: return True if term.is_Atom: return False return any([contains_sum(a) for a in term.args]) def search_addends(term): if term.func == sp.add.Add: if all([not contains_sum(a) for a in term.args]): addends.extend(term.args) for a in term.args: search_addends(a) for eq in ac.all_assignments: search_addends(eq.rhs) addends = [a for a in addends if not isinstance(a, sp.Symbol) or isinstance(a, AbstractField.AbstractAccess)] new_symbol_gen = ac.subexpression_symbol_generator substitutions = {addend: new_symbol for new_symbol, addend in zip(new_symbol_gen, addends)} return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False)  Martin Bauer committed Aug 08, 2019 136 def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments=True):  Martin Bauer committed Nov 16, 2018 137 138 139 140 141 142 143 144 145 146 147 148 149  r"""Substitutes field accesses on rhs of assignments with subexpressions Can change semantics of the update rule (which is the goal of this transformation) This is useful if a field should be update in place - all values are loaded before into subexpression variables, then the new values are computed and written to the same field in-place. """ field_reads = set() if subexpressions: for assignment in ac.subexpressions: field_reads.update(assignment.rhs.atoms(Field.Access)) if main_assignments: for assignment in ac.main_assignments: field_reads.update(assignment.rhs.atoms(Field.Access))  Martin Bauer committed Aug 08, 2019 150  substitutions = {fa: next(ac.subexpression_symbol_generator) for fa in field_reads}  Martin Bauer committed Nov 16, 2018 151 152 153  return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, substitute_on_lhs=False)  Martin Bauer committed Aug 08, 2019 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 def transform_rhs(assignment_list, transformation, *args, **kwargs): """Applies a transformation function on the rhs of each element of the passed assignment list If the list also contains other object, like AST nodes, these are ignored. Additional parameters are passed to the transformation function""" return [Assignment(a.lhs, transformation(a.rhs, *args, **kwargs)) if isinstance(a, Assignment) else a for a in assignment_list] def transform_lhs_and_rhs(assignment_list, transformation, *args, **kwargs): return [Assignment(transformation(a.lhs, *args, **kwargs), transformation(a.rhs, *args, **kwargs)) if isinstance(a, Assignment) else a for a in assignment_list] def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]):  Martin Bauer committed Apr 11, 2018 170  """Applies sympy expand operation to all equations in collection."""  Martin Bauer committed Aug 08, 2019 171 172  def f(ac):  Martin Bauer committed Aug 08, 2019 173  return ac.copy(transform_rhs(ac.main_assignments, operation))  Martin Bauer committed Aug 08, 2019 174   Martin Bauer committed Apr 11, 2018 175 176 177 178  f.__name__ = operation.__name__ return f  Martin Bauer committed Aug 08, 2019 179 def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]):  Martin Bauer committed Apr 11, 2018 180  """Applies the given operation on all subexpressions of the AC."""  Martin Bauer committed Aug 08, 2019 181 182  def f(ac):  Martin Bauer committed Aug 08, 2019 183  return ac.copy(ac.main_assignments, transform_rhs(ac.subexpressions, operation))  Martin Bauer committed Aug 08, 2019 184   Martin Bauer committed Apr 11, 2018 185  f.__name__ = operation.__name__  Martin Bauer committed Apr 13, 2018 186  return f