diff --git a/pystencils/simp/assignment_collection.py b/pystencils/simp/assignment_collection.py index e0f5ec926376205f7f5ed68650791e75b1b634da..dcd89004fc03eb8d0564f80906380c018bda8820 100644 --- a/pystencils/simp/assignment_collection.py +++ b/pystencils/simp/assignment_collection.py @@ -197,14 +197,16 @@ class AssignmentCollection: return res def new_with_substitutions(self, substitutions: Dict, add_substitutions_as_subexpressions: bool = False, - substitute_on_lhs: bool = True) -> 'AssignmentCollection': + substitute_on_lhs: bool = True, + sort_topologically: bool = True) -> 'AssignmentCollection': """Returns new object, where terms are substituted according to the passed substitution dict. Args: substitutions: dict that is passed to sympy subs, substitutions are done main assignments and subexpressions add_substitutions_as_subexpressions: if True, the substitutions are added as assignments to subexpressions substitute_on_lhs: if False, the substitutions are done only on the right hand side of assignments - + sort_topologically: if subexpressions are added as substitutions and this parameters is true, + the subexpressions are sorted topologically after insertion Returns: New AssignmentCollection where substitutions have been applied, self is not altered. """ @@ -215,7 +217,8 @@ class AssignmentCollection: if add_substitutions_as_subexpressions: transformed_subexpressions = [Assignment(b, a) for a, b in substitutions.items()] + transformed_subexpressions - transformed_subexpressions = sort_assignments_topologically(transformed_subexpressions) + if sort_topologically: + transformed_subexpressions = sort_assignments_topologically(transformed_subexpressions) return self.copy(transformed_assignments, transformed_subexpressions) def new_merged(self, other: 'AssignmentCollection') -> 'AssignmentCollection': diff --git a/pystencils/simp/simplifications.py b/pystencils/simp/simplifications.py index 3a4c64764e3c17d46c204e6a39abeab5b3d13439..3d2f57ce3c298fd7a20dc3bcf3535ed76ef43fd9 100644 --- a/pystencils/simp/simplifications.py +++ b/pystencils/simp/simplifications.py @@ -104,7 +104,7 @@ def add_subexpressions_for_divisions(ac): divisors = sorted(list(divisors), key=lambda x: str(x)) new_symbol_gen = ac.subexpression_symbol_generator substitutions = {divisor: new_symbol for new_symbol, divisor in zip(new_symbol_gen, divisors)} - return ac.new_with_substitutions(substitutions, True) + return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, substitute_on_lhs=False) def add_subexpressions_for_sums(ac): @@ -142,14 +142,18 @@ def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments then the new values are computed and written to the same field in-place. """ field_reads = set() + to_iterate = [] if subexpressions: - for assignment in ac.subexpressions: - field_reads.update(assignment.rhs.atoms(Field.Access)) + to_iterate = chain(to_iterate, ac.subexpressions) if main_assignments: - for assignment in ac.main_assignments: + to_iterate = chain(to_iterate, ac.main_assignments) + + for assignment in to_iterate: + if hasattr(assignment, 'lhs') and hasattr(assignment, 'rhs'): field_reads.update(assignment.rhs.atoms(Field.Access)) substitutions = {fa: next(ac.subexpression_symbol_generator) for fa in field_reads} - return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, substitute_on_lhs=False) + return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, + substitute_on_lhs=False, sort_topologically=False) def transform_rhs(assignment_list, transformation, *args, **kwargs):