From d7ad30ffa27fc02b3196a24e3c90b6d5d8a0f933 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Thu, 1 Aug 2019 16:54:45 +0200
Subject: [PATCH] Started with Support for AST nodes in assignment collection

- allow AST nodes in assignment collection to e.g. put RNG nodes in
  LB method
---
 pystencils/simp/assignment_collection.py | 53 +++++++++++++++++++-----
 pystencils/simp/simplifications.py       | 12 +++---
 pystencils/sympyextensions.py            |  6 ---
 3 files changed, 48 insertions(+), 23 deletions(-)

diff --git a/pystencils/simp/assignment_collection.py b/pystencils/simp/assignment_collection.py
index c5a1837..fca052e 100644
--- a/pystencils/simp/assignment_collection.py
+++ b/pystencils/simp/assignment_collection.py
@@ -4,7 +4,23 @@ from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set,
 import sympy as sp
 
 from pystencils.assignment import Assignment
-from pystencils.sympyextensions import count_operations, fast_subs, sort_assignments_topologically
+from pystencils.astnodes import Node
+from pystencils.sympyextensions import count_operations, fast_subs
+
+
+def transform_rhs(assignment_list, transformation, *args, **kwargs):
+    """Applies a transformation function on the rhs of each element of the passed assignment list
+    If the list also contains other object, like AST nodes, these are ignored.
+    Additional parameters are passed to the transformation function"""
+    return [Assignment(a.lhs, transformation(a.rhs, *args, **kwargs)) if isinstance(a, Assignment) else a
+            for a in assignment_list]
+
+
+def transform_lhs_and_rhs(assignment_list, transformation, *args, **kwargs):
+    return [Assignment(transformation(a.lhs, *args, **kwargs),
+                       transformation(a.rhs, *args, **kwargs))
+            if isinstance(a, Assignment) else a
+            for a in assignment_list]
 
 
 class AssignmentCollection:
@@ -205,17 +221,15 @@ class AssignmentCollection:
         Returns:
             New AssignmentCollection where substitutions have been applied, self is not altered.
         """
-        if substitute_on_lhs:
-            new_subexpressions = [fast_subs(eq, substitutions) for eq in self.subexpressions]
-            new_equations = [fast_subs(eq, substitutions) for eq in self.main_assignments]
-        else:
-            new_subexpressions = [Assignment(eq.lhs, fast_subs(eq.rhs, substitutions)) for eq in self.subexpressions]
-            new_equations = [Assignment(eq.lhs, fast_subs(eq.rhs, substitutions)) for eq in self.main_assignments]
+        transform = transform_lhs_and_rhs if substitute_on_lhs else transform_rhs
+        transformed_subexpressions = transform(self.subexpressions, fast_subs, substitutions)
+        transformed_assignments = transform(self.main_assignments, fast_subs, substitutions)
 
         if add_substitutions_as_subexpressions:
-            new_subexpressions = [Assignment(b, a) for a, b in substitutions.items()] + new_subexpressions
-            new_subexpressions = sort_assignments_topologically(new_subexpressions)
-        return self.copy(new_equations, new_subexpressions)
+            transformed_subexpressions = [Assignment(b, a) for a, b in
+                                          substitutions.items()] + transformed_subexpressions
+            transformed_subexpressions = sort_assignments_topologically(transformed_subexpressions)
+        return self.copy(transformed_assignments, transformed_subexpressions)
 
     def new_merged(self, other: 'AssignmentCollection') -> 'AssignmentCollection':
         """Returns a new collection which contains self and other. Subexpressions are renamed if they clash."""
@@ -405,3 +419,22 @@ class SymbolGen:
         name = "{}_{}".format(self._symbol, self._ctr)
         self._ctr += 1
         return sp.Symbol(name)
+
+
+def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]:
+    """Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs"""
+    edges = []
+    for c1, e1 in enumerate(assignments):
+        if isinstance(e1, Assignment):
+            symbols = [e1.lhs]
+        elif isinstance(e1, Node):
+            symbols = e1.symbols_defined
+        else:
+            symbols = []
+        for lhs in symbols:
+            for c2, e2 in enumerate(assignments):
+                if isinstance(e2, Assignment) and lhs in e2.rhs.free_symbols:
+                    edges.append((c1, c2))
+                elif isinstance(e2, Node) and lhs in e2.undefined_symbols:
+                    edges.append((c1, c2))
+    return [assignments[i] for i in sp.topological_sort((range(len(assignments)), edges))]
diff --git a/pystencils/simp/simplifications.py b/pystencils/simp/simplifications.py
index 22ff1be..24726fd 100644
--- a/pystencils/simp/simplifications.py
+++ b/pystencils/simp/simplifications.py
@@ -4,7 +4,7 @@ import sympy as sp
 
 from pystencils.assignment import Assignment
 from pystencils.field import AbstractField, Field
-from pystencils.simp.assignment_collection import AssignmentCollection
+from pystencils.simp.assignment_collection import AssignmentCollection, transform_rhs
 from pystencils.sympyextensions import subs_additive
 
 AC = AssignmentCollection
@@ -128,15 +128,14 @@ def add_subexpressions_for_field_reads(ac: AC, subexpressions=True, main_assignm
     if main_assignments:
         for assignment in ac.main_assignments:
             field_reads.update(assignment.rhs.atoms(Field.Access))
-    substitutions = {fa: sp.Dummy() for fa in field_reads}
+    substitutions = {fa: next(ac.subexpression_symbol_generator) for fa in field_reads}
     return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, substitute_on_lhs=False)
 
 
 def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]) -> Callable[[AC], AC]:
     """Applies sympy expand operation to all equations in collection."""
-    def f(assignment_collection: AC) -> AC:
-        result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.main_assignments]
-        return assignment_collection.copy(result)
+    def f(ac: AC) -> AC:
+        return ac.copy(transform_rhs(ac.main_assignments, operation))
     f.__name__ = operation.__name__
     return f
 
@@ -144,7 +143,6 @@ def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]) -> Callabl
 def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]) -> Callable[[AC], AC]:
     """Applies the given operation on all subexpressions of the AC."""
     def f(ac: AC) -> AC:
-        result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in ac.subexpressions]
-        return ac.copy(ac.main_assignments, result)
+        return ac.copy(ac.main_assignments, transform_rhs(ac.subexpressions, operation))
     f.__name__ = operation.__name__
     return f
diff --git a/pystencils/sympyextensions.py b/pystencils/sympyextensions.py
index 2741f7c..afdf0fd 100644
--- a/pystencils/sympyextensions.py
+++ b/pystencils/sympyextensions.py
@@ -573,12 +573,6 @@ def get_symmetric_part(expr: sp.Expr, symbols: Iterable[sp.Symbol]) -> sp.Expr:
     return sp.Rational(1, 2) * (expr + expr.subs(substitution_dict))
 
 
-def sort_assignments_topologically(assignments: Sequence[Assignment]) -> List[Assignment]:
-    """Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs"""
-    res = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in assignments])
-    return [Assignment(a, b) for a, b in res]
-
-
 class SymbolCreator:
     def __getattribute__(self, name):
         return sp.Symbol(name)
-- 
GitLab