diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py index 4c632b1457c1e2778953a55d4378ee87f05ffb82..51fa1a807db17b8ba58fda8bf080658015b33ef8 100644 --- a/pystencils/cpu/vectorization.py +++ b/pystencils/cpu/vectorization.py @@ -209,12 +209,11 @@ def insert_vector_casts(ast_node): if expr.func is sp.Mul and expr.args[0] == -1: # special treatment for the unary minus: make sure that the -1 has the same type as the argument dtype = int - for arg in expr.args[1:]: - if type(arg) is sp.Pow: - arg = arg.args[0] - if type(arg) is vector_memory_access and arg.dtype.base_type.is_float(): + for arg in expr.atoms(vector_memory_access): + if arg.dtype.base_type.is_float(): dtype = arg.dtype.base_type.numpy_dtype.type - elif type(arg) is TypedSymbol and type(arg.dtype) is VectorType and arg.dtype.base_type.is_float(): + for arg in expr.atoms(TypedSymbol): + if type(arg.dtype) is VectorType and arg.dtype.base_type.is_float(): dtype = arg.dtype.base_type.numpy_dtype.type if dtype is not int: if dtype is np.float32: