From 0845d667afdb41e64ff7a0aadd07a19124276f51 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Fri, 26 Apr 2024 17:13:09 +0200
Subject: [PATCH] Fix symbol canonicalization to not duplicate when marking as
 updated

---
 .../transformations/canonicalize_symbols.py   |  2 +-
 .../test_canonicalize_symbols.py              | 28 +++++++++++++++++++
 2 files changed, 29 insertions(+), 1 deletion(-)

diff --git a/src/pystencils/backend/transformations/canonicalize_symbols.py b/src/pystencils/backend/transformations/canonicalize_symbols.py
index 3900105b8..e55807ef4 100644
--- a/src/pystencils/backend/transformations/canonicalize_symbols.py
+++ b/src/pystencils/backend/transformations/canonicalize_symbols.py
@@ -33,7 +33,7 @@ class CanonContext:
             return replacement
 
     def mark_as_updated(self, symb: PsSymbol):
-        self.updated_symbols.add(self.deduplicate(symb))
+        self.updated_symbols.add(symb)
 
     def is_live(self, symb: PsSymbol) -> bool:
         return symb in self.live_symbols_map
diff --git a/tests/nbackend/transformations/test_canonicalize_symbols.py b/tests/nbackend/transformations/test_canonicalize_symbols.py
index 43f269163..6a4785564 100644
--- a/tests/nbackend/transformations/test_canonicalize_symbols.py
+++ b/tests/nbackend/transformations/test_canonicalize_symbols.py
@@ -86,3 +86,31 @@ def test_do_not_constify():
 
     assert ctx.find_symbol("x").dtype.const
     assert not ctx.find_symbol("z").dtype.const
+
+
+def test_loop_counters():
+    ctx = KernelCreationContext()
+    factory = AstFactory(ctx)
+    canonicalize = CanonicalizeSymbols(ctx)
+
+    f = Field.create_generic("f", 2, index_shape=(1,))
+    g = Field.create_generic("g", 2, index_shape=(1,))
+    ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:, :], archetype_field=f)
+    ctx.set_iteration_space(ispace)
+
+    asm = Assignment(f.center(0), 2 * g.center(0))
+
+    body = PsBlock([factory.parse_sympy(asm)])
+
+    loops = factory.loops_from_ispace(ispace, body)
+
+    loops_copy = loops.clone()
+
+    ast = PsBlock([loops, loops_copy])
+
+    ast = canonicalize(ast)
+
+    assert loops_copy.counter.symbol.name == "ctr_0"
+    assert not loops_copy.counter.symbol.get_dtype().const
+    assert loops.counter.symbol.name == "ctr_0__0"
+    assert not loops.counter.symbol.get_dtype().const
-- 
GitLab