From 63a9140d417c022784b54289fa7a84a7cbcf63fc Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Mon, 12 Aug 2019 13:59:17 +0200
Subject: [PATCH] Fixes in assignments collection simplifications / topological
 sort

---
 pystencils/simp/assignment_collection.py |  9 ++++++---
 pystencils/simp/simplifications.py       | 14 +++++++++-----
 2 files changed, 15 insertions(+), 8 deletions(-)

diff --git a/pystencils/simp/assignment_collection.py b/pystencils/simp/assignment_collection.py
index e0f5ec926..dcd89004f 100644
--- a/pystencils/simp/assignment_collection.py
+++ b/pystencils/simp/assignment_collection.py
@@ -197,14 +197,16 @@ class AssignmentCollection:
         return res
 
     def new_with_substitutions(self, substitutions: Dict, add_substitutions_as_subexpressions: bool = False,
-                               substitute_on_lhs: bool = True) -> 'AssignmentCollection':
+                               substitute_on_lhs: bool = True,
+                               sort_topologically: bool = True) -> 'AssignmentCollection':
         """Returns new object, where terms are substituted according to the passed substitution dict.
 
         Args:
             substitutions: dict that is passed to sympy subs, substitutions are done main assignments and subexpressions
             add_substitutions_as_subexpressions: if True, the substitutions are added as assignments to subexpressions
             substitute_on_lhs: if False, the substitutions are done only on the right hand side of assignments
-
+            sort_topologically: if subexpressions are added as substitutions and this parameters is true,
+                                the subexpressions are sorted topologically after insertion
         Returns:
             New AssignmentCollection where substitutions have been applied, self is not altered.
         """
@@ -215,7 +217,8 @@ class AssignmentCollection:
         if add_substitutions_as_subexpressions:
             transformed_subexpressions = [Assignment(b, a) for a, b in
                                           substitutions.items()] + transformed_subexpressions
-            transformed_subexpressions = sort_assignments_topologically(transformed_subexpressions)
+            if sort_topologically:
+                transformed_subexpressions = sort_assignments_topologically(transformed_subexpressions)
         return self.copy(transformed_assignments, transformed_subexpressions)
 
     def new_merged(self, other: 'AssignmentCollection') -> 'AssignmentCollection':
diff --git a/pystencils/simp/simplifications.py b/pystencils/simp/simplifications.py
index 3a4c64764..3d2f57ce3 100644
--- a/pystencils/simp/simplifications.py
+++ b/pystencils/simp/simplifications.py
@@ -104,7 +104,7 @@ def add_subexpressions_for_divisions(ac):
     divisors = sorted(list(divisors), key=lambda x: str(x))
     new_symbol_gen = ac.subexpression_symbol_generator
     substitutions = {divisor: new_symbol for new_symbol, divisor in zip(new_symbol_gen, divisors)}
-    return ac.new_with_substitutions(substitutions, True)
+    return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, substitute_on_lhs=False)
 
 
 def add_subexpressions_for_sums(ac):
@@ -142,14 +142,18 @@ def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments
     then the new values are computed and written to the same field in-place.
     """
     field_reads = set()
+    to_iterate = []
     if subexpressions:
-        for assignment in ac.subexpressions:
-            field_reads.update(assignment.rhs.atoms(Field.Access))
+        to_iterate = chain(to_iterate, ac.subexpressions)
     if main_assignments:
-        for assignment in ac.main_assignments:
+        to_iterate = chain(to_iterate, ac.main_assignments)
+
+    for assignment in to_iterate:
+        if hasattr(assignment, 'lhs') and hasattr(assignment, 'rhs'):
             field_reads.update(assignment.rhs.atoms(Field.Access))
     substitutions = {fa: next(ac.subexpression_symbol_generator) for fa in field_reads}
-    return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, substitute_on_lhs=False)
+    return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True,
+                                     substitute_on_lhs=False, sort_topologically=False)
 
 
 def transform_rhs(assignment_list, transformation, *args, **kwargs):
-- 
GitLab