Skip to content
Snippets Groups Projects
Commit d0654625 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Merge branch 'bauerd/fix-hoisting' into 'backend-rework'

Do not hoist declarations of mutated variables

See merge request !381
parents 4b2de595 64d39897
1 merge request!381Do not hoist declarations of mutated variables
Pipeline #65504 passed with stages
in 4 minutes and 41 seconds
......@@ -27,6 +27,7 @@ class HoistContext:
def __init__(self) -> None:
self.hoisted_nodes: list[PsDeclaration] = []
self.assigned_symbols: set[PsSymbol] = set()
self.mutated_symbols: set[PsSymbol] = set()
self.invariant_symbols: set[PsSymbol] = set()
def _is_invariant(self, expr: PsExpression) -> bool:
......@@ -123,6 +124,7 @@ class HoistLoopInvariantDeclarations:
"""Hoist invariant declarations out of the given loop."""
hc = HoistContext()
hc.assigned_symbols.add(loop.counter.symbol)
hc.mutated_symbols.add(loop.counter.symbol)
self._prepare_hoist(loop.body, hc)
self._hoist_from_block(loop.body, hc)
return hc
......@@ -134,8 +136,12 @@ class HoistLoopInvariantDeclarations:
case PsExpression():
return
case PsDeclaration(PsSymbolExpr(lhs_symb), _):
hc.assigned_symbols.add(lhs_symb)
case PsAssignment(PsSymbolExpr(lhs_symb), _):
hc.assigned_symbols.add(lhs_symb)
hc.mutated_symbols.add(lhs_symb)
case PsAssignment(_, _):
return
......@@ -147,6 +153,7 @@ class HoistLoopInvariantDeclarations:
loop = stmt
nested_hc = self._hoist(loop)
hc.assigned_symbols |= nested_hc.assigned_symbols
hc.mutated_symbols |= nested_hc.mutated_symbols
statements_new += nested_hc.hoisted_nodes
if loop.body.statements:
statements_new.append(loop)
......@@ -169,7 +176,8 @@ class HoistLoopInvariantDeclarations:
for node in block.statements:
if isinstance(node, PsDeclaration):
if hc._is_invariant(node.rhs):
lhs_symb = cast(PsSymbolExpr, node.lhs).symbol
if lhs_symb not in hc.mutated_symbols and hc._is_invariant(node.rhs):
hc.hoisted_nodes.append(node)
hc.invariant_symbols.add(node.declared_symbol)
else:
......
......@@ -193,3 +193,22 @@ def test_hoisting_eliminates_loops():
assert isinstance(ast, PsBlock)
# All statements are hoisted and the loops are removed
assert ast.statements == invariant_decls
def test_hoist_mutation():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
hoist = HoistLoopInvariantDeclarations(ctx)
x = sp.Symbol("x")
x_decl = factory.parse_sympy(Assignment(x, 1))
x_update = factory.parse_sympy(AddAugmentedAssignment(x, 1))
inner_loop = factory.loop("j", slice(10), PsBlock([x_update]))
outer_loop = factory.loop("i", slice(10), PsBlock([x_decl, inner_loop]))
result = hoist(outer_loop)
# x is updated in the loop, so nothing can be hoisted
assert isinstance(result, PsLoop)
assert result.body.statements == [x_decl, inner_loop]
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment