From a4d1627578ffc21e58f7069ac2a9c8c39d81bd35 Mon Sep 17 00:00:00 2001
From: Michael Kuron <m.kuron@gmx.de>
Date: Sat, 2 Jul 2022 16:45:46 +0200
Subject: [PATCH] Fix nontemporal stores on non-x86 for fields with variable
 size

---
 pystencils/astnodes.py                 |  6 ++++++
 pystencils_tests/test_vectorization.py | 11 +++++++++++
 2 files changed, 17 insertions(+)

diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py
index f3ed2711c..79222a05a 100644
--- a/pystencils/astnodes.py
+++ b/pystencils/astnodes.py
@@ -619,7 +619,13 @@ class SympyAssignment(Node):
                 for i in range(len(symbol.offsets)):
                     loop_counters.add(LoopOverCoordinate.get_loop_counter_symbol(i))
         result.update(loop_counters)
+        
         result.update(self._lhs_symbol.atoms(sp.Symbol))
+        
+        sizes = set().union(*(a.field.shape for a in self._lhs_symbol.atoms(ResolvedFieldAccess)))
+        sizes = filter(lambda s: isinstance(s, FieldShapeSymbol), sizes)
+        result.update(sizes)
+        
         return result
 
     @property
diff --git a/pystencils_tests/test_vectorization.py b/pystencils_tests/test_vectorization.py
index f526341ec..7b97b7b0a 100644
--- a/pystencils_tests/test_vectorization.py
+++ b/pystencils_tests/test_vectorization.py
@@ -75,6 +75,17 @@ def test_aligned_and_nt_stores(openmp, instruction_set=instruction_set):
     np.testing.assert_equal(np.sum(dh.cpu_arrays['f']), np.prod(domain_size))
 
 
+def test_nt_stores_symbolic_size(instruction_set=instruction_set):
+    f, g = ps.fields('f, g: [2D]', layout='fzyx')
+    update_rule = [ps.Assignment(f.center(), 0.0), ps.Assignment(g.center(), 0.0)]
+    opt = {'instruction_set': instruction_set, 'assume_aligned': True, 'nontemporal': True,
+           'assume_inner_stride_one': True}
+    config = pystencils.config.CreateKernelConfig(target=Target.CPU, cpu_vectorize_info=opt)
+    ast = ps.create_kernel(update_rule, config=config)
+    # ps.show_code(ast)
+    ast.compile()
+
+
 def test_inplace_update(instruction_set=instruction_set):
     shape = (9, 9, 3)
     arr = np.ones(shape, order='f')
-- 
GitLab