From 64d39897279003a2d47b8352b5e89cfad9c02592 Mon Sep 17 00:00:00 2001 From: Daniel Bauer <daniel.j.bauer@fau.de> Date: Mon, 29 Apr 2024 08:48:30 +0200 Subject: [PATCH] Do not hoist declarations of mutated variables --- .../hoist_loop_invariant_decls.py | 10 +++++++++- .../transformations/test_hoist_invariants.py | 19 +++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py index cb9c9e920..2368868a9 100644 --- a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py +++ b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py @@ -27,6 +27,7 @@ class HoistContext: def __init__(self) -> None: self.hoisted_nodes: list[PsDeclaration] = [] self.assigned_symbols: set[PsSymbol] = set() + self.mutated_symbols: set[PsSymbol] = set() self.invariant_symbols: set[PsSymbol] = set() def _is_invariant(self, expr: PsExpression) -> bool: @@ -123,6 +124,7 @@ class HoistLoopInvariantDeclarations: """Hoist invariant declarations out of the given loop.""" hc = HoistContext() hc.assigned_symbols.add(loop.counter.symbol) + hc.mutated_symbols.add(loop.counter.symbol) self._prepare_hoist(loop.body, hc) self._hoist_from_block(loop.body, hc) return hc @@ -134,8 +136,12 @@ class HoistLoopInvariantDeclarations: case PsExpression(): return + case PsDeclaration(PsSymbolExpr(lhs_symb), _): + hc.assigned_symbols.add(lhs_symb) + case PsAssignment(PsSymbolExpr(lhs_symb), _): hc.assigned_symbols.add(lhs_symb) + hc.mutated_symbols.add(lhs_symb) case PsAssignment(_, _): return @@ -147,6 +153,7 @@ class HoistLoopInvariantDeclarations: loop = stmt nested_hc = self._hoist(loop) hc.assigned_symbols |= nested_hc.assigned_symbols + hc.mutated_symbols |= nested_hc.mutated_symbols statements_new += nested_hc.hoisted_nodes if loop.body.statements: statements_new.append(loop) @@ -169,7 +176,8 @@ class HoistLoopInvariantDeclarations: for node in block.statements: if isinstance(node, PsDeclaration): - if hc._is_invariant(node.rhs): + lhs_symb = cast(PsSymbolExpr, node.lhs).symbol + if lhs_symb not in hc.mutated_symbols and hc._is_invariant(node.rhs): hc.hoisted_nodes.append(node) hc.invariant_symbols.add(node.declared_symbol) else: diff --git a/tests/nbackend/transformations/test_hoist_invariants.py b/tests/nbackend/transformations/test_hoist_invariants.py index db78efce5..15514f1da 100644 --- a/tests/nbackend/transformations/test_hoist_invariants.py +++ b/tests/nbackend/transformations/test_hoist_invariants.py @@ -193,3 +193,22 @@ def test_hoisting_eliminates_loops(): assert isinstance(ast, PsBlock) # All statements are hoisted and the loops are removed assert ast.statements == invariant_decls + + +def test_hoist_mutation(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + hoist = HoistLoopInvariantDeclarations(ctx) + + x = sp.Symbol("x") + x_decl = factory.parse_sympy(Assignment(x, 1)) + x_update = factory.parse_sympy(AddAugmentedAssignment(x, 1)) + + inner_loop = factory.loop("j", slice(10), PsBlock([x_update])) + outer_loop = factory.loop("i", slice(10), PsBlock([x_decl, inner_loop])) + + result = hoist(outer_loop) + + # x is updated in the loop, so nothing can be hoisted + assert isinstance(result, PsLoop) + assert result.body.statements == [x_decl, inner_loop] -- GitLab