Skip to content
Snippets Groups Projects
Commit 0845d667 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Fix symbol canonicalization to not duplicate when marking as updated

parent 3ee5d9b6
Branches
Tags
1 merge request!380Fix symbol canonicalization to not duplicate when marking as updated
Pipeline #65464 passed with stages
in 5 minutes and 10 seconds
......@@ -33,7 +33,7 @@ class CanonContext:
return replacement
def mark_as_updated(self, symb: PsSymbol):
self.updated_symbols.add(self.deduplicate(symb))
self.updated_symbols.add(symb)
def is_live(self, symb: PsSymbol) -> bool:
return symb in self.live_symbols_map
......
......@@ -86,3 +86,31 @@ def test_do_not_constify():
assert ctx.find_symbol("x").dtype.const
assert not ctx.find_symbol("z").dtype.const
def test_loop_counters():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
canonicalize = CanonicalizeSymbols(ctx)
f = Field.create_generic("f", 2, index_shape=(1,))
g = Field.create_generic("g", 2, index_shape=(1,))
ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:, :], archetype_field=f)
ctx.set_iteration_space(ispace)
asm = Assignment(f.center(0), 2 * g.center(0))
body = PsBlock([factory.parse_sympy(asm)])
loops = factory.loops_from_ispace(ispace, body)
loops_copy = loops.clone()
ast = PsBlock([loops, loops_copy])
ast = canonicalize(ast)
assert loops_copy.counter.symbol.name == "ctr_0"
assert not loops_copy.counter.symbol.get_dtype().const
assert loops.counter.symbol.name == "ctr_0__0"
assert not loops.counter.symbol.get_dtype().const
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment