diff --git a/simp/__init__.py b/simp/__init__.py index 0e8a85d1fac9e85830b9b26776fd929cc8937b8d..36efb062ded6ff27f2594193156762b3682e7932 100644 --- a/simp/__init__.py +++ b/simp/__init__.py @@ -2,9 +2,10 @@ from .assignment_collection import AssignmentCollection from .simplificationstrategy import SimplificationStrategy from .simplifications import sympy_cse, sympy_cse_on_assignment_list, \ apply_to_all_assignments, apply_on_all_subexpressions, subexpression_substitution_in_existing_subexpressions, \ - subexpression_substitution_in_main_assignments, add_subexpressions_for_divisions + subexpression_substitution_in_main_assignments, add_subexpressions_for_divisions, add_subexpressions_for_field_reads __all__ = ['AssignmentCollection', 'SimplificationStrategy', 'sympy_cse', 'sympy_cse_on_assignment_list', 'apply_to_all_assignments', 'apply_on_all_subexpressions', 'subexpression_substitution_in_existing_subexpressions', - 'subexpression_substitution_in_main_assignments', 'add_subexpressions_for_divisions'] + 'subexpression_substitution_in_main_assignments', 'add_subexpressions_for_divisions', + 'add_subexpressions_for_field_reads'] diff --git a/simp/simplifications.py b/simp/simplifications.py index 076adf967ad412cb4b6a943c752c8ee2f46640ee..a42601b52a18f6fed181af95edfe498479b155d8 100644 --- a/simp/simplifications.py +++ b/simp/simplifications.py @@ -1,5 +1,7 @@ import sympy as sp from typing import Callable, List + +from pystencils import Field from pystencils.assignment import Assignment from pystencils.simp.assignment_collection import AssignmentCollection from pystencils.sympyextensions import subs_additive @@ -83,6 +85,24 @@ def add_subexpressions_for_divisions(ac: AC) -> AC: return ac.new_with_substitutions(substitutions, True) +def add_subexpressions_for_field_reads(ac: AC, subexpressions=True, main_assignments=True) -> AC: + r"""Substitutes field accesses on rhs of assignments with subexpressions + + Can change semantics of the update rule (which is the goal of this transformation) + This is useful if a field should be update in place - all values are loaded before into subexpression variables, + then the new values are computed and written to the same field in-place. + """ + field_reads = set() + if subexpressions: + for assignment in ac.subexpressions: + field_reads.update(assignment.rhs.atoms(Field.Access)) + if main_assignments: + for assignment in ac.main_assignments: + field_reads.update(assignment.rhs.atoms(Field.Access)) + substitutions = {fa: sp.Dummy() for fa in field_reads} + return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, substitute_on_lhs=False) + + def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]) -> Callable[[AC], AC]: """Applies sympy expand operation to all equations in collection.""" def f(assignment_collection: AC) -> AC: