From 136f7c95fd25637400818c296d98e85e788a6c49 Mon Sep 17 00:00:00 2001
From: Markus Holzer <markus.holzer@fau.de>
Date: Tue, 8 Feb 2022 15:38:06 +0100
Subject: [PATCH] Fix Neon vector instruction set

---
 pystencils/backends/simd_instruction_sets.py    | 3 ++-
 pystencils/cpu/vectorization.py                 | 2 +-
 pystencils_tests/test_vectorization.py          | 7 ++-----
 pystencils_tests/test_vectorization_specific.py | 2 +-
 4 files changed, 6 insertions(+), 8 deletions(-)

diff --git a/pystencils/backends/simd_instruction_sets.py b/pystencils/backends/simd_instruction_sets.py
index 8ac0beeb7..f2df61963 100644
--- a/pystencils/backends/simd_instruction_sets.py
+++ b/pystencils/backends/simd_instruction_sets.py
@@ -98,12 +98,13 @@ def get_cacheline_size(instruction_set):
         return _cachelinesize
     
     import pystencils as ps
+    from pystencils.astnodes import SympyAssignment
     import numpy as np
     from pystencils.cpu.vectorization import CachelineSize
     
     arr = np.zeros((1, 1), dtype=np.float32)
     f = ps.Field.create_from_numpy_array('f', arr, index_dimensions=0)
-    ass = [CachelineSize(), ps.Assignment(f.center, CachelineSize.symbol)]
+    ass = [CachelineSize(), SympyAssignment(f.center, CachelineSize.symbol)]
     ast = ps.create_kernel(ass, cpu_vectorize_info={'instruction_set': instruction_set})
     kernel = ast.compile()
     kernel(**{f.name: arr, CachelineSize.symbol.name: 0})
diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py
index ac25639b1..6d59be9b0 100644
--- a/pystencils/cpu/vectorization.py
+++ b/pystencils/cpu/vectorization.py
@@ -254,7 +254,7 @@ def insert_vector_casts(ast_node, instruction_set, default_float_type='double'):
     """Inserts necessary casts from scalar values to vector values."""
 
     handled_functions = (sp.Add, sp.Mul, fast_division, fast_sqrt, fast_inv_sqrt, vec_any, vec_all, DivFunc,
-                         sp.UnevaluatedExpr)
+                         sp.UnevaluatedExpr, sp.Abs)
 
     def visit_expr(expr, default_type='double'):  # TODO get rid of default_type
         if isinstance(expr, VectorMemoryAccess):
diff --git a/pystencils_tests/test_vectorization.py b/pystencils_tests/test_vectorization.py
index 8b685c28a..c058691b7 100644
--- a/pystencils_tests/test_vectorization.py
+++ b/pystencils_tests/test_vectorization.py
@@ -42,7 +42,8 @@ def test_vector_type_propagation(instruction_set=instruction_set):
     np.testing.assert_equal(dst[1:-1, 1:-1], 2 * 10.0 + 3)
 
 
-def test_aligned_and_nt_stores(instruction_set=instruction_set, openmp=False):
+@pytest.mark.parametrize('openmp', [True, False])
+def test_aligned_and_nt_stores(openmp, instruction_set=instruction_set):
     domain_size = (24, 24)
     # create a datahandling object
     dh = ps.create_data_handling(domain_size, periodicity=(True, True), parallel=False, default_target=Target.CPU)
@@ -76,10 +77,6 @@ def test_aligned_and_nt_stores(instruction_set=instruction_set, openmp=False):
     np.testing.assert_equal(np.sum(dh.cpu_arrays['f']), np.prod(domain_size))
 
 
-def test_aligned_and_nt_stores_openmp(instruction_set=instruction_set):
-    test_aligned_and_nt_stores(instruction_set, True)
-
-
 def test_inplace_update(instruction_set=instruction_set):
     shape = (9, 9, 3)
     arr = np.ones(shape, order='f')
diff --git a/pystencils_tests/test_vectorization_specific.py b/pystencils_tests/test_vectorization_specific.py
index 3a5697066..e118930b0 100644
--- a/pystencils_tests/test_vectorization_specific.py
+++ b/pystencils_tests/test_vectorization_specific.py
@@ -128,7 +128,7 @@ def test_cacheline_size(instruction_set):
 @pytest.mark.parametrize('instruction_set',
                          sorted(set(supported_instruction_sets) - {test_vectorization.instruction_set}))
 @pytest.mark.parametrize('function',
-                         [f for f in test_vectorization.__dict__ if f.startswith('test_') and f != 'test_hardware_query'])
+                         [f for f in test_vectorization.__dict__ if f.startswith('test_') and f not in ['test_hardware_query', 'test_aligned_and_nt_stores']])
 def test_vectorization_other(instruction_set, function):
     test_vectorization.__dict__[function](instruction_set)
 
-- 
GitLab