From 0fd71fbf9ceaf303fb16a54edfe7c7d2169afac5 Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Fri, 9 Aug 2019 09:49:13 +0200 Subject: [PATCH] Fix bugs recently introduced in topological sort generalizations --- pystencils/simp/assignment_collection.py | 6 +++--- pystencils/simp/simplifications.py | 9 +++++---- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/pystencils/simp/assignment_collection.py b/pystencils/simp/assignment_collection.py index 9d253ff7b..e0f5ec926 100644 --- a/pystencils/simp/assignment_collection.py +++ b/pystencils/simp/assignment_collection.py @@ -5,7 +5,7 @@ import sympy as sp from pystencils.assignment import Assignment from pystencils.simp.simplifications import ( - sort_assignments_topologically, sympy_cse_on_assignment_list, + sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs) from pystencils.sympyextensions import count_operations, fast_subs @@ -85,9 +85,9 @@ class AssignmentCollection: def topological_sort(self, sort_subexpressions: bool = True, sort_main_assignments: bool = True) -> None: """Sorts subexpressions and/or main_equations topologically to make sure symbol usage comes after definition.""" if sort_subexpressions: - self.subexpressions = sympy_cse_on_assignment_list(self.subexpressions) + self.subexpressions = sort_assignments_topologically(self.subexpressions) if sort_main_assignments: - self.main_assignments = sympy_cse_on_assignment_list(self.main_assignments) + self.main_assignments = sort_assignments_topologically(self.main_assignments) # ---------------------------------------------- Properties ------------------------------------------------------- diff --git a/pystencils/simp/simplifications.py b/pystencils/simp/simplifications.py index ab2b3d83d..3a4c64764 100644 --- a/pystencils/simp/simplifications.py +++ b/pystencils/simp/simplifications.py @@ -13,12 +13,13 @@ def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node] """Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs""" edges = [] for c1, e1 in enumerate(assignments): - if isinstance(e1, Assignment): + if hasattr(e1, 'lhs') and hasattr(e1, 'rhs'): symbols = [e1.lhs] elif isinstance(e1, Node): symbols = e1.symbols_defined else: - symbols = [] + raise NotImplementedError("Cannot sort topologically. Object of type " + type(e1) + " cannot be handled.") + for lhs in symbols: for c2, e2 in enumerate(assignments): if isinstance(e2, Assignment) and lhs in e2.rhs.free_symbols: @@ -155,14 +156,14 @@ def transform_rhs(assignment_list, transformation, *args, **kwargs): """Applies a transformation function on the rhs of each element of the passed assignment list If the list also contains other object, like AST nodes, these are ignored. Additional parameters are passed to the transformation function""" - return [Assignment(a.lhs, transformation(a.rhs, *args, **kwargs)) if isinstance(a, Assignment) else a + return [Assignment(a.lhs, transformation(a.rhs, *args, **kwargs)) if hasattr(a, 'lhs') and hasattr(a, 'rhs') else a for a in assignment_list] def transform_lhs_and_rhs(assignment_list, transformation, *args, **kwargs): return [Assignment(transformation(a.lhs, *args, **kwargs), transformation(a.rhs, *args, **kwargs)) - if isinstance(a, Assignment) else a + if hasattr(a, 'lhs') and hasattr(a, 'rhs') else a for a in assignment_list] -- GitLab