From 64d39897279003a2d47b8352b5e89cfad9c02592 Mon Sep 17 00:00:00 2001
From: Daniel Bauer <daniel.j.bauer@fau.de>
Date: Mon, 29 Apr 2024 08:48:30 +0200
Subject: [PATCH] Do not hoist declarations of mutated variables

---
 .../hoist_loop_invariant_decls.py             | 10 +++++++++-
 .../transformations/test_hoist_invariants.py  | 19 +++++++++++++++++++
 2 files changed, 28 insertions(+), 1 deletion(-)

diff --git a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py
index cb9c9e920..2368868a9 100644
--- a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py
+++ b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py
@@ -27,6 +27,7 @@ class HoistContext:
     def __init__(self) -> None:
         self.hoisted_nodes: list[PsDeclaration] = []
         self.assigned_symbols: set[PsSymbol] = set()
+        self.mutated_symbols: set[PsSymbol] = set()
         self.invariant_symbols: set[PsSymbol] = set()
 
     def _is_invariant(self, expr: PsExpression) -> bool:
@@ -123,6 +124,7 @@ class HoistLoopInvariantDeclarations:
         """Hoist invariant declarations out of the given loop."""
         hc = HoistContext()
         hc.assigned_symbols.add(loop.counter.symbol)
+        hc.mutated_symbols.add(loop.counter.symbol)
         self._prepare_hoist(loop.body, hc)
         self._hoist_from_block(loop.body, hc)
         return hc
@@ -134,8 +136,12 @@ class HoistLoopInvariantDeclarations:
             case PsExpression():
                 return
 
+            case PsDeclaration(PsSymbolExpr(lhs_symb), _):
+                hc.assigned_symbols.add(lhs_symb)
+
             case PsAssignment(PsSymbolExpr(lhs_symb), _):
                 hc.assigned_symbols.add(lhs_symb)
+                hc.mutated_symbols.add(lhs_symb)
 
             case PsAssignment(_, _):
                 return
@@ -147,6 +153,7 @@ class HoistLoopInvariantDeclarations:
                         loop = stmt
                         nested_hc = self._hoist(loop)
                         hc.assigned_symbols |= nested_hc.assigned_symbols
+                        hc.mutated_symbols |= nested_hc.mutated_symbols
                         statements_new += nested_hc.hoisted_nodes
                         if loop.body.statements:
                             statements_new.append(loop)
@@ -169,7 +176,8 @@ class HoistLoopInvariantDeclarations:
 
         for node in block.statements:
             if isinstance(node, PsDeclaration):
-                if hc._is_invariant(node.rhs):
+                lhs_symb = cast(PsSymbolExpr, node.lhs).symbol
+                if lhs_symb not in hc.mutated_symbols and hc._is_invariant(node.rhs):
                     hc.hoisted_nodes.append(node)
                     hc.invariant_symbols.add(node.declared_symbol)
                 else:
diff --git a/tests/nbackend/transformations/test_hoist_invariants.py b/tests/nbackend/transformations/test_hoist_invariants.py
index db78efce5..15514f1da 100644
--- a/tests/nbackend/transformations/test_hoist_invariants.py
+++ b/tests/nbackend/transformations/test_hoist_invariants.py
@@ -193,3 +193,22 @@ def test_hoisting_eliminates_loops():
     assert isinstance(ast, PsBlock)
     #   All statements are hoisted and the loops are removed
     assert ast.statements == invariant_decls
+
+
+def test_hoist_mutation():
+    ctx = KernelCreationContext()
+    factory = AstFactory(ctx)
+    hoist = HoistLoopInvariantDeclarations(ctx)
+
+    x = sp.Symbol("x")
+    x_decl = factory.parse_sympy(Assignment(x, 1))
+    x_update = factory.parse_sympy(AddAugmentedAssignment(x, 1))
+
+    inner_loop = factory.loop("j", slice(10), PsBlock([x_update]))
+    outer_loop = factory.loop("i", slice(10), PsBlock([x_decl, inner_loop]))
+
+    result = hoist(outer_loop)
+
+    #   x is updated in the loop, so nothing can be hoisted
+    assert isinstance(result, PsLoop)
+    assert result.body.statements == [x_decl, inner_loop]
-- 
GitLab