Commit 63a9140d authored by Martin Bauer's avatar Martin Bauer

Fixes in assignments collection simplifications / topological sort

parent 0fd71fbf
......@@ -197,14 +197,16 @@ class AssignmentCollection:
return res
def new_with_substitutions(self, substitutions: Dict, add_substitutions_as_subexpressions: bool = False,
substitute_on_lhs: bool = True) -> 'AssignmentCollection':
substitute_on_lhs: bool = True,
sort_topologically: bool = True) -> 'AssignmentCollection':
"""Returns new object, where terms are substituted according to the passed substitution dict.
Args:
substitutions: dict that is passed to sympy subs, substitutions are done main assignments and subexpressions
add_substitutions_as_subexpressions: if True, the substitutions are added as assignments to subexpressions
substitute_on_lhs: if False, the substitutions are done only on the right hand side of assignments
sort_topologically: if subexpressions are added as substitutions and this parameters is true,
the subexpressions are sorted topologically after insertion
Returns:
New AssignmentCollection where substitutions have been applied, self is not altered.
"""
......@@ -215,7 +217,8 @@ class AssignmentCollection:
if add_substitutions_as_subexpressions:
transformed_subexpressions = [Assignment(b, a) for a, b in
substitutions.items()] + transformed_subexpressions
transformed_subexpressions = sort_assignments_topologically(transformed_subexpressions)
if sort_topologically:
transformed_subexpressions = sort_assignments_topologically(transformed_subexpressions)
return self.copy(transformed_assignments, transformed_subexpressions)
def new_merged(self, other: 'AssignmentCollection') -> 'AssignmentCollection':
......
......@@ -104,7 +104,7 @@ def add_subexpressions_for_divisions(ac):
divisors = sorted(list(divisors), key=lambda x: str(x))
new_symbol_gen = ac.subexpression_symbol_generator
substitutions = {divisor: new_symbol for new_symbol, divisor in zip(new_symbol_gen, divisors)}
return ac.new_with_substitutions(substitutions, True)
return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, substitute_on_lhs=False)
def add_subexpressions_for_sums(ac):
......@@ -142,14 +142,18 @@ def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments
then the new values are computed and written to the same field in-place.
"""
field_reads = set()
to_iterate = []
if subexpressions:
for assignment in ac.subexpressions:
field_reads.update(assignment.rhs.atoms(Field.Access))
to_iterate = chain(to_iterate, ac.subexpressions)
if main_assignments:
for assignment in ac.main_assignments:
to_iterate = chain(to_iterate, ac.main_assignments)
for assignment in to_iterate:
if hasattr(assignment, 'lhs') and hasattr(assignment, 'rhs'):
field_reads.update(assignment.rhs.atoms(Field.Access))
substitutions = {fa: next(ac.subexpression_symbol_generator) for fa in field_reads}
return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, substitute_on_lhs=False)
return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True,
substitute_on_lhs=False, sort_topologically=False)
def transform_rhs(assignment_list, transformation, *args, **kwargs):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment