From d7ad30ffa27fc02b3196a24e3c90b6d5d8a0f933 Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Thu, 1 Aug 2019 16:54:45 +0200 Subject: [PATCH] Started with Support for AST nodes in assignment collection - allow AST nodes in assignment collection to e.g. put RNG nodes in LB method --- pystencils/simp/assignment_collection.py | 53 +++++++++++++++++++----- pystencils/simp/simplifications.py | 12 +++--- pystencils/sympyextensions.py | 6 --- 3 files changed, 48 insertions(+), 23 deletions(-) diff --git a/pystencils/simp/assignment_collection.py b/pystencils/simp/assignment_collection.py index c5a1837..fca052e 100644 --- a/pystencils/simp/assignment_collection.py +++ b/pystencils/simp/assignment_collection.py @@ -4,7 +4,23 @@ from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, import sympy as sp from pystencils.assignment import Assignment -from pystencils.sympyextensions import count_operations, fast_subs, sort_assignments_topologically +from pystencils.astnodes import Node +from pystencils.sympyextensions import count_operations, fast_subs + + +def transform_rhs(assignment_list, transformation, *args, **kwargs): + """Applies a transformation function on the rhs of each element of the passed assignment list + If the list also contains other object, like AST nodes, these are ignored. + Additional parameters are passed to the transformation function""" + return [Assignment(a.lhs, transformation(a.rhs, *args, **kwargs)) if isinstance(a, Assignment) else a + for a in assignment_list] + + +def transform_lhs_and_rhs(assignment_list, transformation, *args, **kwargs): + return [Assignment(transformation(a.lhs, *args, **kwargs), + transformation(a.rhs, *args, **kwargs)) + if isinstance(a, Assignment) else a + for a in assignment_list] class AssignmentCollection: @@ -205,17 +221,15 @@ class AssignmentCollection: Returns: New AssignmentCollection where substitutions have been applied, self is not altered. """ - if substitute_on_lhs: - new_subexpressions = [fast_subs(eq, substitutions) for eq in self.subexpressions] - new_equations = [fast_subs(eq, substitutions) for eq in self.main_assignments] - else: - new_subexpressions = [Assignment(eq.lhs, fast_subs(eq.rhs, substitutions)) for eq in self.subexpressions] - new_equations = [Assignment(eq.lhs, fast_subs(eq.rhs, substitutions)) for eq in self.main_assignments] + transform = transform_lhs_and_rhs if substitute_on_lhs else transform_rhs + transformed_subexpressions = transform(self.subexpressions, fast_subs, substitutions) + transformed_assignments = transform(self.main_assignments, fast_subs, substitutions) if add_substitutions_as_subexpressions: - new_subexpressions = [Assignment(b, a) for a, b in substitutions.items()] + new_subexpressions - new_subexpressions = sort_assignments_topologically(new_subexpressions) - return self.copy(new_equations, new_subexpressions) + transformed_subexpressions = [Assignment(b, a) for a, b in + substitutions.items()] + transformed_subexpressions + transformed_subexpressions = sort_assignments_topologically(transformed_subexpressions) + return self.copy(transformed_assignments, transformed_subexpressions) def new_merged(self, other: 'AssignmentCollection') -> 'AssignmentCollection': """Returns a new collection which contains self and other. Subexpressions are renamed if they clash.""" @@ -405,3 +419,22 @@ class SymbolGen: name = "{}_{}".format(self._symbol, self._ctr) self._ctr += 1 return sp.Symbol(name) + + +def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]: + """Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs""" + edges = [] + for c1, e1 in enumerate(assignments): + if isinstance(e1, Assignment): + symbols = [e1.lhs] + elif isinstance(e1, Node): + symbols = e1.symbols_defined + else: + symbols = [] + for lhs in symbols: + for c2, e2 in enumerate(assignments): + if isinstance(e2, Assignment) and lhs in e2.rhs.free_symbols: + edges.append((c1, c2)) + elif isinstance(e2, Node) and lhs in e2.undefined_symbols: + edges.append((c1, c2)) + return [assignments[i] for i in sp.topological_sort((range(len(assignments)), edges))] diff --git a/pystencils/simp/simplifications.py b/pystencils/simp/simplifications.py index 22ff1be..24726fd 100644 --- a/pystencils/simp/simplifications.py +++ b/pystencils/simp/simplifications.py @@ -4,7 +4,7 @@ import sympy as sp from pystencils.assignment import Assignment from pystencils.field import AbstractField, Field -from pystencils.simp.assignment_collection import AssignmentCollection +from pystencils.simp.assignment_collection import AssignmentCollection, transform_rhs from pystencils.sympyextensions import subs_additive AC = AssignmentCollection @@ -128,15 +128,14 @@ def add_subexpressions_for_field_reads(ac: AC, subexpressions=True, main_assignm if main_assignments: for assignment in ac.main_assignments: field_reads.update(assignment.rhs.atoms(Field.Access)) - substitutions = {fa: sp.Dummy() for fa in field_reads} + substitutions = {fa: next(ac.subexpression_symbol_generator) for fa in field_reads} return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, substitute_on_lhs=False) def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]) -> Callable[[AC], AC]: """Applies sympy expand operation to all equations in collection.""" - def f(assignment_collection: AC) -> AC: - result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.main_assignments] - return assignment_collection.copy(result) + def f(ac: AC) -> AC: + return ac.copy(transform_rhs(ac.main_assignments, operation)) f.__name__ = operation.__name__ return f @@ -144,7 +143,6 @@ def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]) -> Callabl def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]) -> Callable[[AC], AC]: """Applies the given operation on all subexpressions of the AC.""" def f(ac: AC) -> AC: - result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in ac.subexpressions] - return ac.copy(ac.main_assignments, result) + return ac.copy(ac.main_assignments, transform_rhs(ac.subexpressions, operation)) f.__name__ = operation.__name__ return f diff --git a/pystencils/sympyextensions.py b/pystencils/sympyextensions.py index 2741f7c..afdf0fd 100644 --- a/pystencils/sympyextensions.py +++ b/pystencils/sympyextensions.py @@ -573,12 +573,6 @@ def get_symmetric_part(expr: sp.Expr, symbols: Iterable[sp.Symbol]) -> sp.Expr: return sp.Rational(1, 2) * (expr + expr.subs(substitution_dict)) -def sort_assignments_topologically(assignments: Sequence[Assignment]) -> List[Assignment]: - """Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs""" - res = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in assignments]) - return [Assignment(a, b) for a, b in res] - - class SymbolCreator: def __getattribute__(self, name): return sp.Symbol(name) -- GitLab