diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py index 420e57c709e1697885ea686fe2596ad4c417b5ff..ae4c0e38894043326548efb63d08c6d79103486c 100644 --- a/pystencils/cpu/vectorization.py +++ b/pystencils/cpu/vectorization.py @@ -257,22 +257,24 @@ def insert_vector_casts(ast_node, instruction_set, default_float_type='double'): handled_functions = (sp.Add, sp.Mul, vec_any, vec_all, DivFunc, sp.Abs) - def visit_expr(expr, default_type='double'): # TODO Vectorization Revamp: get rid of default_type + # TODO Vectorization Revamp: get rid of default_type + def visit_expr(expr, default_type='double', force_vectorize=False): if isinstance(expr, VectorMemoryAccess): - return VectorMemoryAccess(*expr.args[0:4], visit_expr(expr.args[4], default_type), *expr.args[5:]) + return VectorMemoryAccess(*expr.args[0:4], visit_expr(expr.args[4], default_type, force_vectorize), + *expr.args[5:]) elif isinstance(expr, CastFunc): cast_type = expr.args[1] - arg = visit_expr(expr.args[0]) + arg = visit_expr(expr.args[0], default_type, force_vectorize) assert cast_type in [BasicType('float32'), BasicType('float64')],\ f'Vectorization cannot vectorize type {cast_type}' return expr.func(arg, VectorType(cast_type, instruction_set['width'])) elif expr.func is sp.Abs and 'abs' not in instruction_set: - new_arg = visit_expr(expr.args[0], default_type) + new_arg = visit_expr(expr.args[0], default_type, force_vectorize) base_type = get_type_of_expression(expr.args[0]).base_type if type(expr.args[0]) is VectorMemoryAccess \ else get_type_of_expression(expr.args[0]) pw = sp.Piecewise((-new_arg, new_arg < CastFunc(0, base_type.numpy_dtype)), (new_arg, True)) - return visit_expr(pw, default_type) + return visit_expr(pw, default_type, force_vectorize) elif expr.func in handled_functions or isinstance(expr, sp.Rel) or isinstance(expr, BooleanFunction): if expr.func is sp.Mul and expr.args[0] == -1: # special treatment for the unary minus: make sure that the -1 has the same type as the argument @@ -287,7 +289,7 @@ def insert_vector_casts(ast_node, instruction_set, default_float_type='double'): if dtype is np.float32: default_type = 'float' expr = sp.Mul(dtype(expr.args[0]), *expr.args[1:]) - new_args = [visit_expr(a, default_type) for a in expr.args] + new_args = [visit_expr(a, default_type, force_vectorize) for a in expr.args] arg_types = [get_type_of_expression(a, default_float_type=default_type) for a in new_args] if not any(type(t) is VectorType for t in arg_types): return expr @@ -306,7 +308,7 @@ def insert_vector_casts(ast_node, instruction_set, default_float_type='double'): exp = expr.args[0].exp expr = sp.UnevaluatedExpr(sp.Mul(*([base] * +exp), evaluate=False)) - new_args = [visit_expr(a, default_type) for a in expr.args[0].args] + new_args = [visit_expr(a, default_type, force_vectorize) for a in expr.args[0].args] arg_types = [get_type_of_expression(a, default_float_type=default_type) for a in new_args] target_type = collate_types(arg_types) @@ -318,11 +320,11 @@ def insert_vector_casts(ast_node, instruction_set, default_float_type='double'): for a, t in zip(new_args, arg_types)] return expr.func(expr.args[0].func(*casted_args, evaluate=False)) elif expr.func is sp.Pow: - new_arg = visit_expr(expr.args[0], default_type) + new_arg = visit_expr(expr.args[0], default_type, force_vectorize) return expr.func(new_arg, expr.args[1]) elif expr.func == sp.Piecewise: - new_results = [visit_expr(a[0], default_type) for a in expr.args] - new_conditions = [visit_expr(a[1], default_type) for a in expr.args] + new_results = [visit_expr(a[0], default_type, force_vectorize) for a in expr.args] + new_conditions = [visit_expr(a[1], default_type, force_vectorize) for a in expr.args] types_of_results = [get_type_of_expression(a) for a in new_results] types_of_conditions = [get_type_of_expression(a) for a in new_conditions] @@ -341,7 +343,14 @@ def insert_vector_casts(ast_node, instruction_set, default_float_type='double'): for a, t in zip(new_conditions, types_of_conditions)] return sp.Piecewise(*[(r, c) for r, c in zip(casted_results, casted_conditions)]) - elif isinstance(expr, (sp.Number, TypedSymbol, BooleanAtom)): + elif isinstance(expr, TypedSymbol): + if force_vectorize: + expr_type = get_type_of_expression(expr) + if type(expr_type) is not VectorType: + vector_type = VectorType(expr_type, instruction_set['width']) + return CastFunc(expr, vector_type) + return expr + elif isinstance(expr, (sp.Number, BooleanAtom)): return expr else: raise NotImplementedError(f'Due to defensive programming we handle only specific expressions.\n' @@ -357,11 +366,18 @@ def insert_vector_casts(ast_node, instruction_set, default_float_type='double'): # continue subs_expr = fast_subs(assignment.rhs, substitution_dict, skip=lambda e: isinstance(e, ast.ResolvedFieldAccess)) - assignment.rhs = visit_expr(subs_expr, default_type) - rhs_type = get_type_of_expression(assignment.rhs) + + # If either side contains a vectorized subexpression, both sides + # must be fully vectorized. + lhs_type = get_type_of_expression(assignment.lhs) + rhs_type = get_type_of_expression(subs_expr) + lhs_vectorized = type(lhs_type) is VectorType + rhs_vectorized = type(rhs_type) is VectorType + + assignment.rhs = visit_expr(subs_expr, default_type, force_vectorize=lhs_vectorized or rhs_vectorized) + if isinstance(assignment.lhs, TypedSymbol): - lhs_type = assignment.lhs.dtype - if type(rhs_type) is VectorType and type(lhs_type) is not VectorType: + if rhs_vectorized and not lhs_vectorized: new_lhs_type = VectorType(lhs_type, rhs_type.width) new_lhs = TypedSymbol(assignment.lhs.name, new_lhs_type) substitution_dict[assignment.lhs] = new_lhs diff --git a/pystencils_tests/test_vectorization.py b/pystencils_tests/test_vectorization.py index 3718770b9384e0f01dd899c3afed298208fc95a9..83320a91f60677e05ab484f7e83b428df47adc64 100644 --- a/pystencils_tests/test_vectorization.py +++ b/pystencils_tests/test_vectorization.py @@ -6,6 +6,7 @@ import pystencils.config import sympy as sp import pystencils as ps +import pystencils.astnodes as ast from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets, get_vector_instruction_set from pystencils.cpu.vectorization import vectorize from pystencils.enums import Target @@ -40,6 +41,47 @@ def test_vector_type_propagation(instruction_set=instruction_set): np.testing.assert_equal(dst[1:-1, 1:-1], 2 * 10.0 + 3) +def test_vectorize_moved_constants1(instruction_set=instruction_set): + opt = {'instruction_set': instruction_set, 'assume_inner_stride_one': True} + + f = ps.fields("f: [1D]") + x = ast.TypedSymbol("x", np.float64) + + kernel_func = ps.create_kernel( + [ast.SympyAssignment(x, 2.0), ast.SympyAssignment(f[0], x)], + cpu_prepend_optimizations=[ps.transformations.move_constants_before_loop], # explicitly move constants + cpu_vectorize_info=opt, + ) + ps.show_code(kernel_func) # fails if `x` on rhs was not correctly vectorized + kernel = kernel_func.compile() + + f_arr = np.zeros(9) + kernel(f=f_arr) + + assert(np.all(f_arr == 2)) + + +def test_vectorize_moved_constants2(instruction_set=instruction_set): + opt = {'instruction_set': instruction_set, 'assume_inner_stride_one': True} + + f = ps.fields("f: [1D]") + x = ast.TypedSymbol("x", np.float64) + y = ast.TypedSymbol("y", np.float64) + + kernel_func = ps.create_kernel( + [ast.SympyAssignment(x, 2.0), ast.SympyAssignment(y, 3.0), ast.SympyAssignment(f[0], x + y)], + cpu_prepend_optimizations=[ps.transformations.move_constants_before_loop], # explicitly move constants + cpu_vectorize_info=opt, + ) + ps.show_code(kernel_func) # fails if `x` on rhs was not correctly vectorized + kernel = kernel_func.compile() + + f_arr = np.zeros(9) + kernel(f=f_arr) + + assert(np.all(f_arr == 5)) + + @pytest.mark.parametrize('openmp', [True, False]) def test_aligned_and_nt_stores(openmp, instruction_set=instruction_set): domain_size = (24, 24)