Skip to content
Snippets Groups Projects
Commit e512822b authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Merge branch 'fhennig/fix_canonicalize' into 'backend-rework'

Fix symbol canonicalization to not duplicate when marking as updated

See merge request !380
parents 3ee5d9b6 0845d667
1 merge request!380Fix symbol canonicalization to not duplicate when marking as updated
Pipeline #65467 passed with stages
in 1 minute
......@@ -33,7 +33,7 @@ class CanonContext:
return replacement
def mark_as_updated(self, symb: PsSymbol):
self.updated_symbols.add(self.deduplicate(symb))
self.updated_symbols.add(symb)
def is_live(self, symb: PsSymbol) -> bool:
return symb in self.live_symbols_map
......
......@@ -86,3 +86,31 @@ def test_do_not_constify():
assert ctx.find_symbol("x").dtype.const
assert not ctx.find_symbol("z").dtype.const
def test_loop_counters():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
canonicalize = CanonicalizeSymbols(ctx)
f = Field.create_generic("f", 2, index_shape=(1,))
g = Field.create_generic("g", 2, index_shape=(1,))
ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:, :], archetype_field=f)
ctx.set_iteration_space(ispace)
asm = Assignment(f.center(0), 2 * g.center(0))
body = PsBlock([factory.parse_sympy(asm)])
loops = factory.loops_from_ispace(ispace, body)
loops_copy = loops.clone()
ast = PsBlock([loops, loops_copy])
ast = canonicalize(ast)
assert loops_copy.counter.symbol.name == "ctr_0"
assert not loops_copy.counter.symbol.get_dtype().const
assert loops.counter.symbol.name == "ctr_0__0"
assert not loops.counter.symbol.get_dtype().const
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment