Skip to content
Snippets Groups Projects

vectorization: improve treatment of unary minus

Merged Michael Kuron requested to merge philox-simd into master
Viewing commit 6af845a8
Show latest version
2 files
+ 10
2
Preferences
Compare changes
Files
2
import warnings
import warnings
from typing import Container, Union
from typing import Container, Union
 
import numpy as np
import sympy as sp
import sympy as sp
from sympy.logic.boolalg import BooleanFunction
from sympy.logic.boolalg import BooleanFunction
@@ -205,6 +206,7 @@ def insert_vector_casts(ast_node):
@@ -205,6 +206,7 @@ def insert_vector_casts(ast_node):
(new_arg, True))
(new_arg, True))
return visit_expr(pw)
return visit_expr(pw)
elif expr.func in handled_functions or isinstance(expr, sp.Rel) or isinstance(expr, BooleanFunction):
elif expr.func in handled_functions or isinstance(expr, sp.Rel) or isinstance(expr, BooleanFunction):
 
default_type = 'double'
if expr.func is sp.Mul and expr.args[0] == -1:
if expr.func is sp.Mul and expr.args[0] == -1:
# special treatment for the unary minus: make sure that the -1 has the same type as the argument
# special treatment for the unary minus: make sure that the -1 has the same type as the argument
dtype = int
dtype = int
@@ -213,9 +215,12 @@ def insert_vector_casts(ast_node):
@@ -213,9 +215,12 @@ def insert_vector_casts(ast_node):
dtype = arg.dtype.base_type.numpy_dtype.type
dtype = arg.dtype.base_type.numpy_dtype.type
elif type(arg) is TypedSymbol and type(arg.dtype) is VectorType and arg.dtype.base_type.is_float():
elif type(arg) is TypedSymbol and type(arg.dtype) is VectorType and arg.dtype.base_type.is_float():
dtype = arg.dtype.base_type.numpy_dtype.type
dtype = arg.dtype.base_type.numpy_dtype.type
expr = sp.Mul(dtype(expr.args[0]), *expr.args[1:])
if dtype is not int:
 
if dtype is np.float32:
 
default_type = 'float'
 
expr = sp.Mul(dtype(expr.args[0]), *expr.args[1:])
new_args = [visit_expr(a) for a in expr.args]
new_args = [visit_expr(a) for a in expr.args]
arg_types = [get_type_of_expression(a) for a in new_args]
arg_types = [get_type_of_expression(a, default_float_type=default_type) for a in new_args]
if not any(type(t) is VectorType for t in arg_types):
if not any(type(t) is VectorType for t in arg_types):
return expr
return expr
else:
else: