From 3715b05090da53741c8d52ce519542c4d61502f9 Mon Sep 17 00:00:00 2001 From: Alexander Reinauer <areinauer@icp.uni-stuttgart.de> Date: Sun, 21 Apr 2024 14:23:52 +0200 Subject: [PATCH] Fix new_merged for AssignmentCollections --- src/pystencils/simp/assignment_collection.py | 5 +++-- tests/test_assignment_collection.py | 16 ++++++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/pystencils/simp/assignment_collection.py b/src/pystencils/simp/assignment_collection.py index b0c09cec9..40a5e4cc8 100644 --- a/src/pystencils/simp/assignment_collection.py +++ b/src/pystencils/simp/assignment_collection.py @@ -286,12 +286,13 @@ class AssignmentCollection: processed_other_subexpression_equations = [] for other_subexpression_eq in other.subexpressions: if other_subexpression_eq.lhs in own_subexpression_symbols: - if other_subexpression_eq.rhs == own_subexpression_symbols[other_subexpression_eq.lhs]: + new_rhs = fast_subs(other_subexpression_eq.rhs, substitution_dict) + if new_rhs == own_subexpression_symbols[other_subexpression_eq.lhs]: continue # exact the same subexpression equation exists already else: # different definition - a new name has to be introduced new_lhs = next(self.subexpression_symbol_generator) - new_eq = Assignment(new_lhs, fast_subs(other_subexpression_eq.rhs, substitution_dict)) + new_eq = Assignment(new_lhs, new_rhs) processed_other_subexpression_equations.append(new_eq) substitution_dict[other_subexpression_eq.lhs] = new_lhs else: diff --git a/tests/test_assignment_collection.py b/tests/test_assignment_collection.py index f0c1f2a91..7260146f3 100644 --- a/tests/test_assignment_collection.py +++ b/tests/test_assignment_collection.py @@ -170,3 +170,19 @@ def test_new_merged(): assert ps.Assignment(d[0, 0](0), sp.symbols('xi_0')) in merged_ac.main_assignments assert a1 in merged_ac.subexpressions assert a3 in merged_ac.subexpressions + + a1 = ps.Assignment(a, 20) + a2 = ps.Assignment(a, 10) + acommon = ps.Assignment(b, a) + + # main assignments + a3 = ps.Assignment(f[0, 0](0), b) + a4 = ps.Assignment(d[0, 0](0), b) + + ac = ps.AssignmentCollection([a3], subexpressions=[a1, acommon]) + ac2 = ps.AssignmentCollection([a4], subexpressions=[a2, acommon]) + + merged_ac = ac.new_merged(ac2).new_without_subexpressions() + + assert ps.Assignment(f[0, 0](0), 20) in merged_ac.main_assignments + assert ps.Assignment(d[0, 0](0), 10) in merged_ac.main_assignments -- GitLab