diff --git a/pystencils/simp/assignment_collection.py b/pystencils/simp/assignment_collection.py index b15b87a4fc948c1fd86add59c64d5b16dce89afa..da48324948b20d59e7cf8757c885396f0db68a8b 100644 --- a/pystencils/simp/assignment_collection.py +++ b/pystencils/simp/assignment_collection.py @@ -362,6 +362,29 @@ class AssignmentCollection: self.sub_expressions = [Assignment(k, v) for k, v in sub_expressions_dict.items()] + def find(self, *args, **kwargs): + return set.union(*[a.find(*args, **kwargs) for a in self.all_assignments]) + + def match(self, *args, **kwargs): + rtn = {} + for a in self.all_assignments: + partial_result = a.match(*args, **kwargs) + if partial_result: + rtn.update(partial_result) + return rtn + + def subs(self, *args, **kwargs): + return AssignmentCollection( + main_assignments=[a.subs(*args, **kwargs) for a in self.main_assignments], + subexpressions=[a.subs(*args, **kwargs) for a in self.subexpressions] + ) + + def replace(self, *args, **kwargs): + return AssignmentCollection( + main_assignments=[a.replace(*args, **kwargs) for a in self.main_assignments], + subexpressions=[a.replace(*args, **kwargs) for a in self.subexpressions] + ) + class SymbolGen: """Default symbol generator producing number symbols ζ_0, ζ_1, ..."""