Skip to content
Snippets Groups Projects
Commit 39c214af authored by Michael Kuron's avatar Michael Kuron :mortar_board: Committed by Markus Holzer
Browse files

vectorization: improve treatment of unary minus

parent a09d1c86
No related merge requests found
......@@ -546,6 +546,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
elif isinstance(expr, cast_func):
arg, data_type = expr.args
if type(data_type) is VectorType:
# vector_memory_access is a cast_func itself so it should't be directly inside a cast_func
assert not isinstance(arg, vector_memory_access)
if isinstance(arg, sp.Tuple):
is_boolean = get_type_of_expression(arg[0]) == create_type("bool")
is_integer = get_type_of_expression(arg[0]) == create_type("int")
......
import warnings
from typing import Container, Union
import numpy as np
import sympy as sp
from sympy.logic.boolalg import BooleanFunction
......@@ -160,7 +161,6 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
def mask_conditionals(loop_body):
def visit_node(node, mask):
if isinstance(node, ast.Conditional):
cond = node.condition_expr
......@@ -201,18 +201,32 @@ def insert_vector_casts(ast_node):
new_arg = visit_expr(expr.args[0])
base_type = get_type_of_expression(expr.args[0]).base_type if type(expr.args[0]) is vector_memory_access \
else get_type_of_expression(expr.args[0])
pw = sp.Piecewise((base_type.numpy_dtype.type(-1) * new_arg, new_arg < base_type.numpy_dtype.type(0)),
pw = sp.Piecewise((-new_arg, new_arg < base_type.numpy_dtype.type(0)),
(new_arg, True))
return visit_expr(pw)
elif expr.func in handled_functions or isinstance(expr, sp.Rel) or isinstance(expr, BooleanFunction):
default_type = 'double'
if expr.func is sp.Mul and expr.args[0] == -1:
# special treatment for the unary minus: make sure that the -1 has the same type as the argument
dtype = int
for arg in expr.args[1:]:
if type(arg) is vector_memory_access and arg.dtype.base_type.is_float():
dtype = arg.dtype.base_type.numpy_dtype.type
elif type(arg) is TypedSymbol and type(arg.dtype) is VectorType and arg.dtype.base_type.is_float():
dtype = arg.dtype.base_type.numpy_dtype.type
if dtype is not int:
if dtype is np.float32:
default_type = 'float'
expr = sp.Mul(dtype(expr.args[0]), *expr.args[1:])
new_args = [visit_expr(a) for a in expr.args]
arg_types = [get_type_of_expression(a) for a in new_args]
arg_types = [get_type_of_expression(a, default_float_type=default_type) for a in new_args]
if not any(type(t) is VectorType for t in arg_types):
return expr
else:
target_type = collate_types(arg_types)
casted_args = [cast_func(a, target_type) if t != target_type else a
for a, t in zip(new_args, arg_types)]
casted_args = [
cast_func(a, target_type) if t != target_type and not isinstance(a, vector_memory_access) else a
for a, t in zip(new_args, arg_types)]
return expr.func(*casted_args)
elif expr.func is sp.Pow:
new_arg = visit_expr(expr.args[0])
......
......@@ -502,6 +502,9 @@ def get_type_of_expression(expr,
from pystencils.astnodes import ResolvedFieldAccess
from pystencils.cpu.vectorization import vec_all, vec_any
if default_float_type == 'float':
default_float_type = 'float32'
if not symbol_type_dict:
symbol_type_dict = defaultdict(lambda: create_type('double'))
......
......@@ -36,6 +36,8 @@ QUALIFIERS __m128d _my_cvtepu64_pd(const __m128i x)
{
#ifdef __AVX512VL__
return _mm_cvtepu64_pd(x);
#elif defined(__clang__)
return __builtin_convertvector((uint64_t __attribute__((__vector_size__(16)))) x, __m128d);
#else
__m128i xH = _mm_srli_epi64(x, 32);
xH = _mm_or_si128(xH, _mm_castpd_si128(_mm_set1_pd(19342813113834066795298816.))); // 2^84
......@@ -85,6 +87,8 @@ QUALIFIERS __m256d _my256_cvtepu64_pd(const __m256i x)
{
#ifdef __AVX512VL__
return _mm256_cvtepu64_pd(x);
#elif defined(__clang__)
return __builtin_convertvector((uint64_t __attribute__((__vector_size__(32)))) x, __m256d);
#else
__m256i xH = _mm256_srli_epi64(x, 32);
xH = _mm256_or_si256(xH, _mm256_castpd_si256(_mm256_set1_pd(19342813113834066795298816.))); // 2^84
......
......@@ -73,7 +73,7 @@ def test_rng(target, rng, precision, dtype, t=124, offsets=(0, 0), keys=(0, 0),
for x in range(dh.shape[0]):
for y in range(dh.shape[1]):
r = Philox(counter=t + (x + offset_values[0]) * 2 ** 32 + (y + offset_values[1]) * 2 ** 64,
key=keys[0] + keys[1] * 2 ** 32, number=4, width=32)
key=keys[0] + keys[1] * 2 ** 32, number=4, width=32, mode="sequence")
r.advance(-4, counter=False)
int_reference[x, y, :] = r.random_raw(size=4)
......
......@@ -33,22 +33,6 @@ def test_vector_type_propagation():
np.testing.assert_equal(dst[1:-1, 1:-1], 2 * 10.0 + 3)
def test_vectorized_abs():
arr = np.ones((2 ** 2 + 2, 2 ** 3 + 2))
arr[-3:, :] = -1
f, g = ps.fields(f=arr, g=arr)
update_rule = [ps.Assignment(g.center(), sp.Abs(f.center()))]
ast = ps.create_kernel(update_rule)
vectorize(ast, instruction_set=instruction_set)
func = ast.compile()
dst = np.zeros_like(arr)
func(g=dst, f=arr)
np.testing.assert_equal(np.sum(dst[1:-1, 1:-1]), 2 ** 2 * 2 ** 3)
def test_aligned_and_nt_stores():
domain_size = (24, 24)
# create a datahandling object
......
import pytest
import numpy as np
import sympy as sp
import pystencils as ps
from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets
......@@ -28,3 +29,23 @@ def test_vectorisation_varying_arch(instruction_set):
kernel = ast.compile()
kernel(f=arr)
np.testing.assert_equal(arr, 2)
@pytest.mark.parametrize('dtype', ('float', 'double'))
@pytest.mark.parametrize('instruction_set', supported_instruction_sets)
def test_vectorized_abs(instruction_set, dtype):
"""Some instructions sets have abs, some don't.
Furthermore, the special treatment of unary minus makes this data type-sensitive too.
"""
arr = np.ones((2 ** 2 + 2, 2 ** 3 + 2), dtype=np.float64 if dtype == 'double' else np.float32)
arr[-3:, :] = -1
f, g = ps.fields(f=arr, g=arr)
update_rule = [ps.Assignment(g.center(), sp.Abs(f.center()))]
ast = ps.create_kernel(update_rule, cpu_vectorize_info={'instruction_set': instruction_set})
func = ast.compile()
dst = np.zeros_like(arr)
func(g=dst, f=arr)
np.testing.assert_equal(np.sum(dst[1:-1, 1:-1]), 2 ** 2 * 2 ** 3)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment