Commit eaeec78a by Martin Bauer

### RNG: Possibility to pass seed and block offset parameters

 from typing import Callable, List from itertools import chain from typing import Callable, List, Sequence, Union import sympy as sp from pystencils.assignment import Assignment from pystencils.astnodes import Node from pystencils.field import AbstractField, Field from pystencils.simp.assignment_collection import AssignmentCollection, transform_rhs from pystencils.sympyextensions import subs_additive AC = AssignmentCollection def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]: """Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs""" edges = [] for c1, e1 in enumerate(assignments): if isinstance(e1, Assignment): symbols = [e1.lhs] elif isinstance(e1, Node): symbols = e1.symbols_defined else: symbols = [] for lhs in symbols: for c2, e2 in enumerate(assignments): if isinstance(e2, Assignment) and lhs in e2.rhs.free_symbols: edges.append((c1, c2)) elif isinstance(e2, Node) and lhs in e2.undefined_symbols: edges.append((c1, c2)) return [assignments[i] for i in sp.topological_sort((range(len(assignments)), edges))] def sympy_cse(ac: AC) -> AC: def sympy_cse(ac): """Searches for common subexpressions inside the equation collection. Searches is done in both the existing subexpressions as well as the assignments themselves. ... ... @@ -18,27 +36,28 @@ def sympy_cse(ac: AC) -> AC: with the additional subexpressions found """ symbol_gen = ac.subexpression_symbol_generator replacements, new_eq = sp.cse(ac.subexpressions + ac.main_assignments, symbols=symbol_gen) all_assignments = [e for e in chain(ac.subexpressions, ac.main_assignments) if isinstance(e, Assignment)] other_objects = [e for e in chain(ac.subexpressions, ac.main_assignments) if not isinstance(e, Assignment)] replacements, new_eq = sp.cse(all_assignments, symbols=symbol_gen) replacement_eqs = [Assignment(*r) for r in replacements] modified_subexpressions = new_eq[:len(ac.subexpressions)] modified_update_equations = new_eq[len(ac.subexpressions):] new_subexpressions = replacement_eqs + modified_subexpressions topologically_sorted_pairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in new_subexpressions]) new_subexpressions = [Assignment(a[0], a[1]) for a in topologically_sorted_pairs] new_subexpressions = sort_assignments_topologically(other_objects + replacement_eqs + modified_subexpressions) return ac.copy(modified_update_equations, new_subexpressions) def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]: """Extracts common subexpressions from a list of assignments.""" ec = AC([], assignments) from pystencils.simp.assignment_collection import AssignmentCollection ec = AssignmentCollection([], assignments) return sympy_cse(ec).all_assignments def subexpression_substitution_in_existing_subexpressions(ac: AC) -> AC: def subexpression_substitution_in_existing_subexpressions(ac): """Goes through the subexpressions list and replaces the term in the following subexpressions.""" result = [] for outer_ctr, s in enumerate(ac.subexpressions): ... ... @@ -52,7 +71,7 @@ def subexpression_substitution_in_existing_subexpressions(ac: AC) -> AC: return ac.copy(ac.main_assignments, result) def subexpression_substitution_in_main_assignments(ac: AC) -> AC: def subexpression_substitution_in_main_assignments(ac): """Replaces already existing subexpressions in the equations of the assignment_collection.""" result = [] for s in ac.main_assignments: ... ... @@ -63,7 +82,7 @@ def subexpression_substitution_in_main_assignments(ac: AC) -> AC: return ac.copy(result) def add_subexpressions_for_divisions(ac: AC) -> AC: def add_subexpressions_for_divisions(ac): r"""Introduces subexpressions for all divisions which have no constant in the denominator. For example :math:\frac{1}{x} is replaced while :math:\frac{1}{3} is not replaced. ... ... @@ -87,7 +106,7 @@ def add_subexpressions_for_divisions(ac: AC) -> AC: return ac.new_with_substitutions(substitutions, True) def add_subexpressions_for_sums(ac: AC) -> AC: def add_subexpressions_for_sums(ac): r"""Introduces subexpressions for all sums - i.e. splits addends into subexpressions.""" addends = [] ... ... @@ -114,7 +133,7 @@ def add_subexpressions_for_sums(ac: AC) -> AC: return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False) def add_subexpressions_for_field_reads(ac: AC, subexpressions=True, main_assignments=True) -> AC: def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments=True): r"""Substitutes field accesses on rhs of assignments with subexpressions Can change semantics of the update rule (which is the goal of this transformation) ... ... @@ -132,17 +151,36 @@ def add_subexpressions_for_field_reads(ac: AC, subexpressions=True, main_assignm return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, substitute_on_lhs=False) def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]) -> Callable[[AC], AC]: def transform_rhs(assignment_list, transformation, *args, **kwargs): """Applies a transformation function on the rhs of each element of the passed assignment list If the list also contains other object, like AST nodes, these are ignored. Additional parameters are passed to the transformation function""" return [Assignment(a.lhs, transformation(a.rhs, *args, **kwargs)) if isinstance(a, Assignment) else a for a in assignment_list] def transform_lhs_and_rhs(assignment_list, transformation, *args, **kwargs): return [Assignment(transformation(a.lhs, *args, **kwargs), transformation(a.rhs, *args, **kwargs)) if isinstance(a, Assignment) else a for a in assignment_list] def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]): """Applies sympy expand operation to all equations in collection.""" def f(ac: AC) -> AC: def f(ac): return ac.copy(transform_rhs(ac.main_assignments, operation)) f.__name__ = operation.__name__ return f def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]) -> Callable[[AC], AC]: def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]): """Applies the given operation on all subexpressions of the AC.""" def f(ac: AC) -> AC: def f(ac): return ac.copy(ac.main_assignments, transform_rhs(ac.subexpressions, operation)) f.__name__ = operation.__name__ return f