From e46f1658088717c69c3f04e14732472d37ae588a Mon Sep 17 00:00:00 2001
From: Michael Kuron <mkuron@icp.uni-stuttgart.de>
Date: Fri, 19 Feb 2021 18:49:28 +0100
Subject: [PATCH] improve detection of unary minus

---
 pystencils/cpu/vectorization.py | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py
index 4c632b145..51fa1a807 100644
--- a/pystencils/cpu/vectorization.py
+++ b/pystencils/cpu/vectorization.py
@@ -209,12 +209,11 @@ def insert_vector_casts(ast_node):
             if expr.func is sp.Mul and expr.args[0] == -1:
                 # special treatment for the unary minus: make sure that the -1 has the same type as the argument
                 dtype = int
-                for arg in expr.args[1:]:
-                    if type(arg) is sp.Pow:
-                        arg = arg.args[0]
-                    if type(arg) is vector_memory_access and arg.dtype.base_type.is_float():
+                for arg in expr.atoms(vector_memory_access):
+                    if arg.dtype.base_type.is_float():
                         dtype = arg.dtype.base_type.numpy_dtype.type
-                    elif type(arg) is TypedSymbol and type(arg.dtype) is VectorType and arg.dtype.base_type.is_float():
+                for arg in expr.atoms(TypedSymbol):
+                    if type(arg.dtype) is VectorType and arg.dtype.base_type.is_float():
                         dtype = arg.dtype.base_type.numpy_dtype.type
                 if dtype is not int:
                     if dtype is np.float32:
-- 
GitLab