Commit 0fd71fbf authored by Martin Bauer's avatar Martin Bauer
Browse files

Fix bugs recently introduced in topological sort generalizations

parent bd49f37e
......@@ -5,7 +5,7 @@ import sympy as sp
from pystencils.assignment import Assignment
from pystencils.simp.simplifications import (
sort_assignments_topologically, sympy_cse_on_assignment_list,
transform_lhs_and_rhs, transform_rhs)
from pystencils.sympyextensions import count_operations, fast_subs
......@@ -85,9 +85,9 @@ class AssignmentCollection:
def topological_sort(self, sort_subexpressions: bool = True, sort_main_assignments: bool = True) -> None:
"""Sorts subexpressions and/or main_equations topologically to make sure symbol usage comes after definition."""
if sort_subexpressions:
self.subexpressions = sympy_cse_on_assignment_list(self.subexpressions)
self.subexpressions = sort_assignments_topologically(self.subexpressions)
if sort_main_assignments:
self.main_assignments = sympy_cse_on_assignment_list(self.main_assignments)
self.main_assignments = sort_assignments_topologically(self.main_assignments)
# ---------------------------------------------- Properties -------------------------------------------------------
......@@ -13,12 +13,13 @@ def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]
"""Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs"""
edges = []
for c1, e1 in enumerate(assignments):
if isinstance(e1, Assignment):
if hasattr(e1, 'lhs') and hasattr(e1, 'rhs'):
symbols = [e1.lhs]
elif isinstance(e1, Node):
symbols = e1.symbols_defined
symbols = []
raise NotImplementedError("Cannot sort topologically. Object of type " + type(e1) + " cannot be handled.")
for lhs in symbols:
for c2, e2 in enumerate(assignments):
if isinstance(e2, Assignment) and lhs in e2.rhs.free_symbols:
......@@ -155,14 +156,14 @@ def transform_rhs(assignment_list, transformation, *args, **kwargs):
"""Applies a transformation function on the rhs of each element of the passed assignment list
If the list also contains other object, like AST nodes, these are ignored.
Additional parameters are passed to the transformation function"""
return [Assignment(a.lhs, transformation(a.rhs, *args, **kwargs)) if isinstance(a, Assignment) else a
return [Assignment(a.lhs, transformation(a.rhs, *args, **kwargs)) if hasattr(a, 'lhs') and hasattr(a, 'rhs') else a
for a in assignment_list]
def transform_lhs_and_rhs(assignment_list, transformation, *args, **kwargs):
return [Assignment(transformation(a.lhs, *args, **kwargs),
transformation(a.rhs, *args, **kwargs))
if isinstance(a, Assignment) else a
if hasattr(a, 'lhs') and hasattr(a, 'rhs') else a
for a in assignment_list]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment