From 0845d667afdb41e64ff7a0aadd07a19124276f51 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Fri, 26 Apr 2024 17:13:09 +0200 Subject: [PATCH] Fix symbol canonicalization to not duplicate when marking as updated --- .../transformations/canonicalize_symbols.py | 2 +- .../test_canonicalize_symbols.py | 28 +++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/src/pystencils/backend/transformations/canonicalize_symbols.py b/src/pystencils/backend/transformations/canonicalize_symbols.py index 3900105b8..e55807ef4 100644 --- a/src/pystencils/backend/transformations/canonicalize_symbols.py +++ b/src/pystencils/backend/transformations/canonicalize_symbols.py @@ -33,7 +33,7 @@ class CanonContext: return replacement def mark_as_updated(self, symb: PsSymbol): - self.updated_symbols.add(self.deduplicate(symb)) + self.updated_symbols.add(symb) def is_live(self, symb: PsSymbol) -> bool: return symb in self.live_symbols_map diff --git a/tests/nbackend/transformations/test_canonicalize_symbols.py b/tests/nbackend/transformations/test_canonicalize_symbols.py index 43f269163..6a4785564 100644 --- a/tests/nbackend/transformations/test_canonicalize_symbols.py +++ b/tests/nbackend/transformations/test_canonicalize_symbols.py @@ -86,3 +86,31 @@ def test_do_not_constify(): assert ctx.find_symbol("x").dtype.const assert not ctx.find_symbol("z").dtype.const + + +def test_loop_counters(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + canonicalize = CanonicalizeSymbols(ctx) + + f = Field.create_generic("f", 2, index_shape=(1,)) + g = Field.create_generic("g", 2, index_shape=(1,)) + ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:, :], archetype_field=f) + ctx.set_iteration_space(ispace) + + asm = Assignment(f.center(0), 2 * g.center(0)) + + body = PsBlock([factory.parse_sympy(asm)]) + + loops = factory.loops_from_ispace(ispace, body) + + loops_copy = loops.clone() + + ast = PsBlock([loops, loops_copy]) + + ast = canonicalize(ast) + + assert loops_copy.counter.symbol.name == "ctr_0" + assert not loops_copy.counter.symbol.get_dtype().const + assert loops.counter.symbol.name == "ctr_0__0" + assert not loops.counter.symbol.get_dtype().const -- GitLab