simplifications.py 7.93 KB
 Martin Bauer committed Aug 08, 2019 1 2 from itertools import chain from typing import Callable, List, Sequence, Union  Martin Bauer committed Nov 16, 2018 3   Martin Bauer committed Jul 11, 2019 4 5 import sympy as sp  Martin Bauer committed Apr 10, 2018 6 from pystencils.assignment import Assignment  Martin Bauer committed Aug 08, 2019 7 from pystencils.astnodes import Node  Martin Bauer committed Jul 11, 2019 8 from pystencils.field import AbstractField, Field  Martin Bauer committed Apr 10, 2018 9 from pystencils.sympyextensions import subs_additive  Martin Bauer committed Apr 10, 2018 10   Martin Bauer committed Aug 08, 2019 11 12 13 14 15  def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]: """Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs""" edges = [] for c1, e1 in enumerate(assignments):  Martin Bauer committed Aug 09, 2019 16  if hasattr(e1, 'lhs') and hasattr(e1, 'rhs'):  Martin Bauer committed Aug 08, 2019 17 18 19 20  symbols = [e1.lhs] elif isinstance(e1, Node): symbols = e1.symbols_defined else:  Martin Bauer committed Aug 09, 2019 21 22  raise NotImplementedError("Cannot sort topologically. Object of type " + type(e1) + " cannot be handled.")  Martin Bauer committed Aug 08, 2019 23 24 25 26 27 28 29  for lhs in symbols: for c2, e2 in enumerate(assignments): if isinstance(e2, Assignment) and lhs in e2.rhs.free_symbols: edges.append((c1, c2)) elif isinstance(e2, Node) and lhs in e2.undefined_symbols: edges.append((c1, c2)) return [assignments[i] for i in sp.topological_sort((range(len(assignments)), edges))]  Martin Bauer committed Apr 10, 2018 30   Martin Bauer committed Apr 11, 2018 31   Martin Bauer committed Jan 16, 2020 32 def sympy_cse(ac, **kwargs):  Stephan Seitz committed Sep 26, 2019 33  """Searches for common subexpressions inside the assignment collection.  Martin Bauer committed Apr 10, 2018 34 35  Searches is done in both the existing subexpressions as well as the assignments themselves.  Stephan Seitz committed Sep 26, 2019 36  It uses the sympy subexpression detection to do this. Return a new assignment collection  Martin Bauer committed Apr 10, 2018 37 38  with the additional subexpressions found """  Martin Bauer committed Apr 10, 2018 39  symbol_gen = ac.subexpression_symbol_generator  Martin Bauer committed Aug 08, 2019 40 41 42  all_assignments = [e for e in chain(ac.subexpressions, ac.main_assignments) if isinstance(e, Assignment)] other_objects = [e for e in chain(ac.subexpressions, ac.main_assignments) if not isinstance(e, Assignment)]  Martin Bauer committed Jan 16, 2020 43  replacements, new_eq = sp.cse(all_assignments, symbols=symbol_gen, **kwargs)  Martin Bauer committed Aug 08, 2019 44   Martin Bauer committed Apr 10, 2018 45  replacement_eqs = [Assignment(*r) for r in replacements]  Martin Bauer committed Apr 10, 2018 46   Martin Bauer committed Apr 10, 2018 47 48  modified_subexpressions = new_eq[:len(ac.subexpressions)] modified_update_equations = new_eq[len(ac.subexpressions):]  Martin Bauer committed Apr 10, 2018 49   Martin Bauer committed Aug 08, 2019 50  new_subexpressions = sort_assignments_topologically(other_objects + replacement_eqs + modified_subexpressions)  Martin Bauer committed Apr 10, 2018 51  return ac.copy(modified_update_equations, new_subexpressions)  Martin Bauer committed Apr 10, 2018 52 53   Martin Bauer committed Apr 10, 2018 54 55 def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]: """Extracts common subexpressions from a list of assignments."""  Martin Bauer committed Aug 08, 2019 56 57  from pystencils.simp.assignment_collection import AssignmentCollection ec = AssignmentCollection([], assignments)  Martin Bauer committed Apr 10, 2018 58 59 60  return sympy_cse(ec).all_assignments  Martin Bauer committed Aug 08, 2019 61 def subexpression_substitution_in_existing_subexpressions(ac):  Martin Bauer committed Apr 10, 2018 62  """Goes through the subexpressions list and replaces the term in the following subexpressions."""  Martin Bauer committed Apr 10, 2018 63  result = []  Martin Bauer committed Apr 10, 2018 64  for outer_ctr, s in enumerate(ac.subexpressions):  Martin Bauer committed Apr 10, 2018 65  new_rhs = s.rhs  Martin Bauer committed Apr 10, 2018 66 67  for inner_ctr in range(outer_ctr): sub_expr = ac.subexpressions[inner_ctr]  Martin Bauer committed Apr 10, 2018 68 69 70  new_rhs = subs_additive(new_rhs, sub_expr.lhs, sub_expr.rhs, required_match_replacement=1.0) new_rhs = new_rhs.subs(sub_expr.rhs, sub_expr.lhs) result.append(Assignment(s.lhs, new_rhs))  Martin Bauer committed Apr 10, 2018 71   Martin Bauer committed Apr 10, 2018 72  return ac.copy(ac.main_assignments, result)  Martin Bauer committed Apr 10, 2018 73 74   Martin Bauer committed Aug 08, 2019 75 def subexpression_substitution_in_main_assignments(ac):  Martin Bauer committed Apr 10, 2018 76  """Replaces already existing subexpressions in the equations of the assignment_collection."""  Martin Bauer committed Apr 10, 2018 77  result = []  Martin Bauer committed Apr 10, 2018 78 79  for s in ac.main_assignments: new_rhs = s.rhs  Martin Bauer committed Apr 10, 2018 80 81  for sub_expr in ac.subexpressions: new_rhs = subs_additive(new_rhs, sub_expr.lhs, sub_expr.rhs, required_match_replacement=1.0)  Martin Bauer committed Apr 10, 2018 82 83  result.append(Assignment(s.lhs, new_rhs)) return ac.copy(result)  Martin Bauer committed Apr 10, 2018 84 85   Martin Bauer committed Aug 08, 2019 86 def add_subexpressions_for_divisions(ac):  Martin Bauer committed Apr 30, 2018 87  r"""Introduces subexpressions for all divisions which have no constant in the denominator.  Martin Bauer committed Apr 10, 2018 88   Martin Bauer committed Apr 30, 2018 89  For example :math:\frac{1}{x} is replaced while :math:\frac{1}{3} is not replaced.  Martin Bauer committed Apr 10, 2018 90  """  Martin Bauer committed Apr 10, 2018 91 92  divisors = set()  Martin Bauer committed Apr 10, 2018 93  def search_divisors(term):  Martin Bauer committed Apr 10, 2018 94 95 96 97 98  if term.func == sp.Pow: if term.exp.is_integer and term.exp.is_number and term.exp < 0: divisors.add(term) else: for a in term.args:  Martin Bauer committed Apr 10, 2018 99  search_divisors(a)  Martin Bauer committed Apr 10, 2018 100   Martin Bauer committed Apr 10, 2018 101 102  for eq in ac.all_assignments: search_divisors(eq.rhs)  Martin Bauer committed Apr 10, 2018 103   Nils Kohl committed May 03, 2019 104  divisors = sorted(list(divisors), key=lambda x: str(x))  Martin Bauer committed Apr 10, 2018 105  new_symbol_gen = ac.subexpression_symbol_generator  Martin Bauer committed Apr 10, 2018 106  substitutions = {divisor: new_symbol for new_symbol, divisor in zip(new_symbol_gen, divisors)}  Martin Bauer committed Aug 12, 2019 107  return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, substitute_on_lhs=False)  Martin Bauer committed Apr 11, 2018 108 109   Martin Bauer committed Aug 08, 2019 110 def add_subexpressions_for_sums(ac):  Nils Kohl committed Apr 26, 2019 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136  r"""Introduces subexpressions for all sums - i.e. splits addends into subexpressions.""" addends = [] def contains_sum(term): if term.func == sp.add.Add: return True if term.is_Atom: return False return any([contains_sum(a) for a in term.args]) def search_addends(term): if term.func == sp.add.Add: if all([not contains_sum(a) for a in term.args]): addends.extend(term.args) for a in term.args: search_addends(a) for eq in ac.all_assignments: search_addends(eq.rhs) addends = [a for a in addends if not isinstance(a, sp.Symbol) or isinstance(a, AbstractField.AbstractAccess)] new_symbol_gen = ac.subexpression_symbol_generator substitutions = {addend: new_symbol for new_symbol, addend in zip(new_symbol_gen, addends)} return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False)  Martin Bauer committed Aug 08, 2019 137 def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments=True):  Martin Bauer committed Nov 16, 2018 138 139 140 141 142 143 144  r"""Substitutes field accesses on rhs of assignments with subexpressions Can change semantics of the update rule (which is the goal of this transformation) This is useful if a field should be update in place - all values are loaded before into subexpression variables, then the new values are computed and written to the same field in-place. """ field_reads = set()  Martin Bauer committed Aug 12, 2019 145  to_iterate = []  Martin Bauer committed Nov 16, 2018 146  if subexpressions:  Martin Bauer committed Aug 12, 2019 147  to_iterate = chain(to_iterate, ac.subexpressions)  Martin Bauer committed Nov 16, 2018 148  if main_assignments:  Martin Bauer committed Aug 12, 2019 149 150 151 152  to_iterate = chain(to_iterate, ac.main_assignments) for assignment in to_iterate: if hasattr(assignment, 'lhs') and hasattr(assignment, 'rhs'):  Martin Bauer committed Nov 16, 2018 153  field_reads.update(assignment.rhs.atoms(Field.Access))  Martin Bauer committed Aug 08, 2019 154  substitutions = {fa: next(ac.subexpression_symbol_generator) for fa in field_reads}  Martin Bauer committed Aug 12, 2019 155 156  return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, substitute_on_lhs=False, sort_topologically=False)  Martin Bauer committed Nov 16, 2018 157 158   Martin Bauer committed Aug 08, 2019 159 160 161 162 def transform_rhs(assignment_list, transformation, *args, **kwargs): """Applies a transformation function on the rhs of each element of the passed assignment list If the list also contains other object, like AST nodes, these are ignored. Additional parameters are passed to the transformation function"""  Martin Bauer committed Aug 09, 2019 163  return [Assignment(a.lhs, transformation(a.rhs, *args, **kwargs)) if hasattr(a, 'lhs') and hasattr(a, 'rhs') else a  Martin Bauer committed Aug 08, 2019 164 165 166 167 168 169  for a in assignment_list] def transform_lhs_and_rhs(assignment_list, transformation, *args, **kwargs): return [Assignment(transformation(a.lhs, *args, **kwargs), transformation(a.rhs, *args, **kwargs))  Martin Bauer committed Aug 09, 2019 170  if hasattr(a, 'lhs') and hasattr(a, 'rhs') else a  Martin Bauer committed Aug 08, 2019 171 172 173 174  for a in assignment_list] def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]):  Martin Bauer committed Apr 11, 2018 175  """Applies sympy expand operation to all equations in collection."""  Martin Bauer committed Aug 08, 2019 176 177  def f(ac):  Martin Bauer committed Aug 08, 2019 178  return ac.copy(transform_rhs(ac.main_assignments, operation))  Martin Bauer committed Aug 08, 2019 179   Martin Bauer committed Apr 11, 2018 180 181 182 183  f.__name__ = operation.__name__ return f  Martin Bauer committed Aug 08, 2019 184 def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]):  Martin Bauer committed Apr 11, 2018 185  """Applies the given operation on all subexpressions of the AC."""  Martin Bauer committed Aug 08, 2019 186 187  def f(ac):  Martin Bauer committed Aug 08, 2019 188  return ac.copy(ac.main_assignments, transform_rhs(ac.subexpressions, operation))  Martin Bauer committed Aug 08, 2019 189   Martin Bauer committed Apr 11, 2018 190  f.__name__ = operation.__name__  Martin Bauer committed Apr 13, 2018 191  return f