Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Showing
with 4359 additions and 0 deletions
from pystencils.gpu.gpu_array_handler import GPUArrayHandler, GPUNotAvailableHandler
from pystencils.gpu.gpujit import make_python_function
from pystencils.gpu.kernelcreation import create_cuda_kernel, created_indexed_cuda_kernel
from .indexing import AbstractIndexing, BlockIndexing, LineIndexing
__all__ = ['GPUArrayHandler', 'GPUNotAvailableHandler',
'create_cuda_kernel', 'created_indexed_cuda_kernel', 'make_python_function',
'AbstractIndexing', 'BlockIndexing', 'LineIndexing']
try:
import cupy as cp
import cupyx as cpx
except ImportError:
cp = None
cpx = None
import numpy as np
class GPUArrayHandler:
def __init__(self, device_number):
self._device_number = device_number
def zeros(self, shape, dtype=np.float64, order='C'):
with cp.cuda.Device(self._device_number):
return cp.zeros(shape=shape, dtype=dtype, order=order)
def ones(self, shape, dtype=np.float64, order='C'):
with cp.cuda.Device(self._device_number):
return cp.ones(shape=shape, dtype=dtype, order=order)
def empty(self, shape, dtype=np.float64, order='C'):
with cp.cuda.Device(self._device_number):
return cp.empty(shape=shape, dtype=dtype, order=order)
def to_gpu(self, numpy_array):
swaps = _get_index_swaps(numpy_array)
if numpy_array.base is not None and isinstance(numpy_array.base, np.ndarray):
with cp.cuda.Device(self._device_number):
gpu_array = cp.asarray(numpy_array.base)
for a, b in reversed(swaps):
gpu_array = gpu_array.swapaxes(a, b)
return gpu_array
else:
return cp.asarray(numpy_array)
def upload(self, array, numpy_array):
assert self._device_number == array.device.id
if numpy_array.base is not None and isinstance(numpy_array.base, np.ndarray):
with cp.cuda.Device(self._device_number):
array.base.set(numpy_array.base)
else:
with cp.cuda.Device(self._device_number):
array.set(numpy_array)
def download(self, array, numpy_array):
assert self._device_number == array.device.id
if numpy_array.base is not None and isinstance(numpy_array.base, np.ndarray):
with cp.cuda.Device(self._device_number):
numpy_array.base[:] = array.base.get()
else:
with cp.cuda.Device(self._device_number):
numpy_array[:] = array.get()
def randn(self, shape, dtype=np.float64):
with cp.cuda.Device(self._device_number):
return cp.random.randn(*shape, dtype=dtype)
@staticmethod
def pinned_numpy_array(layout, shape, dtype):
assert set(layout) == set(range(len(shape))), "Wrong layout descriptor"
cur_layout = list(range(len(shape)))
swaps = []
for i in range(len(layout)):
if cur_layout[i] != layout[i]:
index_to_swap_with = cur_layout.index(layout[i])
swaps.append((i, index_to_swap_with))
cur_layout[i], cur_layout[index_to_swap_with] = cur_layout[index_to_swap_with], cur_layout[i]
assert tuple(cur_layout) == tuple(layout)
shape = list(shape)
for a, b in swaps:
shape[a], shape[b] = shape[b], shape[a]
res = cpx.empty_pinned(tuple(shape), order='c', dtype=dtype)
for a, b in reversed(swaps):
res = res.swapaxes(a, b)
return res
from_numpy = to_gpu
class GPUNotAvailableHandler:
def __getattribute__(self, name):
raise NotImplementedError("Unable to utilise cupy! Please make sure cupy works correctly in your setup!")
def _get_index_swaps(array):
swaps = []
if array.base is not None and isinstance(array.base, np.ndarray):
for stride in array.base.strides:
index_base = array.base.strides.index(stride)
index_view = array.strides.index(stride)
if index_base != index_view and (index_view, index_base) not in swaps:
swaps.append((index_base, index_view))
return swaps
import numpy as np
from pystencils.backends.cbackend import get_headers
from pystencils.backends.cuda_backend import generate_cuda
from pystencils.typing import StructType
from pystencils.field import FieldType
from pystencils.include import get_pystencils_include_path
from pystencils.kernel_wrapper import KernelWrapper
from pystencils.typing import BasicType, FieldPointerSymbol
USE_FAST_MATH = True
def get_cubic_interpolation_include_paths():
from os.path import join, dirname
return [join(dirname(__file__), "CubicInterpolationCUDA", "code"),
join(dirname(__file__), "CubicInterpolationCUDA", "code", "internal")]
def make_python_function(kernel_function_node, argument_dict=None, custom_backend=None):
"""
Creates a kernel function from an abstract syntax tree which
was created e.g. by :func:`pystencils.gpu.create_cuda_kernel`
or :func:`pystencils.gpu.created_indexed_cuda_kernel`
Args:
kernel_function_node: the abstract syntax tree
argument_dict: parameters passed here are already fixed. Remaining parameters have to be passed to the
returned kernel functor.
custom_backend: use own custom printer for code generation
Returns:
compiled kernel as Python function
"""
import cupy as cp
if argument_dict is None:
argument_dict = {}
headers = get_headers(kernel_function_node)
if cp.cuda.runtime.is_hip:
headers.add('"gpu_defines.h"')
for field in kernel_function_node.fields_accessed:
if isinstance(field.dtype, BasicType) and field.dtype.is_half():
headers.add('<hip/hip_fp16.h>')
else:
headers.update({'"gpu_defines.h"', '<cstdint>'})
for field in kernel_function_node.fields_accessed:
if isinstance(field.dtype, BasicType) and field.dtype.is_half():
headers.add('<cuda_fp16.h>')
header_list = sorted(headers)
includes = "\n".join([f"#include {include_file}" for include_file in header_list])
code = includes + "\n"
code += "#define FUNC_PREFIX __global__\n"
code += "#define RESTRICT __restrict__\n\n"
code += 'extern "C" {\n%s\n}\n' % str(generate_cuda(kernel_function_node, custom_backend=custom_backend))
options = ["-w", "-std=c++11"]
if USE_FAST_MATH:
options.append("-use_fast_math")
options.append("-I" + get_pystencils_include_path())
func = cp.RawKernel(code, kernel_function_node.function_name, options=tuple(options), backend="nvrtc", jitify=True)
parameters = kernel_function_node.get_parameters()
cache = {}
cache_values = []
def wrapper(**kwargs):
key = hash(tuple((k, v.ctypes.data, v.strides, v.shape) if isinstance(v, np.ndarray) else (k, id(v))
for k, v in kwargs.items()))
try:
args, block_and_thread_numbers = cache[key]
device = set(a.device.id for a in args if type(a) is cp.ndarray)
assert len(device) == 1, "All arrays used by a kernel need to be allocated on the same device"
with cp.cuda.Device(device.pop()):
func(block_and_thread_numbers['grid'], block_and_thread_numbers['block'], args)
except KeyError:
full_arguments = argument_dict.copy()
full_arguments.update(kwargs)
shape = _check_arguments(parameters, full_arguments)
indexing = kernel_function_node.indexing
block_and_thread_numbers = indexing.call_parameters(shape)
block_and_thread_numbers['block'] = tuple(int(i) for i in block_and_thread_numbers['block'])
block_and_thread_numbers['grid'] = tuple(int(i) for i in block_and_thread_numbers['grid'])
args = tuple(_build_numpy_argument_list(parameters, full_arguments))
cache[key] = (args, block_and_thread_numbers)
cache_values.append(kwargs) # keep objects alive such that ids remain unique
device = set(a.device.id for a in args if type(a) is cp.ndarray)
assert len(device) == 1, "All arrays used by a kernel need to be allocated on the same device"
with cp.cuda.Device(device.pop()):
func(block_and_thread_numbers['grid'], block_and_thread_numbers['block'], args)
# useful for debugging:
# cp.cuda.runtime.deviceSynchronize()
# cuda.Context.synchronize() # useful for debugging, to get errors right after kernel was called
ast = kernel_function_node
parameters = kernel_function_node.get_parameters()
wrapper = KernelWrapper(wrapper, parameters, ast)
wrapper.num_regs = func.num_regs
return wrapper
def _build_numpy_argument_list(parameters, argument_dict):
argument_dict = {k: v for k, v in argument_dict.items()}
result = []
for param in parameters:
if param.is_field_pointer:
array = argument_dict[param.field_name]
actual_type = array.dtype
expected_type = param.fields[0].dtype.numpy_dtype
if expected_type != actual_type:
raise ValueError(f"Data type mismatch for field {param.field_name}. "
f"Expected {expected_type} got {actual_type}.")
result.append(array)
elif param.is_field_stride:
cast_to_dtype = param.symbol.dtype.numpy_dtype.type
array = argument_dict[param.field_name]
stride = cast_to_dtype(array.strides[param.symbol.coordinate] // array.dtype.itemsize)
result.append(stride)
elif param.is_field_shape:
cast_to_dtype = param.symbol.dtype.numpy_dtype.type
array = argument_dict[param.field_name]
result.append(cast_to_dtype(array.shape[param.symbol.coordinate]))
else:
expected_type = param.symbol.dtype.numpy_dtype
result.append(expected_type.type(argument_dict[param.symbol.name]))
assert len(result) == len(parameters)
return result
def _check_arguments(parameter_specification, argument_dict):
"""
Checks if parameters passed to kernel match the description in the AST function node.
If not it raises a ValueError, on success it returns the array shape that determines the CUDA blocks and threads
"""
argument_dict = {k: v for k, v in argument_dict.items()}
array_shapes = set()
index_arr_shapes = set()
for param in parameter_specification:
if isinstance(param.symbol, FieldPointerSymbol):
symbolic_field = param.fields[0]
try:
field_arr = argument_dict[symbolic_field.name]
except KeyError:
raise KeyError(f"Missing field parameter for kernel call {str(symbolic_field)}")
if symbolic_field.has_fixed_shape:
symbolic_field_shape = tuple(int(i) for i in symbolic_field.shape)
if isinstance(symbolic_field.dtype, StructType):
symbolic_field_shape = symbolic_field_shape[:-1]
if symbolic_field_shape != field_arr.shape:
raise ValueError(f"Passed array {symbolic_field.name} has shape {str(field_arr.shape)} "
f"which does not match expected shape {str(symbolic_field.shape)}")
if symbolic_field.has_fixed_shape:
symbolic_field_strides = tuple(int(i) * field_arr.dtype.itemsize for i in symbolic_field.strides)
if isinstance(symbolic_field.dtype, StructType):
symbolic_field_strides = symbolic_field_strides[:-1]
if symbolic_field_strides != field_arr.strides:
raise ValueError(f"Passed array {symbolic_field.name} has strides {str(field_arr.strides)} "
f"which does not match expected strides {str(symbolic_field_strides)}")
if FieldType.is_indexed(symbolic_field):
index_arr_shapes.add(field_arr.shape[:symbolic_field.spatial_dimensions])
elif FieldType.is_generic(symbolic_field):
array_shapes.add(field_arr.shape[:symbolic_field.spatial_dimensions])
if len(array_shapes) > 1:
raise ValueError(f"All passed arrays have to have the same size {str(array_shapes)}")
if len(index_arr_shapes) > 1:
raise ValueError(f"All passed index arrays have to have the same size {str(array_shapes)}")
if len(index_arr_shapes) > 0:
return list(index_arr_shapes)[0]
else:
return list(array_shapes)[0]
import abc
from functools import partial
import math
from typing import List, Tuple
import sympy as sp
from sympy.core.cache import cacheit
from pystencils.astnodes import Block, Conditional, SympyAssignment
from pystencils.typing import TypedSymbol, create_type
from pystencils.integer_functions import div_ceil, div_floor
from pystencils.sympyextensions import is_integer_sequence, prod
class ThreadIndexingSymbol(TypedSymbol):
def __new__(cls, *args, **kwds):
obj = ThreadIndexingSymbol.__xnew_cached_(cls, *args, **kwds)
return obj
def __new_stage2__(cls, name, dtype, *args, **kwargs):
obj = super(ThreadIndexingSymbol, cls).__xnew__(cls, name, dtype, *args, **kwargs)
return obj
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
BLOCK_IDX = [ThreadIndexingSymbol("blockIdx." + coord, create_type("int32")) for coord in ('x', 'y', 'z')]
THREAD_IDX = [ThreadIndexingSymbol("threadIdx." + coord, create_type("int32")) for coord in ('x', 'y', 'z')]
BLOCK_DIM = [ThreadIndexingSymbol("blockDim." + coord, create_type("int32")) for coord in ('x', 'y', 'z')]
GRID_DIM = [ThreadIndexingSymbol("gridDim." + coord, create_type("int32")) for coord in ('x', 'y', 'z')]
class AbstractIndexing(abc.ABC):
"""
Abstract base class for all Indexing classes. An Indexing class defines how an iteration space is mapped
to GPU's block and grid system. It calculates indices based on GPU's thread and block indices
and computes the number of blocks and threads a kernel is started with.
The Indexing class is created with an iteration space that is given as list of slices to determine start, stop
and the step size for each coordinate. Further the data_layout is given as tuple to determine the fast and slow
coordinates. This is important to get an optimal mapping of coordinates to GPU threads.
"""
def __init__(self, iteration_space: Tuple[slice], data_layout: Tuple):
for iter_space in iteration_space:
assert isinstance(iter_space, slice), f"iteration_space must be of type Tuple[slice], " \
f"not tuple of type {type(iter_space)}"
self._iteration_space = iteration_space
self._data_layout = data_layout
self._dim = len(iteration_space)
@property
def iteration_space(self):
"""Iteration space to loop over"""
return self._iteration_space
@property
def data_layout(self):
"""Data layout of the kernels arrays"""
return self._data_layout
@property
def dim(self):
"""Number of spatial dimensions"""
return self._dim
@property
@abc.abstractmethod
def coordinates(self):
"""Returns a sequence of coordinate expressions for (x,y,z) depending on symbolic GPU block and thread indices.
These symbolic indices can be obtained with the method `index_variables` """
@property
def index_variables(self):
"""Sympy symbols for GPU's block and thread indices, and block and grid dimensions. """
return BLOCK_IDX + THREAD_IDX + BLOCK_DIM + GRID_DIM
@abc.abstractmethod
def get_loop_ctr_assignments(self, loop_counter_symbols) -> List[SympyAssignment]:
"""Adds assignments for the loop counter symbols depending on the gpu threads.
Args:
loop_counter_symbols: typed symbols representing the loop counters
Returns:
assignments for the loop counters
"""
@abc.abstractmethod
def call_parameters(self, arr_shape):
"""Determine grid and block size for kernel call.
Args:
arr_shape: the numeric (not symbolic) shape of the array
Returns:
dict with keys 'blocks' and 'threads' with tuple values for number of (x,y,z) threads and blocks
the kernel should be started with
"""
@abc.abstractmethod
def guard(self, kernel_content, arr_shape):
"""In some indexing schemes not all threads of a block execute the kernel content.
This function can return a Conditional ast node, defining this execution guard.
Args:
kernel_content: the actual kernel contents which can e.g. be put into the Conditional node as true block
arr_shape: the numeric or symbolic shape of the field
Returns:
ast node, which is put inside the kernel function
"""
@abc.abstractmethod
def max_threads_per_block(self):
"""Return maximal number of threads per block for launch bounds. If this cannot be determined without
knowing the array shape return None for unknown """
@abc.abstractmethod
def symbolic_parameters(self):
"""Set of symbols required in call_parameters code"""
# -------------------------------------------- Implementations ---------------------------------------------------------
class BlockIndexing(AbstractIndexing):
"""Generic indexing scheme that maps sub-blocks of an array to GPU blocks.
Args:
iteration_space: list of slices to determine start, stop and the step size for each coordinate
data_layout: tuple specifying loop order with innermost loop last.
This is the same format as returned by `Field.layout`.
permute_block_size_dependent_on_layout: if True the block_size is permuted such that the fastest coordinate
gets the largest amount of threads
compile_time_block_size: compile in concrete block size, otherwise the gpu variable 'blockDim' is used
maximum_block_size: maximum block size that is possible for the GPU. Set to 'auto' to let cupy define the
maximum block size from the device properties
device_number: device number of the used GPU. By default, the zeroth device is used.
"""
def __init__(self, iteration_space: Tuple[slice], data_layout: Tuple[int],
block_size=(128, 2, 1), permute_block_size_dependent_on_layout=True, compile_time_block_size=False,
maximum_block_size=(1024, 1024, 64), device_number=None):
super(BlockIndexing, self).__init__(iteration_space, data_layout)
if self._dim > 4:
raise NotImplementedError("This indexing scheme supports at most 4 spatial dimensions")
if permute_block_size_dependent_on_layout and self._dim < 4:
block_size = self.permute_block_size_according_to_layout(block_size, data_layout)
self._block_size = block_size
if maximum_block_size == 'auto':
assert device_number is not None, 'If "maximum_block_size" is set to "auto" a device number must be stated'
# Get device limits
import cupy as cp
# See https://github.com/cupy/cupy/issues/7676
if cp.cuda.runtime.is_hip:
maximum_block_size = tuple(cp.cuda.runtime.deviceGetAttribute(i, device_number) for i in range(26, 29))
else:
da = cp.cuda.Device(device_number).attributes
maximum_block_size = tuple(da[f"MaxBlockDim{c}"] for c in ["X", "Y", "Z"])
self._maximum_block_size = maximum_block_size
self._compile_time_block_size = compile_time_block_size
self._device_number = device_number
@property
def cuda_indices(self):
block_size = self._block_size if self._compile_time_block_size else BLOCK_DIM
indices = [block_index * bs + thread_idx
for block_index, bs, thread_idx in zip(BLOCK_IDX, block_size, THREAD_IDX)]
return indices[:self._dim]
@property
def coordinates(self):
if self._dim < 4:
coordinates = [c + iter_slice.start for c, iter_slice in zip(self.cuda_indices, self._iteration_space)]
return coordinates[:self._dim]
else:
coordinates = list()
width = self._iteration_space[1].stop - self.iteration_space[1].start
coordinates.append(div_floor(self.cuda_indices[0], width))
coordinates.append(sp.Mod(self.cuda_indices[0], width))
coordinates.append(self.cuda_indices[1] + self.iteration_space[2].start)
coordinates.append(self.cuda_indices[2] + self.iteration_space[3].start)
return coordinates
def get_loop_ctr_assignments(self, loop_counter_symbols):
return _loop_ctr_assignments(loop_counter_symbols, self.coordinates, self._iteration_space)
def call_parameters(self, arr_shape):
numeric_iteration_slice = _get_numeric_iteration_slice(self._iteration_space, arr_shape)
widths = _get_widths(numeric_iteration_slice)
if len(widths) > 3:
widths = [widths[0] * widths[1], widths[2], widths[3]]
extend_bs = (1,) * (3 - len(self._block_size))
block_size = self._block_size + extend_bs
if not self._compile_time_block_size:
assert len(block_size) == 3
adapted_block_size = []
for i in range(len(widths)):
factor = div_floor(prod(block_size[:i]), prod(adapted_block_size))
adapted_block_size.append(sp.Min(block_size[i] * factor, widths[i]))
extend_adapted_bs = (1,) * (3 - len(adapted_block_size))
block_size = tuple(adapted_block_size) + extend_adapted_bs
block_size = tuple(sp.Min(bs, max_bs) for bs, max_bs in zip(block_size, self._maximum_block_size))
grid = tuple(div_ceil(length, block_size) for length, block_size in zip(widths, block_size))
extend_gr = (1,) * (3 - len(grid))
return {'block': block_size,
'grid': grid + extend_gr}
def guard(self, kernel_content, arr_shape):
arr_shape = arr_shape[:self._dim]
if len(self._iteration_space) - 1 == len(arr_shape):
numeric_iteration_slice = _get_numeric_iteration_slice(self._iteration_space[1:], arr_shape)
numeric_iteration_slice = [self.iteration_space[0]] + numeric_iteration_slice
else:
assert len(self._iteration_space) == len(arr_shape), "Iteration space must be equal to the array shape"
numeric_iteration_slice = _get_numeric_iteration_slice(self._iteration_space, arr_shape)
end = [s.stop if s.stop != 0 else 1 for s in numeric_iteration_slice]
for i, s in enumerate(numeric_iteration_slice):
if s.step and s.step != 1:
end[i] = div_ceil(s.stop - s.start, s.step) + s.start
if self._dim < 4:
conditions = [c < e for c, e in zip(self.coordinates, end)]
else:
end = [end[0] * end[1], end[2], end[3]]
coordinates = [c + iter_slice.start for c, iter_slice in zip(self.cuda_indices, self._iteration_space[1:])]
conditions = [c < e for c, e in zip(coordinates, end)]
condition = conditions[0]
for c in conditions[1:]:
condition = sp.And(condition, c)
return Block([Conditional(condition, kernel_content)])
def numeric_iteration_space(self, arr_shape):
return _get_numeric_iteration_slice(self._iteration_space, arr_shape)
def limit_block_size_by_register_restriction(self, block_size, required_registers_per_thread):
"""Shrinks the block_size if there are too many registers used per block.
This is not done automatically, since the required_registers_per_thread are not known before compilation.
They can be obtained by ``func.num_regs`` from a cupy function.
Args:
block_size: used block size that is target for limiting
required_registers_per_thread: needed registers per thread
returns: smaller block_size if too many registers are used.
"""
import cupy as cp
# See https://github.com/cupy/cupy/issues/7676
if cp.cuda.runtime.is_hip:
max_registers_per_block = cp.cuda.runtime.deviceGetAttribute(71, self._device_number)
else:
device = cp.cuda.Device(self._device_number)
da = device.attributes
max_registers_per_block = da.get("MaxRegistersPerBlock")
result = list(block_size)
while True:
required_registers = math.prod(result) * required_registers_per_thread
if required_registers <= max_registers_per_block:
return result
else:
largest_list_entry_idx = max(range(len(result)), key=lambda e: result[e])
assert result[largest_list_entry_idx] >= 2
result[largest_list_entry_idx] //= 2
@staticmethod
def permute_block_size_according_to_layout(block_size, layout):
"""Returns modified block_size such that the fastest coordinate gets the biggest block dimension"""
if not is_integer_sequence(block_size):
return block_size
sorted_block_size = list(sorted(block_size, reverse=True))
while len(sorted_block_size) > len(layout):
sorted_block_size[0] *= sorted_block_size[-1]
sorted_block_size = sorted_block_size[:-1]
result = list(block_size)
for l, bs in zip(reversed(layout), sorted_block_size):
result[l] = bs
return tuple(result[:len(layout)])
def max_threads_per_block(self):
if is_integer_sequence(self._block_size):
return prod(self._block_size)
else:
return None
def symbolic_parameters(self):
return set(b for b in self._block_size if isinstance(b, sp.Symbol))
class LineIndexing(AbstractIndexing):
"""
Indexing scheme that assigns the innermost 'line' i.e. the elements which are adjacent in memory to a 1D GPU block.
The fastest coordinate is indexed with thread_idx.x, the remaining coordinates are mapped to block_idx.{x,y,z}
This indexing scheme supports up to 4 spatial dimensions, where the innermost dimensions is not larger than the
maximum amount of threads allowed in a GPU block (which depends on device).
Args:
iteration_space: list of slices to determine start, stop and the step size for each coordinate
data_layout: tuple to determine the fast and slow coordinates.
"""
def __init__(self, iteration_space: Tuple[slice], data_layout: Tuple):
super(LineIndexing, self).__init__(iteration_space, data_layout)
if len(iteration_space) > 4:
raise NotImplementedError("This indexing scheme supports at most 4 spatial dimensions")
@property
def cuda_indices(self):
available_indices = [THREAD_IDX[0]] + BLOCK_IDX
coordinates = available_indices[:self.dim]
fastest_coordinate = self.data_layout[-1]
coordinates[0], coordinates[fastest_coordinate] = coordinates[fastest_coordinate], coordinates[0]
return coordinates
@property
def coordinates(self):
return [i + o.start for i, o in zip(self.cuda_indices, self._iteration_space)]
def get_loop_ctr_assignments(self, loop_counter_symbols):
return _loop_ctr_assignments(loop_counter_symbols, self.coordinates, self._iteration_space)
def call_parameters(self, arr_shape):
numeric_iteration_slice = _get_numeric_iteration_slice(self._iteration_space, arr_shape)
widths = _get_widths(numeric_iteration_slice)
def get_shape_of_cuda_idx(cuda_idx):
if cuda_idx not in self.cuda_indices:
return 1
else:
idx = self.cuda_indices.index(cuda_idx)
return widths[idx]
return {'block': tuple([get_shape_of_cuda_idx(idx) for idx in THREAD_IDX]),
'grid': tuple([get_shape_of_cuda_idx(idx) for idx in BLOCK_IDX])}
def guard(self, kernel_content, arr_shape):
return kernel_content
def max_threads_per_block(self):
return None
def symbolic_parameters(self):
return set()
def numeric_iteration_space(self, arr_shape):
return _get_numeric_iteration_slice(self._iteration_space, arr_shape)
# -------------------------------------- Helper functions --------------------------------------------------------------
def _get_numeric_iteration_slice(iteration_slice, arr_shape):
res = []
for slice_component, shape in zip(iteration_slice, arr_shape):
result_slice = slice_component
if not isinstance(result_slice.start, int):
start = result_slice.start
assert len(start.free_symbols) == 1
start = start.subs({symbol: shape for symbol in start.free_symbols})
result_slice = slice(start, result_slice.stop, result_slice.step)
if not isinstance(result_slice.stop, int):
stop = result_slice.stop
assert len(stop.free_symbols) == 1
stop = stop.subs({symbol: shape for symbol in stop.free_symbols})
result_slice = slice(result_slice.start, stop, result_slice.step)
assert isinstance(result_slice.step, int)
res.append(result_slice)
return res
def _get_widths(iteration_slice):
widths = []
for iter_slice in iteration_slice:
step = iter_slice.step
assert isinstance(step, int), f"Step can only be of type int not of type {type(step)}"
start = iter_slice.start
stop = iter_slice.stop
if step == 1:
if stop - start == 0:
widths.append(1)
else:
widths.append(stop - start)
else:
width = (stop - start) / step
if isinstance(width, int):
widths.append(width)
elif isinstance(width, float):
widths.append(math.ceil(width))
else:
widths.append(div_ceil(stop - start, step))
return widths
def _loop_ctr_assignments(loop_counter_symbols, coordinates, iteration_space):
loop_ctr_assignments = []
for loop_counter, coordinate, iter_slice in zip(loop_counter_symbols, coordinates, iteration_space):
if isinstance(iter_slice, slice) and iter_slice.step > 1:
offset = (iter_slice.step * iter_slice.start) - iter_slice.start
loop_ctr_assignments.append(SympyAssignment(loop_counter, coordinate * iter_slice.step - offset))
elif iter_slice.start == iter_slice.stop:
loop_ctr_assignments.append(SympyAssignment(loop_counter, 0))
else:
loop_ctr_assignments.append(SympyAssignment(loop_counter, coordinate))
return loop_ctr_assignments
def indexing_creator_from_params(gpu_indexing, gpu_indexing_params):
if isinstance(gpu_indexing, str):
if gpu_indexing == 'block':
indexing_creator = BlockIndexing
elif gpu_indexing == 'line':
indexing_creator = LineIndexing
else:
raise ValueError(f"Unknown GPU indexing {gpu_indexing}. Valid values are 'block' and 'line'")
if gpu_indexing_params:
indexing_creator = partial(indexing_creator, **gpu_indexing_params)
return indexing_creator
else:
return gpu_indexing
import sympy as sp
from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, SympyAssignment
from pystencils.config import CreateKernelConfig
from pystencils.typing import StructType, TypedSymbol
from pystencils.typing.transformations import add_types
from pystencils.field import Field, FieldType
from pystencils.enums import Target, Backend
from pystencils.gpu.gpujit import make_python_function
from pystencils.node_collection import NodeCollection
from pystencils.gpu.indexing import indexing_creator_from_params
from pystencils.slicing import normalize_slice
from pystencils.transformations import (
get_base_buffer_index, get_common_field, get_common_indexed_element, parse_base_pointer_info,
resolve_buffer_accesses, resolve_field_accesses, unify_shape_symbols)
def create_cuda_kernel(assignments: NodeCollection, config: CreateKernelConfig):
function_name = config.function_name
indexing_creator = indexing_creator_from_params(config.gpu_indexing, config.gpu_indexing_params)
iteration_slice = config.iteration_slice
ghost_layers = config.ghost_layers
fields_written = assignments.bound_fields
fields_read = assignments.rhs_fields
assignments = assignments.all_assignments
assignments = add_types(assignments, config)
all_fields = fields_read.union(fields_written)
read_only_fields = set([f.name for f in fields_read - fields_written])
buffers = set([f for f in all_fields if FieldType.is_buffer(f)])
fields_without_buffers = all_fields - buffers
field_accesses = set()
num_buffer_accesses = 0
indexed_elements = set()
for eq in assignments:
indexed_elements.update(eq.atoms(sp.Indexed))
field_accesses.update(eq.atoms(Field.Access))
field_accesses = {e for e in field_accesses if not e.is_absolute_access}
num_buffer_accesses += sum(1 for access in eq.atoms(Field.Access) if FieldType.is_buffer(access.field))
# common shape and field to from the iteration space
common_field = get_common_field(fields_without_buffers)
common_shape = common_field.spatial_shape
if iteration_slice is None:
# determine iteration slice from ghost layers
if ghost_layers is None:
# determine required number of ghost layers from field access
required_ghost_layers = max([fa.required_ghost_layers for fa in field_accesses])
ghost_layers = [(required_ghost_layers, required_ghost_layers)] * len(common_shape)
iteration_slice = []
if isinstance(ghost_layers, int):
for i in range(len(common_shape)):
iteration_slice.append(slice(ghost_layers, -ghost_layers if ghost_layers > 0 else None))
ghost_layers = [(ghost_layers, ghost_layers)] * len(common_shape)
else:
for i in range(len(common_shape)):
iteration_slice.append(slice(ghost_layers[i][0],
-ghost_layers[i][1] if ghost_layers[i][1] > 0 else None))
iteration_space = normalize_slice(iteration_slice, common_shape)
else:
iteration_space = normalize_slice(iteration_slice, common_shape)
iteration_space = tuple([s if isinstance(s, slice) else slice(s, s + 1, 1) for s in iteration_space])
loop_counter_symbols = [LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(len(iteration_space))]
if len(indexed_elements) > 0:
common_indexed_element = get_common_indexed_element(indexed_elements)
index = common_indexed_element.indices[0].atoms(TypedSymbol)
assert len(index) == 1, "index expressions must only contain one symbol representing the index"
indexing = indexing_creator(iteration_space=(slice(0, common_indexed_element.shape[0], 1), *iteration_space),
data_layout=common_field.layout)
extended_ctrs = [index.pop(), *loop_counter_symbols]
loop_counter_assignments = indexing.get_loop_ctr_assignments(extended_ctrs)
else:
indexing = indexing_creator(iteration_space=iteration_space, data_layout=common_field.layout)
loop_counter_assignments = indexing.get_loop_ctr_assignments(loop_counter_symbols)
assignments = loop_counter_assignments + assignments
block = indexing.guard(Block(assignments), common_shape)
unify_shape_symbols(block, common_shape=common_shape, fields=fields_without_buffers)
ast = KernelFunction(block,
Target.GPU,
Backend.CUDA,
make_python_function,
ghost_layers,
function_name,
assignments=assignments)
ast.global_variables.update(indexing.index_variables)
base_pointer_spec = config.base_pointer_specification
if base_pointer_spec is None:
base_pointer_spec = []
base_pointer_info = {f.name: parse_base_pointer_info(base_pointer_spec, [2, 1, 0],
f.spatial_dimensions, f.index_dimensions)
for f in all_fields}
coord_mapping = {f.name: loop_counter_symbols for f in all_fields}
if any(FieldType.is_buffer(f) for f in all_fields):
base_buffer_index = get_base_buffer_index(ast, loop_counter_symbols, iteration_space)
resolve_buffer_accesses(ast, base_buffer_index, read_only_fields)
resolve_field_accesses(ast, read_only_fields, field_to_base_pointer_info=base_pointer_info,
field_to_fixed_coordinates=coord_mapping)
# add the function which determines #blocks and #threads as additional member to KernelFunction node
# this is used by the jit
# If loop counter symbols have been explicitly used in the update equations (e.g. for built in periodicity),
# they are defined here
undefined_loop_counters = {LoopOverCoordinate.is_loop_counter_symbol(s): s for s in ast.body.undefined_symbols
if LoopOverCoordinate.is_loop_counter_symbol(s) is not None}
for i, loop_counter in undefined_loop_counters.items():
ast.body.insert_front(SympyAssignment(loop_counter, indexing.coordinates[i]))
ast.indexing = indexing
return ast
def created_indexed_cuda_kernel(assignments: NodeCollection, config: CreateKernelConfig):
index_fields = config.index_fields
function_name = config.function_name
coordinate_names = config.coordinate_names
indexing_creator = indexing_creator_from_params(config.gpu_indexing, config.gpu_indexing_params)
fields_written = assignments.bound_fields
fields_read = assignments.rhs_fields
all_fields = fields_read.union(fields_written)
read_only_fields = set([f.name for f in fields_read - fields_written])
# extract the index fields based on the name. The original index field might have been modified
index_fields = [idx_field for idx_field in index_fields if idx_field.name in [f.name for f in all_fields]]
non_index_fields = [f for f in all_fields if f not in index_fields]
spatial_coordinates = {f.spatial_dimensions for f in non_index_fields}
assert len(spatial_coordinates) == 1, f"Non-index fields do not have the same number of spatial coordinates " \
f"Non index fields are {non_index_fields}, spatial coordinates are " \
f"{spatial_coordinates}"
spatial_coordinates = list(spatial_coordinates)[0]
assignments = assignments.all_assignments
assignments = add_types(assignments, config)
for index_field in index_fields:
index_field.field_type = FieldType.INDEXED
assert FieldType.is_indexed(index_field)
assert index_field.spatial_dimensions == 1, "Index fields have to be 1D"
def get_coordinate_symbol_assignment(name):
for ind_f in index_fields:
assert isinstance(ind_f.dtype, StructType), "Index fields have to have a struct data type"
data_type = ind_f.dtype
if data_type.has_element(name):
rhs = ind_f[0](name)
lhs = TypedSymbol(name, data_type.get_element_type(name))
return SympyAssignment(lhs, rhs)
raise ValueError(f"Index {name} not found in any of the passed index fields")
coordinate_symbol_assignments = [get_coordinate_symbol_assignment(n)
for n in coordinate_names[:spatial_coordinates]]
coordinate_typed_symbols = [eq.lhs for eq in coordinate_symbol_assignments]
idx_field = list(index_fields)[0]
iteration_space = normalize_slice(tuple([slice(None, None, None)]) * len(idx_field.spatial_shape),
idx_field.spatial_shape)
indexing = indexing_creator(iteration_space=iteration_space,
data_layout=idx_field.layout)
function_body = Block(coordinate_symbol_assignments + assignments)
function_body = indexing.guard(function_body, get_common_field(index_fields).spatial_shape)
ast = KernelFunction(function_body, Target.GPU, Backend.CUDA, make_python_function,
None, function_name, assignments=assignments)
ast.global_variables.update(indexing.index_variables)
coord_mapping = indexing.coordinates
base_pointer_spec = [['spatialInner0']]
base_pointer_info = {f.name: parse_base_pointer_info(base_pointer_spec, [2, 1, 0],
f.spatial_dimensions, f.index_dimensions)
for f in all_fields}
coord_mapping = {f.name: coord_mapping for f in index_fields}
coord_mapping.update({f.name: coordinate_typed_symbols for f in non_index_fields})
resolve_field_accesses(ast, read_only_fields, field_to_fixed_coordinates=coord_mapping,
field_to_base_pointer_info=base_pointer_info)
# add the function which determines #blocks and #threads as additional member to KernelFunction node
# this is used by the jit
ast.indexing = indexing
return ast
import numpy as np
from itertools import product
from pystencils import CreateKernelConfig, create_kernel
from pystencils.gpu import make_python_function
from pystencils import Assignment, Field
from pystencils.enums import Target
from pystencils.slicing import get_periodic_boundary_src_dst_slices, normalize_slice
def create_copy_kernel(domain_size, from_slice, to_slice, index_dimensions=0, index_dim_shape=1, dtype=np.float64):
"""Copies a rectangular part of an array to another non-overlapping part"""
f = Field.create_generic("pdfs", len(domain_size), index_dimensions=index_dimensions, dtype=dtype)
normalized_from_slice = normalize_slice(from_slice, f.spatial_shape)
normalized_to_slice = normalize_slice(to_slice, f.spatial_shape)
offset = [s1.start - s2.start for s1, s2 in zip(normalized_from_slice, normalized_to_slice)]
assert offset == [s1.stop - s2.stop for s1, s2 in zip(normalized_from_slice, normalized_to_slice)], \
"Slices have to have same size"
update_eqs = []
if index_dimensions < 2:
index_dim_shape = [index_dim_shape]
for i in product(*[range(d) for d in index_dim_shape]):
eq = Assignment(f(*i), f[tuple(offset)](*i))
update_eqs.append(eq)
config = CreateKernelConfig(target=Target.GPU, iteration_slice=to_slice, skip_independence_check=True)
ast = create_kernel(update_eqs, config=config)
return ast
def get_periodic_boundary_functor(stencil, domain_size, index_dimensions=0, index_dim_shape=1, ghost_layers=1,
thickness=None, dtype=np.float64, target=Target.GPU):
assert target in {Target.GPU}
src_dst_slice_tuples = get_periodic_boundary_src_dst_slices(stencil, ghost_layers, thickness)
kernels = []
for src_slice, dst_slice in src_dst_slice_tuples:
ast = create_copy_kernel(domain_size, src_slice, dst_slice, index_dimensions, index_dim_shape, dtype)
kernels.append(make_python_function(ast))
def functor(pdfs, **_):
for kernel in kernels:
kernel(pdfs=pdfs)
return functor
from os.path import dirname, realpath
def get_pystencils_include_path():
return dirname(realpath(__file__))
/*
Copyright 2010-2011, D. E. Shaw Research. All rights reserved.
Copyright 2019-2023, Michael Kuron.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions, and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions, and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include <emmintrin.h> // SSE2
#include <wmmintrin.h> // AES
#ifdef __AVX__
#include <immintrin.h> // AVX*
#else
#include <smmintrin.h> // SSE4
#ifdef __FMA__
#include <immintrin.h> // FMA
#endif
#endif
#include <cstdint>
#include <array>
#include <map>
#define QUALIFIERS inline
#define TWOPOW53_INV_DOUBLE (1.1102230246251565e-16)
#define TWOPOW32_INV_FLOAT (2.3283064e-10f)
#include "myintrin.h"
typedef std::uint32_t uint32;
typedef std::uint64_t uint64;
template <typename T, std::size_t Alignment>
class AlignedAllocator
{
public:
typedef T value_type;
template <typename U>
struct rebind {
typedef AlignedAllocator<U, Alignment> other;
};
T * allocate(const std::size_t n) const {
if (n == 0) {
return nullptr;
}
void * const p = _mm_malloc(n*sizeof(T), Alignment);
if (p == nullptr) {
throw std::bad_alloc();
}
return static_cast<T *>(p);
}
void deallocate(T * const p, const std::size_t n) const {
_mm_free(p);
}
};
template <typename Key, typename T>
using AlignedMap = std::map<Key, T, std::less<Key>, AlignedAllocator<std::pair<const Key, T>, sizeof(Key)>>;
#if defined(__AES__) || defined(_MSC_VER)
QUALIFIERS __m128i aesni_keygen_assist(__m128i temp1, __m128i temp2) {
__m128i temp3;
temp2 = _mm_shuffle_epi32(temp2 ,0xff);
temp3 = _mm_slli_si128(temp1, 0x4);
temp1 = _mm_xor_si128(temp1, temp3);
temp3 = _mm_slli_si128(temp3, 0x4);
temp1 = _mm_xor_si128(temp1, temp3);
temp3 = _mm_slli_si128(temp3, 0x4);
temp1 = _mm_xor_si128(temp1, temp3);
temp1 = _mm_xor_si128(temp1, temp2);
return temp1;
}
QUALIFIERS std::array<__m128i,11> aesni_keygen(__m128i k) {
std::array<__m128i,11> rk;
__m128i tmp;
rk[0] = k;
tmp = _mm_aeskeygenassist_si128(k, 0x1);
k = aesni_keygen_assist(k, tmp);
rk[1] = k;
tmp = _mm_aeskeygenassist_si128(k, 0x2);
k = aesni_keygen_assist(k, tmp);
rk[2] = k;
tmp = _mm_aeskeygenassist_si128(k, 0x4);
k = aesni_keygen_assist(k, tmp);
rk[3] = k;
tmp = _mm_aeskeygenassist_si128(k, 0x8);
k = aesni_keygen_assist(k, tmp);
rk[4] = k;
tmp = _mm_aeskeygenassist_si128(k, 0x10);
k = aesni_keygen_assist(k, tmp);
rk[5] = k;
tmp = _mm_aeskeygenassist_si128(k, 0x20);
k = aesni_keygen_assist(k, tmp);
rk[6] = k;
tmp = _mm_aeskeygenassist_si128(k, 0x40);
k = aesni_keygen_assist(k, tmp);
rk[7] = k;
tmp = _mm_aeskeygenassist_si128(k, 0x80);
k = aesni_keygen_assist(k, tmp);
rk[8] = k;
tmp = _mm_aeskeygenassist_si128(k, 0x1b);
k = aesni_keygen_assist(k, tmp);
rk[9] = k;
tmp = _mm_aeskeygenassist_si128(k, 0x36);
k = aesni_keygen_assist(k, tmp);
rk[10] = k;
return rk;
}
QUALIFIERS const std::array<__m128i,11> & aesni_roundkeys(const __m128i & k128) {
alignas(16) std::array<uint32,4> a;
_mm_store_si128((__m128i*) a.data(), k128);
static AlignedMap<std::array<uint32,4>, std::array<__m128i,11>> roundkeys;
if(roundkeys.find(a) == roundkeys.end()) {
auto rk = aesni_keygen(k128);
roundkeys[a] = rk;
}
return roundkeys[a];
}
QUALIFIERS __m128i aesni1xm128i(const __m128i & in, const __m128i & k0) {
auto k = aesni_roundkeys(k0);
__m128i x = _mm_xor_si128(k[0], in);
x = _mm_aesenc_si128(x, k[1]);
x = _mm_aesenc_si128(x, k[2]);
x = _mm_aesenc_si128(x, k[3]);
x = _mm_aesenc_si128(x, k[4]);
x = _mm_aesenc_si128(x, k[5]);
x = _mm_aesenc_si128(x, k[6]);
x = _mm_aesenc_si128(x, k[7]);
x = _mm_aesenc_si128(x, k[8]);
x = _mm_aesenc_si128(x, k[9]);
x = _mm_aesenclast_si128(x, k[10]);
return x;
}
QUALIFIERS void aesni_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
double & rnd1, double & rnd2)
{
// pack input and call AES
__m128i c128 = _mm_set_epi32(ctr3, ctr2, ctr1, ctr0);
__m128i k128 = _mm_set_epi32(key3, key2, key1, key0);
c128 = aesni1xm128i(c128, k128);
// convert 32 to 64 bit and put 0th and 2nd element into x, 1st and 3rd element into y
__m128i x = _mm_and_si128(c128, _mm_set_epi32(0, 0xffffffff, 0, 0xffffffff));
__m128i y = _mm_and_si128(c128, _mm_set_epi32(0xffffffff, 0, 0xffffffff, 0));
y = _mm_srli_si128(y, 4);
// calculate z = x ^ y << (53 - 32))
__m128i z = _mm_sll_epi64(y, _mm_set1_epi64x(53 - 32));
z = _mm_xor_si128(x, z);
// convert uint64 to double
__m128d rs = _my_cvtepu64_pd(z);
// calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
#ifdef __FMA__
rs = _mm_fmadd_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE), _mm_set1_pd(TWOPOW53_INV_DOUBLE/2.0));
#else
rs = _mm_mul_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE));
rs = _mm_add_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE/2.0));
#endif
// store result
alignas(16) double rr[2];
_mm_store_pd(rr, rs);
rnd1 = rr[0];
rnd2 = rr[1];
}
QUALIFIERS void aesni_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
float & rnd1, float & rnd2, float & rnd3, float & rnd4)
{
// pack input and call AES
__m128i c128 = _mm_set_epi32(ctr3, ctr2, ctr1, ctr0);
__m128i k128 = _mm_set_epi32(key3, key2, key1, key0);
c128 = aesni1xm128i(c128, k128);
// convert uint32 to float
__m128 rs = _my_cvtepu32_ps(c128);
// calculate rs * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
#ifdef __FMA__
rs = _mm_fmadd_ps(rs, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
#else
rs = _mm_mul_ps(rs, _mm_set1_ps(TWOPOW32_INV_FLOAT));
rs = _mm_add_ps(rs, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
#endif
// store result
alignas(16) float r[4];
_mm_store_ps(r, rs);
rnd1 = r[0];
rnd2 = r[1];
rnd3 = r[2];
rnd4 = r[3];
}
template<bool high>
QUALIFIERS __m128d _uniform_double_hq(__m128i x, __m128i y)
{
// convert 32 to 64 bit
if (high)
{
x = _mm_unpackhi_epi32(x, _mm_setzero_si128());
y = _mm_unpackhi_epi32(y, _mm_setzero_si128());
}
else
{
x = _mm_unpacklo_epi32(x, _mm_setzero_si128());
y = _mm_unpacklo_epi32(y, _mm_setzero_si128());
}
// calculate z = x ^ y << (53 - 32))
__m128i z = _mm_sll_epi64(y, _mm_set1_epi64x(53 - 32));
z = _mm_xor_si128(x, z);
// convert uint64 to double
__m128d rs = _my_cvtepu64_pd(z);
// calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
#ifdef __FMA__
rs = _mm_fmadd_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE), _mm_set1_pd(TWOPOW53_INV_DOUBLE/2.0));
#else
rs = _mm_mul_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE));
rs = _mm_add_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE/2.0));
#endif
return rs;
}
QUALIFIERS void aesni_float4(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
__m128 & rnd1, __m128 & rnd2, __m128 & rnd3, __m128 & rnd4)
{
// pack input and call AES
__m128i k128 = _mm_set_epi32(key3, key2, key1, key0);
__m128i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_MY_TRANSPOSE4_EPI32(ctr[0], ctr[1], ctr[2], ctr[3]);
for (int i = 0; i < 4; ++i)
{
ctr[i] = aesni1xm128i(ctr[i], k128);
}
_MY_TRANSPOSE4_EPI32(ctr[0], ctr[1], ctr[2], ctr[3]);
// convert uint32 to float
rnd1 = _my_cvtepu32_ps(ctr[0]);
rnd2 = _my_cvtepu32_ps(ctr[1]);
rnd3 = _my_cvtepu32_ps(ctr[2]);
rnd4 = _my_cvtepu32_ps(ctr[3]);
// calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
#ifdef __FMA__
rnd1 = _mm_fmadd_ps(rnd1, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd2 = _mm_fmadd_ps(rnd2, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd3 = _mm_fmadd_ps(rnd3, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd4 = _mm_fmadd_ps(rnd4, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0));
#else
rnd1 = _mm_mul_ps(rnd1, _mm_set1_ps(TWOPOW32_INV_FLOAT));
rnd1 = _mm_add_ps(rnd1, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd2 = _mm_mul_ps(rnd2, _mm_set1_ps(TWOPOW32_INV_FLOAT));
rnd2 = _mm_add_ps(rnd2, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd3 = _mm_mul_ps(rnd3, _mm_set1_ps(TWOPOW32_INV_FLOAT));
rnd3 = _mm_add_ps(rnd3, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd4 = _mm_mul_ps(rnd4, _mm_set1_ps(TWOPOW32_INV_FLOAT));
rnd4 = _mm_add_ps(rnd4, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
#endif
}
QUALIFIERS void aesni_double2(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
__m128d & rnd1lo, __m128d & rnd1hi, __m128d & rnd2lo, __m128d & rnd2hi)
{
// pack input and call AES
__m128i k128 = _mm_set_epi32(key3, key2, key1, key0);
__m128i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_MY_TRANSPOSE4_EPI32(ctr[0], ctr[1], ctr[2], ctr[3]);
for (int i = 0; i < 4; ++i)
{
ctr[i] = aesni1xm128i(ctr[i], k128);
}
_MY_TRANSPOSE4_EPI32(ctr[0], ctr[1], ctr[2], ctr[3]);
rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]);
rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]);
rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]);
rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]);
}
QUALIFIERS void aesni_float4(uint32 ctr0, __m128i ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
__m128 & rnd1, __m128 & rnd2, __m128 & rnd3, __m128 & rnd4)
{
__m128i ctr0v = _mm_set1_epi32(ctr0);
__m128i ctr2v = _mm_set1_epi32(ctr2);
__m128i ctr3v = _mm_set1_epi32(ctr3);
aesni_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, key2, key3, rnd1, rnd2, rnd3, rnd4);
}
QUALIFIERS void aesni_double2(uint32 ctr0, __m128i ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
__m128d & rnd1lo, __m128d & rnd1hi, __m128d & rnd2lo, __m128d & rnd2hi)
{
__m128i ctr0v = _mm_set1_epi32(ctr0);
__m128i ctr2v = _mm_set1_epi32(ctr2);
__m128i ctr3v = _mm_set1_epi32(ctr3);
aesni_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, key2, key3, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
}
QUALIFIERS void aesni_double2(uint32 ctr0, __m128i ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
__m128d & rnd1, __m128d & rnd2)
{
__m128i ctr0v = _mm_set1_epi32(ctr0);
__m128i ctr2v = _mm_set1_epi32(ctr2);
__m128i ctr3v = _mm_set1_epi32(ctr3);
__m128d ignore;
aesni_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, key2, key3, rnd1, ignore, rnd2, ignore);
}
#endif
#ifdef __AVX2__
QUALIFIERS const std::array<__m256i,11> & aesni_roundkeys(const __m256i & k256) {
alignas(32) std::array<uint32,8> a;
_mm256_store_si256((__m256i*) a.data(), k256);
static AlignedMap<std::array<uint32,8>, std::array<__m256i,11>> roundkeys;
if(roundkeys.find(a) == roundkeys.end()) {
auto rk1 = aesni_keygen(_mm256_extractf128_si256(k256, 0));
auto rk2 = aesni_keygen(_mm256_extractf128_si256(k256, 1));
for(int i = 0; i < 11; ++i) {
roundkeys[a][i] = _my256_set_m128i(rk2[i], rk1[i]);
}
}
return roundkeys[a];
}
QUALIFIERS __m256i aesni1xm128i(const __m256i & in, const __m256i & k0) {
#if defined(__VAES__)
auto k = aesni_roundkeys(k0);
__m256i x = _mm256_xor_si256(k[0], in);
x = _mm256_aesenc_epi128(x, k[1]);
x = _mm256_aesenc_epi128(x, k[2]);
x = _mm256_aesenc_epi128(x, k[3]);
x = _mm256_aesenc_epi128(x, k[4]);
x = _mm256_aesenc_epi128(x, k[5]);
x = _mm256_aesenc_epi128(x, k[6]);
x = _mm256_aesenc_epi128(x, k[7]);
x = _mm256_aesenc_epi128(x, k[8]);
x = _mm256_aesenc_epi128(x, k[9]);
x = _mm256_aesenclast_epi128(x, k[10]);
#else
__m128i a = aesni1xm128i(_mm256_extractf128_si256(in, 0), _mm256_extractf128_si256(k0, 0));
__m128i b = aesni1xm128i(_mm256_extractf128_si256(in, 1), _mm256_extractf128_si256(k0, 1));
__m256i x = _my256_set_m128i(b, a);
#endif
return x;
}
template<bool high>
QUALIFIERS __m256d _uniform_double_hq(__m256i x, __m256i y)
{
// convert 32 to 64 bit
if (high)
{
x = _mm256_cvtepu32_epi64(_mm256_extracti128_si256(x, 1));
y = _mm256_cvtepu32_epi64(_mm256_extracti128_si256(y, 1));
}
else
{
x = _mm256_cvtepu32_epi64(_mm256_extracti128_si256(x, 0));
y = _mm256_cvtepu32_epi64(_mm256_extracti128_si256(y, 0));
}
// calculate z = x ^ y << (53 - 32))
__m256i z = _mm256_sll_epi64(y, _mm_set1_epi64x(53 - 32));
z = _mm256_xor_si256(x, z);
// convert uint64 to double
__m256d rs = _my256_cvtepu64_pd(z);
// calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
#ifdef __FMA__
rs = _mm256_fmadd_pd(rs, _mm256_set1_pd(TWOPOW53_INV_DOUBLE), _mm256_set1_pd(TWOPOW53_INV_DOUBLE/2.0));
#else
rs = _mm256_mul_pd(rs, _mm256_set1_pd(TWOPOW53_INV_DOUBLE));
rs = _mm256_add_pd(rs, _mm256_set1_pd(TWOPOW53_INV_DOUBLE/2.0));
#endif
return rs;
}
QUALIFIERS void aesni_float4(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
__m256 & rnd1, __m256 & rnd2, __m256 & rnd3, __m256 & rnd4)
{
// pack input and call AES
__m256i k256 = _mm256_set_epi32(key3, key2, key1, key0, key3, key2, key1, key0);
__m256i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
__m128i a[4], b[4];
for (int i = 0; i < 4; ++i)
{
a[i] = _mm256_extractf128_si256(ctr[i], 0);
b[i] = _mm256_extractf128_si256(ctr[i], 1);
}
_MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]);
_MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]);
for (int i = 0; i < 4; ++i)
{
ctr[i] = _my256_set_m128i(b[i], a[i]);
}
for (int i = 0; i < 4; ++i)
{
ctr[i] = aesni1xm128i(ctr[i], k256);
}
for (int i = 0; i < 4; ++i)
{
a[i] = _mm256_extractf128_si256(ctr[i], 0);
b[i] = _mm256_extractf128_si256(ctr[i], 1);
}
_MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]);
_MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]);
for (int i = 0; i < 4; ++i)
{
ctr[i] = _my256_set_m128i(b[i], a[i]);
}
// convert uint32 to float
rnd1 = _my256_cvtepu32_ps(ctr[0]);
rnd2 = _my256_cvtepu32_ps(ctr[1]);
rnd3 = _my256_cvtepu32_ps(ctr[2]);
rnd4 = _my256_cvtepu32_ps(ctr[3]);
// calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
#ifdef __FMA__
rnd1 = _mm256_fmadd_ps(rnd1, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd2 = _mm256_fmadd_ps(rnd2, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd3 = _mm256_fmadd_ps(rnd3, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd4 = _mm256_fmadd_ps(rnd4, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0));
#else
rnd1 = _mm256_mul_ps(rnd1, _mm256_set1_ps(TWOPOW32_INV_FLOAT));
rnd1 = _mm256_add_ps(rnd1, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd2 = _mm256_mul_ps(rnd2, _mm256_set1_ps(TWOPOW32_INV_FLOAT));
rnd2 = _mm256_add_ps(rnd2, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd3 = _mm256_mul_ps(rnd3, _mm256_set1_ps(TWOPOW32_INV_FLOAT));
rnd3 = _mm256_add_ps(rnd3, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd4 = _mm256_mul_ps(rnd4, _mm256_set1_ps(TWOPOW32_INV_FLOAT));
rnd4 = _mm256_add_ps(rnd4, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
#endif
}
QUALIFIERS void aesni_double2(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
__m256d & rnd1lo, __m256d & rnd1hi, __m256d & rnd2lo, __m256d & rnd2hi)
{
// pack input and call AES
__m256i k256 = _mm256_set_epi32(key3, key2, key1, key0, key3, key2, key1, key0);
__m256i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
__m128i a[4], b[4];
for (int i = 0; i < 4; ++i)
{
a[i] = _mm256_extractf128_si256(ctr[i], 0);
b[i] = _mm256_extractf128_si256(ctr[i], 1);
}
_MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]);
_MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]);
for (int i = 0; i < 4; ++i)
{
ctr[i] = _my256_set_m128i(b[i], a[i]);
}
for (int i = 0; i < 4; ++i)
{
ctr[i] = aesni1xm128i(ctr[i], k256);
}
for (int i = 0; i < 4; ++i)
{
a[i] = _mm256_extractf128_si256(ctr[i], 0);
b[i] = _mm256_extractf128_si256(ctr[i], 1);
}
_MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]);
_MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]);
for (int i = 0; i < 4; ++i)
{
ctr[i] = _my256_set_m128i(b[i], a[i]);
}
rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]);
rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]);
rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]);
rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]);
}
QUALIFIERS void aesni_float4(uint32 ctr0, __m256i ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
__m256 & rnd1, __m256 & rnd2, __m256 & rnd3, __m256 & rnd4)
{
__m256i ctr0v = _mm256_set1_epi32(ctr0);
__m256i ctr2v = _mm256_set1_epi32(ctr2);
__m256i ctr3v = _mm256_set1_epi32(ctr3);
aesni_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, key2, key3, rnd1, rnd2, rnd3, rnd4);
}
QUALIFIERS void aesni_double2(uint32 ctr0, __m256i ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
__m256d & rnd1lo, __m256d & rnd1hi, __m256d & rnd2lo, __m256d & rnd2hi)
{
__m256i ctr0v = _mm256_set1_epi32(ctr0);
__m256i ctr2v = _mm256_set1_epi32(ctr2);
__m256i ctr3v = _mm256_set1_epi32(ctr3);
aesni_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, key2, key3, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
}
QUALIFIERS void aesni_double2(uint32 ctr0, __m256i ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
__m256d & rnd1, __m256d & rnd2)
{
#if 0
__m256i ctr0v = _mm256_set1_epi32(ctr0);
__m256i ctr2v = _mm256_set1_epi32(ctr2);
__m256i ctr3v = _mm256_set1_epi32(ctr3);
__m256d ignore;
aesni_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, key2, key3, rnd1, ignore, rnd2, ignore);
#else
__m128d rnd1lo, rnd1hi, rnd2lo, rnd2hi;
aesni_double2(ctr0, _mm256_extractf128_si256(ctr1, 0), ctr2, ctr3, key0, key1, key2, key3, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
rnd1 = _my256_set_m128d(rnd1hi, rnd1lo);
rnd2 = _my256_set_m128d(rnd2hi, rnd2lo);
#endif
}
#endif
#if defined(__AVX512F__) || defined(__AVX10_512BIT__)
QUALIFIERS const std::array<__m512i,11> & aesni_roundkeys(const __m512i & k512) {
alignas(64) std::array<uint32,16> a;
_mm512_store_si512((__m512i*) a.data(), k512);
static AlignedMap<std::array<uint32,16>, std::array<__m512i,11>> roundkeys;
if(roundkeys.find(a) == roundkeys.end()) {
auto rk1 = aesni_keygen(_mm512_extracti32x4_epi32(k512, 0));
auto rk2 = aesni_keygen(_mm512_extracti32x4_epi32(k512, 1));
auto rk3 = aesni_keygen(_mm512_extracti32x4_epi32(k512, 2));
auto rk4 = aesni_keygen(_mm512_extracti32x4_epi32(k512, 3));
for(int i = 0; i < 11; ++i) {
roundkeys[a][i] = _my512_set_m128i(rk4[i], rk3[i], rk2[i], rk1[i]);
}
}
return roundkeys[a];
}
QUALIFIERS __m512i aesni1xm128i(const __m512i & in, const __m512i & k0) {
#ifdef __VAES__
auto k = aesni_roundkeys(k0);
__m512i x = _mm512_xor_si512(k[0], in);
x = _mm512_aesenc_epi128(x, k[1]);
x = _mm512_aesenc_epi128(x, k[2]);
x = _mm512_aesenc_epi128(x, k[3]);
x = _mm512_aesenc_epi128(x, k[4]);
x = _mm512_aesenc_epi128(x, k[5]);
x = _mm512_aesenc_epi128(x, k[6]);
x = _mm512_aesenc_epi128(x, k[7]);
x = _mm512_aesenc_epi128(x, k[8]);
x = _mm512_aesenc_epi128(x, k[9]);
x = _mm512_aesenclast_epi128(x, k[10]);
#else
__m128i a = aesni1xm128i(_mm512_extracti32x4_epi32(in, 0), _mm512_extracti32x4_epi32(k0, 0));
__m128i b = aesni1xm128i(_mm512_extracti32x4_epi32(in, 1), _mm512_extracti32x4_epi32(k0, 1));
__m128i c = aesni1xm128i(_mm512_extracti32x4_epi32(in, 2), _mm512_extracti32x4_epi32(k0, 2));
__m128i d = aesni1xm128i(_mm512_extracti32x4_epi32(in, 3), _mm512_extracti32x4_epi32(k0, 3));
__m512i x = _my512_set_m128i(d, c, b, a);
#endif
return x;
}
template<bool high>
QUALIFIERS __m512d _uniform_double_hq(__m512i x, __m512i y)
{
// convert 32 to 64 bit
if (high)
{
x = _mm512_cvtepu32_epi64(_mm512_extracti64x4_epi64(x, 1));
y = _mm512_cvtepu32_epi64(_mm512_extracti64x4_epi64(y, 1));
}
else
{
x = _mm512_cvtepu32_epi64(_mm512_extracti64x4_epi64(x, 0));
y = _mm512_cvtepu32_epi64(_mm512_extracti64x4_epi64(y, 0));
}
// calculate z = x ^ y << (53 - 32))
__m512i z = _mm512_sll_epi64(y, _mm_set1_epi64x(53 - 32));
z = _mm512_xor_si512(x, z);
// convert uint64 to double
__m512d rs = _mm512_cvtepu64_pd(z);
// calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
rs = _mm512_fmadd_pd(rs, _mm512_set1_pd(TWOPOW53_INV_DOUBLE), _mm512_set1_pd(TWOPOW53_INV_DOUBLE/2.0));
return rs;
}
QUALIFIERS void aesni_float4(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
__m512 & rnd1, __m512 & rnd2, __m512 & rnd3, __m512 & rnd4)
{
// pack input and call AES
__m512i k512 = _mm512_set_epi32(key3, key2, key1, key0, key3, key2, key1, key0,
key3, key2, key1, key0, key3, key2, key1, key0);
__m512i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
__m128i a[4], b[4], c[4], d[4];
for (int i = 0; i < 4; ++i)
{
a[i] = _mm512_extracti32x4_epi32(ctr[i], 0);
b[i] = _mm512_extracti32x4_epi32(ctr[i], 1);
c[i] = _mm512_extracti32x4_epi32(ctr[i], 2);
d[i] = _mm512_extracti32x4_epi32(ctr[i], 3);
}
_MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]);
_MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]);
_MY_TRANSPOSE4_EPI32(c[0], c[1], c[2], c[3]);
_MY_TRANSPOSE4_EPI32(d[0], d[1], d[2], d[3]);
for (int i = 0; i < 4; ++i)
{
ctr[i] = _my512_set_m128i(d[i], c[i], b[i], a[i]);
}
for (int i = 0; i < 4; ++i)
{
ctr[i] = aesni1xm128i(ctr[i], k512);
}
for (int i = 0; i < 4; ++i)
{
a[i] = _mm512_extracti32x4_epi32(ctr[i], 0);
b[i] = _mm512_extracti32x4_epi32(ctr[i], 1);
c[i] = _mm512_extracti32x4_epi32(ctr[i], 2);
d[i] = _mm512_extracti32x4_epi32(ctr[i], 3);
}
_MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]);
_MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]);
_MY_TRANSPOSE4_EPI32(c[0], c[1], c[2], c[3]);
_MY_TRANSPOSE4_EPI32(d[0], d[1], d[2], d[3]);
for (int i = 0; i < 4; ++i)
{
ctr[i] = _my512_set_m128i(d[i], c[i], b[i], a[i]);
}
// convert uint32 to float
rnd1 = _mm512_cvtepu32_ps(ctr[0]);
rnd2 = _mm512_cvtepu32_ps(ctr[1]);
rnd3 = _mm512_cvtepu32_ps(ctr[2]);
rnd4 = _mm512_cvtepu32_ps(ctr[3]);
// calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
rnd1 = _mm512_fmadd_ps(rnd1, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd2 = _mm512_fmadd_ps(rnd2, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd3 = _mm512_fmadd_ps(rnd3, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd4 = _mm512_fmadd_ps(rnd4, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0));
}
QUALIFIERS void aesni_double2(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
__m512d & rnd1lo, __m512d & rnd1hi, __m512d & rnd2lo, __m512d & rnd2hi)
{
// pack input and call AES
__m512i k512 = _mm512_set_epi32(key3, key2, key1, key0, key3, key2, key1, key0,
key3, key2, key1, key0, key3, key2, key1, key0);
__m512i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
__m128i a[4], b[4], c[4], d[4];
for (int i = 0; i < 4; ++i)
{
a[i] = _mm512_extracti32x4_epi32(ctr[i], 0);
b[i] = _mm512_extracti32x4_epi32(ctr[i], 1);
c[i] = _mm512_extracti32x4_epi32(ctr[i], 2);
d[i] = _mm512_extracti32x4_epi32(ctr[i], 3);
}
_MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]);
_MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]);
_MY_TRANSPOSE4_EPI32(c[0], c[1], c[2], c[3]);
_MY_TRANSPOSE4_EPI32(d[0], d[1], d[2], d[3]);
for (int i = 0; i < 4; ++i)
{
ctr[i] = _my512_set_m128i(d[i], c[i], b[i], a[i]);
}
for (int i = 0; i < 4; ++i)
{
ctr[i] = aesni1xm128i(ctr[i], k512);
}
for (int i = 0; i < 4; ++i)
{
a[i] = _mm512_extracti32x4_epi32(ctr[i], 0);
b[i] = _mm512_extracti32x4_epi32(ctr[i], 1);
c[i] = _mm512_extracti32x4_epi32(ctr[i], 2);
d[i] = _mm512_extracti32x4_epi32(ctr[i], 3);
}
_MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]);
_MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]);
_MY_TRANSPOSE4_EPI32(c[0], c[1], c[2], c[3]);
_MY_TRANSPOSE4_EPI32(d[0], d[1], d[2], d[3]);
for (int i = 0; i < 4; ++i)
{
ctr[i] = _my512_set_m128i(d[i], c[i], b[i], a[i]);
}
rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]);
rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]);
rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]);
rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]);
}
QUALIFIERS void aesni_float4(uint32 ctr0, __m512i ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
__m512 & rnd1, __m512 & rnd2, __m512 & rnd3, __m512 & rnd4)
{
__m512i ctr0v = _mm512_set1_epi32(ctr0);
__m512i ctr2v = _mm512_set1_epi32(ctr2);
__m512i ctr3v = _mm512_set1_epi32(ctr3);
aesni_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, key2, key3, rnd1, rnd2, rnd3, rnd4);
}
QUALIFIERS void aesni_double2(uint32 ctr0, __m512i ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
__m512d & rnd1lo, __m512d & rnd1hi, __m512d & rnd2lo, __m512d & rnd2hi)
{
__m512i ctr0v = _mm512_set1_epi32(ctr0);
__m512i ctr2v = _mm512_set1_epi32(ctr2);
__m512i ctr3v = _mm512_set1_epi32(ctr3);
aesni_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, key2, key3, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
}
QUALIFIERS void aesni_double2(uint32 ctr0, __m512i ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
__m512d & rnd1, __m512d & rnd2)
{
#if 0
__m512i ctr0v = _mm512_set1_epi32(ctr0);
__m512i ctr2v = _mm512_set1_epi32(ctr2);
__m512i ctr3v = _mm512_set1_epi32(ctr3);
__m512d ignore;
aesni_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, key2, key3, rnd1, ignore, rnd2, ignore);
#else
__m256d rnd1lo, rnd1hi, rnd2lo, rnd2hi;
aesni_double2(ctr0, _mm512_extracti64x4_epi64(ctr1, 0), ctr2, ctr3, key0, key1, key2, key3, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
rnd1 = _my512_set_m256d(rnd1hi, rnd1lo);
rnd2 = _my512_set_m256d(rnd2hi, rnd2lo);
#endif
}
#endif
/*
Copyright 2021-2023, Michael Kuron.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions, and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions, and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#if defined(_MSC_VER)
#define __ARM_NEON
#endif
#ifdef __ARM_NEON
#include <arm_neon.h>
#endif
#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_SVE_BITS) && __ARM_FEATURE_SVE_BITS > 0
#include <arm_sve.h>
typedef svbool_t svbool_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
typedef svfloat32_t svfloat32_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
typedef svfloat64_t svfloat64_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
typedef svint32_t svint32_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
#endif
#ifdef __ARM_NEON
inline float32x4_t makeVec_f32(float a, float b, float c, float d)
{
alignas(16) float data[4] = {a, b, c, d};
return vld1q_f32(data);
}
inline float64x2_t makeVec_f64(double a, double b)
{
alignas(16) double data[2] = {a, b};
return vld1q_f64(data);
}
inline int32x4_t makeVec_s32(int a, int b, int c, int d)
{
alignas(16) int data[4] = {a, b, c, d};
return vld1q_s32(data);
}
#endif
inline void cachelineZero(void * p) {
#if !defined(_MSC_VER) || defined(__clang__)
__asm__ volatile("dc zva, %0"::"r"(p):"memory");
#endif
}
inline size_t _cachelineSize() {
#if !defined(_MSC_VER) || defined(__clang__)
// check that dc zva is permitted
uint64_t dczid;
__asm__ volatile ("mrs %0, dczid_el0" : "=r"(dczid));
if ((dczid & (1 << 4)) != 0) {
return SIZE_MAX;
}
// allocate and fill with ones
const size_t max_size = 0x100000;
uint8_t data[2*max_size];
for (size_t i = 0; i < 2*max_size; ++i) {
data[i] = 0xff;
}
// find alignment offset
size_t offset = max_size - ((uintptr_t) data) % max_size;
// zero a cacheline
cachelineZero((void*) (data + offset));
// make sure that at least one byte was zeroed
if (data[offset] != 0) {
return SIZE_MAX;
}
// make sure that nothing was zeroed before the pointer
if (data[offset-1] == 0) {
return SIZE_MAX;
}
// find the last byte that was zeroed
for (size_t size = 1; size < max_size; ++size) {
if (data[offset + size] != 0) {
return size;
}
}
#endif
// too much was zeroed
return SIZE_MAX;
}
inline size_t cachelineSize() {
static size_t size = _cachelineSize();
return size;
}
/*
Copyright 2023, Markus Holzer.
Copyright 2023, Michael Kuron.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions, and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions, and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#pragma once
#define POS_INFINITY __int_as_float(0x7f800000)
#define INFINITY POS_INFINITY
#define NEG_INFINITY __int_as_float(0xff800000)
#ifdef __HIPCC_RTC__
typedef __hip_uint8_t uint8_t;
typedef __hip_int8_t int8_t;
typedef __hip_uint16_t uint16_t;
typedef __hip_int16_t int16_t;
#endif
/*
Copyright 2023, Markus Holzer.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions, and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions, and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
/// Half precision support. Experimental. Use carefully.
///
/// This feature is experimental, since it strictly depends on the underlying architecture and compiler support.
/// On x86 architectures, what you can expect is that the data format is supported natively only for storage and
/// interchange. Arithmetic operations will likely involve casting to fp32 (C++ float) and truncation to fp16.
/// Only bandwidth bound code may therefore benefit. None of this is guaranteed, and may change in the future.
/// Clang version must be 15 or higher for x86 half precision support.
/// GCC version must be 12 or higher for x86 half precision support.
/// Also support seems to require SSE, so ensure that respective instruction sets are enabled.
/// See
/// https://clang.llvm.org/docs/LanguageExtensions.html#half-precision-floating-point
/// https://gcc.gnu.org/onlinedocs/gcc/Half-Precision.html
/// for more information.
#pragma once
using half = _Float16;
/*
Copyright 2019-2023, Michael Kuron.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions, and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions, and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#pragma once
#if defined(__SSE2__) || (defined(_MSC_VER) && !defined(_M_ARM64))
QUALIFIERS __m128 _my_cvtepu32_ps(const __m128i v)
{
#if defined(__AVX512VL__) || defined(__AVX10_1__)
return _mm_cvtepu32_ps(v);
#else
__m128i v2 = _mm_srli_epi32(v, 1);
__m128i v1 = _mm_and_si128(v, _mm_set1_epi32(1));
__m128 v2f = _mm_cvtepi32_ps(v2);
__m128 v1f = _mm_cvtepi32_ps(v1);
return _mm_add_ps(_mm_add_ps(v2f, v2f), v1f);
#endif
}
QUALIFIERS void _MY_TRANSPOSE4_EPI32(__m128i & R0, __m128i & R1, __m128i & R2, __m128i & R3)
{
__m128i T0, T1, T2, T3;
T0 = _mm_unpacklo_epi32(R0, R1);
T1 = _mm_unpacklo_epi32(R2, R3);
T2 = _mm_unpackhi_epi32(R0, R1);
T3 = _mm_unpackhi_epi32(R2, R3);
R0 = _mm_unpacklo_epi64(T0, T1);
R1 = _mm_unpackhi_epi64(T0, T1);
R2 = _mm_unpacklo_epi64(T2, T3);
R3 = _mm_unpackhi_epi64(T2, T3);
}
#endif
#if defined(__SSE4_1__) || (defined(_MSC_VER) && !defined(_M_ARM64))
#if !defined(__AVX512VL__) && !defined(__AVX10_1__) && defined(__GNUC__) && __GNUC__ >= 5 && !defined(__clang__)
__attribute__((optimize("no-associative-math")))
#endif
QUALIFIERS __m128d _my_cvtepu64_pd(const __m128i x)
{
#if defined(__AVX512VL__) || defined(__AVX10_1__)
return _mm_cvtepu64_pd(x);
#elif defined(__clang__)
return __builtin_convertvector((uint64_t __attribute__((__vector_size__(16)))) x, __m128d);
#else
__m128i xH = _mm_srli_epi64(x, 32);
xH = _mm_or_si128(xH, _mm_castpd_si128(_mm_set1_pd(19342813113834066795298816.))); // 2^84
__m128i xL = _mm_blend_epi16(x, _mm_castpd_si128(_mm_set1_pd(0x0010000000000000)), 0xcc); // 2^52
__m128d f = _mm_sub_pd(_mm_castsi128_pd(xH), _mm_set1_pd(19342813118337666422669312.)); // 2^84 + 2^52
return _mm_add_pd(f, _mm_castsi128_pd(xL));
#endif
}
#endif
#ifdef __AVX2__
QUALIFIERS __m256i _my256_set_m128i(__m128i hi, __m128i lo)
{
#if (!defined(__GNUC__) || __GNUC__ >= 8) || defined(__clang__)
return _mm256_set_m128i(hi, lo);
#else
return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), hi, 1);
#endif
}
QUALIFIERS __m256d _my256_set_m128d(__m128d hi, __m128d lo)
{
#if (!defined(__GNUC__) || __GNUC__ >= 8) || defined(__clang__)
return _mm256_set_m128d(hi, lo);
#else
return _mm256_insertf128_pd(_mm256_castpd128_pd256(lo), hi, 1);
#endif
}
QUALIFIERS __m256 _my256_cvtepu32_ps(const __m256i v)
{
#if defined(__AVX512VL__) || defined(__AVX10_1__)
return _mm256_cvtepu32_ps(v);
#else
__m256i v2 = _mm256_srli_epi32(v, 1);
__m256i v1 = _mm256_and_si256(v, _mm256_set1_epi32(1));
__m256 v2f = _mm256_cvtepi32_ps(v2);
__m256 v1f = _mm256_cvtepi32_ps(v1);
return _mm256_add_ps(_mm256_add_ps(v2f, v2f), v1f);
#endif
}
#if !defined(__AVX512VL__) && !defined(__AVX10_1__) && defined(__GNUC__) && __GNUC__ >= 5 && !defined(__clang__)
__attribute__((optimize("no-associative-math")))
#endif
QUALIFIERS __m256d _my256_cvtepu64_pd(const __m256i x)
{
#if defined(__AVX512VL__) || defined(__AVX10_1__)
return _mm256_cvtepu64_pd(x);
#elif defined(__clang__)
return __builtin_convertvector((uint64_t __attribute__((__vector_size__(32)))) x, __m256d);
#else
__m256i xH = _mm256_srli_epi64(x, 32);
xH = _mm256_or_si256(xH, _mm256_castpd_si256(_mm256_set1_pd(19342813113834066795298816.))); // 2^84
__m256i xL = _mm256_blend_epi16(x, _mm256_castpd_si256(_mm256_set1_pd(0x0010000000000000)), 0xcc); // 2^52
__m256d f = _mm256_sub_pd(_mm256_castsi256_pd(xH), _mm256_set1_pd(19342813118337666422669312.)); // 2^84 + 2^52
return _mm256_add_pd(f, _mm256_castsi256_pd(xL));
#endif
}
#endif
#if defined(__AVX512F__) || defined(__AVX10_512BIT__)
QUALIFIERS __m512i _my512_set_m128i(__m128i d, __m128i c, __m128i b, __m128i a)
{
return _mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512(a), b, 1), c, 2), d, 3);
}
QUALIFIERS __m512d _my512_set_m256d(__m256d b, __m256d a)
{
return _mm512_insertf64x4(_mm512_castpd256_pd512(a), b, 1);
}
#endif
/*
Copyright 2010-2011, D. E. Shaw Research. All rights reserved.
Copyright 2019-2024, Michael Kuron.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions, and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions, and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#if !defined(__OPENCL_VERSION__) && !defined(__HIPCC_RTC__)
#if defined(__SSE2__) || (defined(_MSC_VER) && !defined(_M_ARM64))
#include <emmintrin.h> // SSE2
#endif
#ifdef __AVX2__
#include <immintrin.h> // AVX*
#elif defined(__SSE4_1__) || (defined(_MSC_VER) && !defined(_M_ARM64))
#include <smmintrin.h> // SSE4
#ifdef __FMA__
#include <immintrin.h> // FMA
#endif
#endif
#if defined(_MSC_VER) && defined(_M_ARM64)
#define __ARM_NEON
#endif
#ifdef __ARM_NEON
#include <arm_neon.h>
#endif
#if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_SME)
#include <arm_sve.h>
#endif
#if defined(__powerpc__) && defined(__GNUC__) && !defined(__clang__) && !defined(__xlC__)
#include <ppu_intrinsics.h>
#endif
#ifdef __ALTIVEC__
#include <altivec.h>
#undef bool
#ifndef _ARCH_PWR8
#include <pveclib/vec_int64_ppc.h>
#endif
#endif
#ifdef __riscv_v
#include <riscv_vector.h>
#endif
#endif
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
#define QUALIFIERS static __forceinline__ __device__
#elif defined(__OPENCL_VERSION__)
#define QUALIFIERS static inline
#else
#define QUALIFIERS inline
#include "myintrin.h"
#endif
#if defined(__ARM_FEATURE_SME)
#define SVE_QUALIFIERS __attribute__((arm_streaming_compatible)) QUALIFIERS
#else
#define SVE_QUALIFIERS QUALIFIERS
#endif
#define PHILOX_W32_0 (0x9E3779B9)
#define PHILOX_W32_1 (0xBB67AE85)
#define PHILOX_M4x32_0 (0xD2511F53)
#define PHILOX_M4x32_1 (0xCD9E8D57)
#define TWOPOW53_INV_DOUBLE (1.1102230246251565e-16)
#define TWOPOW32_INV_FLOAT (2.3283064e-10f)
#ifdef __OPENCL_VERSION__
#include "opencl_stdint.h"
typedef uint32_t uint32;
typedef uint64_t uint64;
#else
#ifndef __HIPCC_RTC__
#include <cstdint>
#endif
typedef std::uint32_t uint32;
typedef std::uint64_t uint64;
#endif
#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_SVE_BITS) && __ARM_FEATURE_SVE_BITS > 0
typedef svfloat32_t svfloat32_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
typedef svfloat64_t svfloat64_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
#elif defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_SME)
typedef svfloat32_t svfloat32_st;
typedef svfloat64_t svfloat64_st;
#endif
QUALIFIERS uint32 mulhilo32(uint32 a, uint32 b, uint32* hip)
{
#if !defined(__CUDA_ARCH__) && !defined(__HIP_DEVICE_COMPILE__)
// host code
#if defined(__powerpc__) && (!defined(__clang__) || defined(__xlC__))
*hip = __mulhwu(a,b);
return a*b;
#elif defined(__OPENCL_VERSION__)
*hip = mul_hi(a,b);
return a*b;
#else
uint64 product = ((uint64)a) * ((uint64)b);
*hip = product >> 32;
return (uint32)product;
#endif
#else
// device code
*hip = __umulhi(a,b);
return a*b;
#endif
}
QUALIFIERS void _philox4x32round(uint32* ctr, uint32* key)
{
uint32 hi0;
uint32 hi1;
uint32 lo0 = mulhilo32(PHILOX_M4x32_0, ctr[0], &hi0);
uint32 lo1 = mulhilo32(PHILOX_M4x32_1, ctr[2], &hi1);
ctr[0] = hi1^ctr[1]^key[0];
ctr[1] = lo1;
ctr[2] = hi0^ctr[3]^key[1];
ctr[3] = lo0;
}
QUALIFIERS void _philox4x32bumpkey(uint32* key)
{
key[0] += PHILOX_W32_0;
key[1] += PHILOX_W32_1;
}
QUALIFIERS double _uniform_double_hq(uint32 x, uint32 y)
{
uint64 z = (uint64)x ^ ((uint64)y << (53 - 32));
return z * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0);
}
QUALIFIERS void philox_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
#ifdef __OPENCL_VERSION__
double * rnd1, double * rnd2)
#else
double & rnd1, double & rnd2)
#endif
{
uint32 key[2] = {key0, key1};
uint32 ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
#ifdef __OPENCL_VERSION__
*rnd1 = _uniform_double_hq(ctr[0], ctr[1]);
*rnd2 = _uniform_double_hq(ctr[2], ctr[3]);
#else
rnd1 = _uniform_double_hq(ctr[0], ctr[1]);
rnd2 = _uniform_double_hq(ctr[2], ctr[3]);
#endif
}
QUALIFIERS void philox_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
#ifdef __OPENCL_VERSION__
float * rnd1, float * rnd2, float * rnd3, float * rnd4)
#else
float & rnd1, float & rnd2, float & rnd3, float & rnd4)
#endif
{
uint32 key[2] = {key0, key1};
uint32 ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
#ifdef __OPENCL_VERSION__
*rnd1 = ctr[0] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f);
*rnd2 = ctr[1] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f);
*rnd3 = ctr[2] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f);
*rnd4 = ctr[3] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f);
#else
rnd1 = ctr[0] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f);
rnd2 = ctr[1] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f);
rnd3 = ctr[2] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f);
rnd4 = ctr[3] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f);
#endif
}
#if !defined(__CUDA_ARCH__) && !defined(__OPENCL_VERSION__) && !defined(__HIP_DEVICE_COMPILE__)
#if defined(__SSE4_1__) || (defined(_MSC_VER) && !defined(_M_ARM64))
QUALIFIERS void _philox4x32round(__m128i* ctr, __m128i* key)
{
__m128i lohi0a = _mm_mul_epu32(ctr[0], _mm_set1_epi32(PHILOX_M4x32_0));
__m128i lohi0b = _mm_mul_epu32(_mm_srli_epi64(ctr[0], 32), _mm_set1_epi32(PHILOX_M4x32_0));
__m128i lohi1a = _mm_mul_epu32(ctr[2], _mm_set1_epi32(PHILOX_M4x32_1));
__m128i lohi1b = _mm_mul_epu32(_mm_srli_epi64(ctr[2], 32), _mm_set1_epi32(PHILOX_M4x32_1));
lohi0a = _mm_shuffle_epi32(lohi0a, 0xD8);
lohi0b = _mm_shuffle_epi32(lohi0b, 0xD8);
lohi1a = _mm_shuffle_epi32(lohi1a, 0xD8);
lohi1b = _mm_shuffle_epi32(lohi1b, 0xD8);
__m128i lo0 = _mm_unpacklo_epi32(lohi0a, lohi0b);
__m128i hi0 = _mm_unpackhi_epi32(lohi0a, lohi0b);
__m128i lo1 = _mm_unpacklo_epi32(lohi1a, lohi1b);
__m128i hi1 = _mm_unpackhi_epi32(lohi1a, lohi1b);
ctr[0] = _mm_xor_si128(_mm_xor_si128(hi1, ctr[1]), key[0]);
ctr[1] = lo1;
ctr[2] = _mm_xor_si128(_mm_xor_si128(hi0, ctr[3]), key[1]);
ctr[3] = lo0;
}
QUALIFIERS void _philox4x32bumpkey(__m128i* key)
{
key[0] = _mm_add_epi32(key[0], _mm_set1_epi32(PHILOX_W32_0));
key[1] = _mm_add_epi32(key[1], _mm_set1_epi32(PHILOX_W32_1));
}
template<bool high>
QUALIFIERS __m128d _uniform_double_hq(__m128i x, __m128i y)
{
// convert 32 to 64 bit
if (high)
{
x = _mm_unpackhi_epi32(x, _mm_setzero_si128());
y = _mm_unpackhi_epi32(y, _mm_setzero_si128());
}
else
{
x = _mm_unpacklo_epi32(x, _mm_setzero_si128());
y = _mm_unpacklo_epi32(y, _mm_setzero_si128());
}
// calculate z = x ^ y << (53 - 32))
__m128i z = _mm_sll_epi64(y, _mm_set1_epi64x(53 - 32));
z = _mm_xor_si128(x, z);
// convert uint64 to double
__m128d rs = _my_cvtepu64_pd(z);
// calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
#ifdef __FMA__
rs = _mm_fmadd_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE), _mm_set1_pd(TWOPOW53_INV_DOUBLE/2.0));
#else
rs = _mm_mul_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE));
rs = _mm_add_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE/2.0));
#endif
return rs;
}
QUALIFIERS void philox_float4(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i ctr3,
uint32 key0, uint32 key1,
__m128 & rnd1, __m128 & rnd2, __m128 & rnd3, __m128 & rnd4)
{
__m128i key[2] = {_mm_set1_epi32(key0), _mm_set1_epi32(key1)};
__m128i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
// convert uint32 to float
rnd1 = _my_cvtepu32_ps(ctr[0]);
rnd2 = _my_cvtepu32_ps(ctr[1]);
rnd3 = _my_cvtepu32_ps(ctr[2]);
rnd4 = _my_cvtepu32_ps(ctr[3]);
// calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
#ifdef __FMA__
rnd1 = _mm_fmadd_ps(rnd1, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd2 = _mm_fmadd_ps(rnd2, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd3 = _mm_fmadd_ps(rnd3, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd4 = _mm_fmadd_ps(rnd4, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0));
#else
rnd1 = _mm_mul_ps(rnd1, _mm_set1_ps(TWOPOW32_INV_FLOAT));
rnd1 = _mm_add_ps(rnd1, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd2 = _mm_mul_ps(rnd2, _mm_set1_ps(TWOPOW32_INV_FLOAT));
rnd2 = _mm_add_ps(rnd2, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd3 = _mm_mul_ps(rnd3, _mm_set1_ps(TWOPOW32_INV_FLOAT));
rnd3 = _mm_add_ps(rnd3, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd4 = _mm_mul_ps(rnd4, _mm_set1_ps(TWOPOW32_INV_FLOAT));
rnd4 = _mm_add_ps(rnd4, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
#endif
}
QUALIFIERS void philox_double2(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i ctr3,
uint32 key0, uint32 key1,
__m128d & rnd1lo, __m128d & rnd1hi, __m128d & rnd2lo, __m128d & rnd2hi)
{
__m128i key[2] = {_mm_set1_epi32(key0), _mm_set1_epi32(key1)};
__m128i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]);
rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]);
rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]);
rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]);
}
QUALIFIERS void philox_float4(uint32 ctr0, __m128i ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__m128 & rnd1, __m128 & rnd2, __m128 & rnd3, __m128 & rnd4)
{
__m128i ctr0v = _mm_set1_epi32(ctr0);
__m128i ctr2v = _mm_set1_epi32(ctr2);
__m128i ctr3v = _mm_set1_epi32(ctr3);
philox_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4);
}
QUALIFIERS void philox_double2(uint32 ctr0, __m128i ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__m128d & rnd1lo, __m128d & rnd1hi, __m128d & rnd2lo, __m128d & rnd2hi)
{
__m128i ctr0v = _mm_set1_epi32(ctr0);
__m128i ctr2v = _mm_set1_epi32(ctr2);
__m128i ctr3v = _mm_set1_epi32(ctr3);
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
}
QUALIFIERS void philox_double2(uint32 ctr0, __m128i ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__m128d & rnd1, __m128d & rnd2)
{
__m128i ctr0v = _mm_set1_epi32(ctr0);
__m128i ctr2v = _mm_set1_epi32(ctr2);
__m128i ctr3v = _mm_set1_epi32(ctr3);
__m128d ignore;
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, ignore, rnd2, ignore);
}
#endif
#ifdef __ALTIVEC__
QUALIFIERS void _philox4x32round(__vector unsigned int* ctr, __vector unsigned int* key)
{
#ifndef _ARCH_PWR8
__vector unsigned int lo0 = vec_mul(ctr[0], vec_splats(PHILOX_M4x32_0));
__vector unsigned int hi0 = vec_mulhuw(ctr[0], vec_splats(PHILOX_M4x32_0));
__vector unsigned int lo1 = vec_mul(ctr[2], vec_splats(PHILOX_M4x32_1));
__vector unsigned int hi1 = vec_mulhuw(ctr[2], vec_splats(PHILOX_M4x32_1));
#elif defined(_ARCH_PWR10)
__vector unsigned int lo0 = vec_mul(ctr[0], vec_splats(PHILOX_M4x32_0));
__vector unsigned int hi0 = vec_mulh(ctr[0], vec_splats(PHILOX_M4x32_0));
__vector unsigned int lo1 = vec_mul(ctr[2], vec_splats(PHILOX_M4x32_1));
__vector unsigned int hi1 = vec_mulh(ctr[2], vec_splats(PHILOX_M4x32_1));
#else
__vector unsigned int lohi0a = (__vector unsigned int) vec_mule(ctr[0], vec_splats(PHILOX_M4x32_0));
__vector unsigned int lohi0b = (__vector unsigned int) vec_mulo(ctr[0], vec_splats(PHILOX_M4x32_0));
__vector unsigned int lohi1a = (__vector unsigned int) vec_mule(ctr[2], vec_splats(PHILOX_M4x32_1));
__vector unsigned int lohi1b = (__vector unsigned int) vec_mulo(ctr[2], vec_splats(PHILOX_M4x32_1));
#ifdef __LITTLE_ENDIAN__
__vector unsigned int lo0 = vec_mergee(lohi0a, lohi0b);
__vector unsigned int lo1 = vec_mergee(lohi1a, lohi1b);
__vector unsigned int hi0 = vec_mergeo(lohi0a, lohi0b);
__vector unsigned int hi1 = vec_mergeo(lohi1a, lohi1b);
#else
__vector unsigned int lo0 = vec_mergeo(lohi0a, lohi0b);
__vector unsigned int lo1 = vec_mergeo(lohi1a, lohi1b);
__vector unsigned int hi0 = vec_mergee(lohi0a, lohi0b);
__vector unsigned int hi1 = vec_mergee(lohi1a, lohi1b);
#endif
#endif
ctr[0] = vec_xor(vec_xor(hi1, ctr[1]), key[0]);
ctr[1] = lo1;
ctr[2] = vec_xor(vec_xor(hi0, ctr[3]), key[1]);
ctr[3] = lo0;
}
QUALIFIERS void _philox4x32bumpkey(__vector unsigned int* key)
{
key[0] = vec_add(key[0], vec_splats(PHILOX_W32_0));
key[1] = vec_add(key[1], vec_splats(PHILOX_W32_1));
}
#ifdef __VSX__
template<bool high>
QUALIFIERS __vector double _uniform_double_hq(__vector unsigned int x, __vector unsigned int y)
{
// convert 32 to 64 bit
#ifdef __LITTLE_ENDIAN__
if (high)
{
x = vec_mergel(x, vec_splats(0U));
y = vec_mergel(y, vec_splats(0U));
}
else
{
x = vec_mergeh(x, vec_splats(0U));
y = vec_mergeh(y, vec_splats(0U));
}
#else
if (high)
{
x = vec_mergel(vec_splats(0U), x);
y = vec_mergel(vec_splats(0U), y);
}
else
{
x = vec_mergeh(vec_splats(0U), x);
y = vec_mergeh(vec_splats(0U), y);
}
#endif
// calculate z = x ^ y << (53 - 32))
#ifdef _ARCH_PWR8
__vector unsigned long long z = vec_sl((__vector unsigned long long) y, vec_splats(53ULL - 32ULL));
#else
__vector unsigned long long z = vec_vsld((__vector unsigned long long) y, vec_splats(53ULL - 32ULL));
#endif
z = vec_xor((__vector unsigned long long) x, z);
// convert uint64 to double
#ifdef __xlC__
__vector double rs = vec_ctd(z, 0);
#else
__vector double rs = vec_ctf(z, 0);
#endif
// calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
rs = vec_madd(rs, vec_splats(TWOPOW53_INV_DOUBLE), vec_splats(TWOPOW53_INV_DOUBLE/2.0));
return rs;
}
#endif
QUALIFIERS void philox_float4(__vector unsigned int ctr0, __vector unsigned int ctr1, __vector unsigned int ctr2, __vector unsigned int ctr3,
uint32 key0, uint32 key1,
__vector float & rnd1, __vector float & rnd2, __vector float & rnd3, __vector float & rnd4)
{
__vector unsigned int key[2] = {vec_splats(key0), vec_splats(key1)};
__vector unsigned int ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
// convert uint32 to float
rnd1 = vec_ctf(ctr[0], 0);
rnd2 = vec_ctf(ctr[1], 0);
rnd3 = vec_ctf(ctr[2], 0);
rnd4 = vec_ctf(ctr[3], 0);
// calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
rnd1 = vec_madd(rnd1, vec_splats(TWOPOW32_INV_FLOAT), vec_splats(TWOPOW32_INV_FLOAT/2.0f));
rnd2 = vec_madd(rnd2, vec_splats(TWOPOW32_INV_FLOAT), vec_splats(TWOPOW32_INV_FLOAT/2.0f));
rnd3 = vec_madd(rnd3, vec_splats(TWOPOW32_INV_FLOAT), vec_splats(TWOPOW32_INV_FLOAT/2.0f));
rnd4 = vec_madd(rnd4, vec_splats(TWOPOW32_INV_FLOAT), vec_splats(TWOPOW32_INV_FLOAT/2.0f));
}
#ifdef __VSX__
QUALIFIERS void philox_double2(__vector unsigned int ctr0, __vector unsigned int ctr1, __vector unsigned int ctr2, __vector unsigned int ctr3,
uint32 key0, uint32 key1,
__vector double & rnd1lo, __vector double & rnd1hi, __vector double & rnd2lo, __vector double & rnd2hi)
{
__vector unsigned int key[2] = {vec_splats(key0), vec_splats(key1)};
__vector unsigned int ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]);
rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]);
rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]);
rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]);
}
#endif
QUALIFIERS void philox_float4(uint32 ctr0, __vector unsigned int ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__vector float & rnd1, __vector float & rnd2, __vector float & rnd3, __vector float & rnd4)
{
__vector unsigned int ctr0v = vec_splats(ctr0);
__vector unsigned int ctr2v = vec_splats(ctr2);
__vector unsigned int ctr3v = vec_splats(ctr3);
philox_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4);
}
QUALIFIERS void philox_float4(uint32 ctr0, __vector int ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__vector float & rnd1, __vector float & rnd2, __vector float & rnd3, __vector float & rnd4)
{
philox_float4(ctr0, (__vector unsigned int) ctr1, ctr2, ctr3, key0, key1, rnd1, rnd2, rnd3, rnd4);
}
#ifdef __VSX__
QUALIFIERS void philox_double2(uint32 ctr0, __vector unsigned int ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__vector double & rnd1lo, __vector double & rnd1hi, __vector double & rnd2lo, __vector double & rnd2hi)
{
__vector unsigned int ctr0v = vec_splats(ctr0);
__vector unsigned int ctr2v = vec_splats(ctr2);
__vector unsigned int ctr3v = vec_splats(ctr3);
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
}
QUALIFIERS void philox_double2(uint32 ctr0, __vector unsigned int ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__vector double & rnd1, __vector double & rnd2)
{
__vector unsigned int ctr0v = vec_splats(ctr0);
__vector unsigned int ctr2v = vec_splats(ctr2);
__vector unsigned int ctr3v = vec_splats(ctr3);
__vector double ignore;
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, ignore, rnd2, ignore);
}
QUALIFIERS void philox_double2(uint32 ctr0, __vector int ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__vector double & rnd1, __vector double & rnd2)
{
philox_double2(ctr0, (__vector unsigned int) ctr1, ctr2, ctr3, key0, key1, rnd1, rnd2);
}
#endif
#endif
#if defined(__ARM_NEON)
QUALIFIERS void _philox4x32round(uint32x4_t* ctr, uint32x4_t* key)
{
uint32x4_t lohi0a = vreinterpretq_u32_u64(vmull_u32(vget_low_u32(ctr[0]), vdup_n_u32(PHILOX_M4x32_0)));
uint32x4_t lohi0b = vreinterpretq_u32_u64(vmull_high_u32(ctr[0], vdupq_n_u32(PHILOX_M4x32_0)));
uint32x4_t lohi1a = vreinterpretq_u32_u64(vmull_u32(vget_low_u32(ctr[2]), vdup_n_u32(PHILOX_M4x32_1)));
uint32x4_t lohi1b = vreinterpretq_u32_u64(vmull_high_u32(ctr[2], vdupq_n_u32(PHILOX_M4x32_1)));
uint32x4_t lo0 = vuzp1q_u32(lohi0a, lohi0b);
uint32x4_t lo1 = vuzp1q_u32(lohi1a, lohi1b);
uint32x4_t hi0 = vuzp2q_u32(lohi0a, lohi0b);
uint32x4_t hi1 = vuzp2q_u32(lohi1a, lohi1b);
ctr[0] = veorq_u32(veorq_u32(hi1, ctr[1]), key[0]);
ctr[1] = lo1;
ctr[2] = veorq_u32(veorq_u32(hi0, ctr[3]), key[1]);
ctr[3] = lo0;
}
QUALIFIERS void _philox4x32bumpkey(uint32x4_t* key)
{
key[0] = vaddq_u32(key[0], vdupq_n_u32(PHILOX_W32_0));
key[1] = vaddq_u32(key[1], vdupq_n_u32(PHILOX_W32_1));
}
template<bool high>
QUALIFIERS float64x2_t _uniform_double_hq(uint32x4_t x, uint32x4_t y)
{
// convert 32 to 64 bit
if (high)
{
x = vzip2q_u32(x, vdupq_n_u32(0));
y = vzip2q_u32(y, vdupq_n_u32(0));
}
else
{
x = vzip1q_u32(x, vdupq_n_u32(0));
y = vzip1q_u32(y, vdupq_n_u32(0));
}
// calculate z = x ^ y << (53 - 32))
uint64x2_t z = vshlq_n_u64(vreinterpretq_u64_u32(y), 53 - 32);
z = veorq_u64(vreinterpretq_u64_u32(x), z);
// convert uint64 to double
float64x2_t rs = vcvtq_f64_u64(z);
// calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
rs = vfmaq_f64(vdupq_n_f64(TWOPOW53_INV_DOUBLE/2.0), vdupq_n_f64(TWOPOW53_INV_DOUBLE), rs);
return rs;
}
QUALIFIERS void philox_float4(uint32x4_t ctr0, uint32x4_t ctr1, uint32x4_t ctr2, uint32x4_t ctr3,
uint32 key0, uint32 key1,
float32x4_t & rnd1, float32x4_t & rnd2, float32x4_t & rnd3, float32x4_t & rnd4)
{
uint32x4_t key[2] = {vdupq_n_u32(key0), vdupq_n_u32(key1)};
uint32x4_t ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
// convert uint32 to float
rnd1 = vcvtq_f32_u32(ctr[0]);
rnd2 = vcvtq_f32_u32(ctr[1]);
rnd3 = vcvtq_f32_u32(ctr[2]);
rnd4 = vcvtq_f32_u32(ctr[3]);
// calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
rnd1 = vfmaq_f32(vdupq_n_f32(TWOPOW32_INV_FLOAT/2.0), vdupq_n_f32(TWOPOW32_INV_FLOAT), rnd1);
rnd2 = vfmaq_f32(vdupq_n_f32(TWOPOW32_INV_FLOAT/2.0), vdupq_n_f32(TWOPOW32_INV_FLOAT), rnd2);
rnd3 = vfmaq_f32(vdupq_n_f32(TWOPOW32_INV_FLOAT/2.0), vdupq_n_f32(TWOPOW32_INV_FLOAT), rnd3);
rnd4 = vfmaq_f32(vdupq_n_f32(TWOPOW32_INV_FLOAT/2.0), vdupq_n_f32(TWOPOW32_INV_FLOAT), rnd4);
}
QUALIFIERS void philox_double2(uint32x4_t ctr0, uint32x4_t ctr1, uint32x4_t ctr2, uint32x4_t ctr3,
uint32 key0, uint32 key1,
float64x2_t & rnd1lo, float64x2_t & rnd1hi, float64x2_t & rnd2lo, float64x2_t & rnd2hi)
{
uint32x4_t key[2] = {vdupq_n_u32(key0), vdupq_n_u32(key1)};
uint32x4_t ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]);
rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]);
rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]);
rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]);
}
QUALIFIERS void philox_float4(uint32 ctr0, uint32x4_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
float32x4_t & rnd1, float32x4_t & rnd2, float32x4_t & rnd3, float32x4_t & rnd4)
{
uint32x4_t ctr0v = vdupq_n_u32(ctr0);
uint32x4_t ctr2v = vdupq_n_u32(ctr2);
uint32x4_t ctr3v = vdupq_n_u32(ctr3);
philox_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4);
}
#ifndef _MSC_VER
QUALIFIERS void philox_float4(uint32 ctr0, int32x4_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
float32x4_t & rnd1, float32x4_t & rnd2, float32x4_t & rnd3, float32x4_t & rnd4)
{
philox_float4(ctr0, vreinterpretq_u32_s32(ctr1), ctr2, ctr3, key0, key1, rnd1, rnd2, rnd3, rnd4);
}
#endif
QUALIFIERS void philox_double2(uint32 ctr0, uint32x4_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
float64x2_t & rnd1lo, float64x2_t & rnd1hi, float64x2_t & rnd2lo, float64x2_t & rnd2hi)
{
uint32x4_t ctr0v = vdupq_n_u32(ctr0);
uint32x4_t ctr2v = vdupq_n_u32(ctr2);
uint32x4_t ctr3v = vdupq_n_u32(ctr3);
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
}
QUALIFIERS void philox_double2(uint32 ctr0, uint32x4_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
float64x2_t & rnd1, float64x2_t & rnd2)
{
uint32x4_t ctr0v = vdupq_n_u32(ctr0);
uint32x4_t ctr2v = vdupq_n_u32(ctr2);
uint32x4_t ctr3v = vdupq_n_u32(ctr3);
float64x2_t ignore;
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, ignore, rnd2, ignore);
}
#ifndef _MSC_VER
QUALIFIERS void philox_double2(uint32 ctr0, int32x4_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
float64x2_t & rnd1, float64x2_t & rnd2)
{
philox_double2(ctr0, vreinterpretq_u32_s32(ctr1), ctr2, ctr3, key0, key1, rnd1, rnd2);
}
#endif
#endif
#if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_SME)
SVE_QUALIFIERS void _philox4x32round(svuint32x4_t & ctr, svuint32x2_t & key)
{
svuint32_t lo0 = svmul_u32_x(svptrue_b32(), svget4_u32(ctr, 0), svdup_u32(PHILOX_M4x32_0));
svuint32_t hi0 = svmulh_u32_x(svptrue_b32(), svget4_u32(ctr, 0), svdup_u32(PHILOX_M4x32_0));
svuint32_t lo1 = svmul_u32_x(svptrue_b32(), svget4_u32(ctr, 2), svdup_u32(PHILOX_M4x32_1));
svuint32_t hi1 = svmulh_u32_x(svptrue_b32(), svget4_u32(ctr, 2), svdup_u32(PHILOX_M4x32_1));
ctr = svset4_u32(ctr, 0, sveor_u32_x(svptrue_b32(), sveor_u32_x(svptrue_b32(), hi1, svget4_u32(ctr, 1)), svget2_u32(key, 0)));
ctr = svset4_u32(ctr, 1, lo1);
ctr = svset4_u32(ctr, 2, sveor_u32_x(svptrue_b32(), sveor_u32_x(svptrue_b32(), hi0, svget4_u32(ctr, 3)), svget2_u32(key, 1)));
ctr = svset4_u32(ctr, 3, lo0);
}
SVE_QUALIFIERS void _philox4x32bumpkey(svuint32x2_t & key)
{
key = svset2_u32(key, 0, svadd_u32_x(svptrue_b32(), svget2_u32(key, 0), svdup_u32(PHILOX_W32_0)));
key = svset2_u32(key, 1, svadd_u32_x(svptrue_b32(), svget2_u32(key, 1), svdup_u32(PHILOX_W32_1)));
}
template<bool high>
SVE_QUALIFIERS svfloat64_t _uniform_double_hq(svuint32_t x, svuint32_t y)
{
// convert 32 to 64 bit
if (high)
{
x = svzip2_u32(x, svdup_u32(0));
y = svzip2_u32(y, svdup_u32(0));
}
else
{
x = svzip1_u32(x, svdup_u32(0));
y = svzip1_u32(y, svdup_u32(0));
}
// calculate z = x ^ y << (53 - 32))
svuint64_t z = svlsl_n_u64_x(svptrue_b64(), svreinterpret_u64_u32(y), 53 - 32);
z = sveor_u64_x(svptrue_b64(), svreinterpret_u64_u32(x), z);
// convert uint64 to double
svfloat64_t rs = svcvt_f64_u64_x(svptrue_b64(), z);
// calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
rs = svmad_f64_x(svptrue_b64(), rs, svdup_f64(TWOPOW53_INV_DOUBLE), svdup_f64(TWOPOW53_INV_DOUBLE/2.0));
return rs;
}
SVE_QUALIFIERS void philox_float4(svuint32_t ctr0, svuint32_t ctr1, svuint32_t ctr2, svuint32_t ctr3,
uint32 key0, uint32 key1,
svfloat32_st & rnd1, svfloat32_st & rnd2, svfloat32_st & rnd3, svfloat32_st & rnd4)
{
svuint32x2_t key = svcreate2_u32(svdup_u32(key0), svdup_u32(key1));
svuint32x4_t ctr = svcreate4_u32(ctr0, ctr1, ctr2, ctr3);
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
// convert uint32 to float
rnd1 = svcvt_f32_u32_x(svptrue_b32(), svget4_u32(ctr, 0));
rnd2 = svcvt_f32_u32_x(svptrue_b32(), svget4_u32(ctr, 1));
rnd3 = svcvt_f32_u32_x(svptrue_b32(), svget4_u32(ctr, 2));
rnd4 = svcvt_f32_u32_x(svptrue_b32(), svget4_u32(ctr, 3));
// calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
rnd1 = svmad_f32_x(svptrue_b32(), rnd1, svdup_f32(TWOPOW32_INV_FLOAT), svdup_f32(TWOPOW32_INV_FLOAT/2.0));
rnd2 = svmad_f32_x(svptrue_b32(), rnd2, svdup_f32(TWOPOW32_INV_FLOAT), svdup_f32(TWOPOW32_INV_FLOAT/2.0));
rnd3 = svmad_f32_x(svptrue_b32(), rnd3, svdup_f32(TWOPOW32_INV_FLOAT), svdup_f32(TWOPOW32_INV_FLOAT/2.0));
rnd4 = svmad_f32_x(svptrue_b32(), rnd4, svdup_f32(TWOPOW32_INV_FLOAT), svdup_f32(TWOPOW32_INV_FLOAT/2.0));
}
SVE_QUALIFIERS void philox_double2(svuint32_t ctr0, svuint32_t ctr1, svuint32_t ctr2, svuint32_t ctr3,
uint32 key0, uint32 key1,
svfloat64_st & rnd1lo, svfloat64_st & rnd1hi, svfloat64_st & rnd2lo, svfloat64_st & rnd2hi)
{
svuint32x2_t key = svcreate2_u32(svdup_u32(key0), svdup_u32(key1));
svuint32x4_t ctr = svcreate4_u32(ctr0, ctr1, ctr2, ctr3);
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
rnd1lo = _uniform_double_hq<false>(svget4_u32(ctr, 0), svget4_u32(ctr, 1));
rnd1hi = _uniform_double_hq<true>(svget4_u32(ctr, 0), svget4_u32(ctr, 1));
rnd2lo = _uniform_double_hq<false>(svget4_u32(ctr, 2), svget4_u32(ctr, 3));
rnd2hi = _uniform_double_hq<true>(svget4_u32(ctr, 2), svget4_u32(ctr, 3));
}
SVE_QUALIFIERS void philox_float4(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
svfloat32_st & rnd1, svfloat32_st & rnd2, svfloat32_st & rnd3, svfloat32_st & rnd4)
{
svuint32_t ctr0v = svdup_u32(ctr0);
svuint32_t ctr2v = svdup_u32(ctr2);
svuint32_t ctr3v = svdup_u32(ctr3);
philox_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4);
}
SVE_QUALIFIERS void philox_float4(uint32 ctr0, svint32_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
svfloat32_st & rnd1, svfloat32_st & rnd2, svfloat32_st & rnd3, svfloat32_st & rnd4)
{
philox_float4(ctr0, svreinterpret_u32_s32(ctr1), ctr2, ctr3, key0, key1, rnd1, rnd2, rnd3, rnd4);
}
SVE_QUALIFIERS void philox_double2(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
svfloat64_st & rnd1lo, svfloat64_st & rnd1hi, svfloat64_st & rnd2lo, svfloat64_st & rnd2hi)
{
svuint32_t ctr0v = svdup_u32(ctr0);
svuint32_t ctr2v = svdup_u32(ctr2);
svuint32_t ctr3v = svdup_u32(ctr3);
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
}
SVE_QUALIFIERS void philox_double2(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
svfloat64_st & rnd1, svfloat64_st & rnd2)
{
svuint32_t ctr0v = svdup_u32(ctr0);
svuint32_t ctr2v = svdup_u32(ctr2);
svuint32_t ctr3v = svdup_u32(ctr3);
svfloat64_st ignore;
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, ignore, rnd2, ignore);
}
SVE_QUALIFIERS void philox_double2(uint32 ctr0, svint32_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
svfloat64_st & rnd1, svfloat64_st & rnd2)
{
philox_double2(ctr0, svreinterpret_u32_s32(ctr1), ctr2, ctr3, key0, key1, rnd1, rnd2);
}
#endif
#if defined(__riscv_v)
QUALIFIERS void _philox4x32round(vuint32m1_t & ctr0, vuint32m1_t & ctr1, vuint32m1_t & ctr2, vuint32m1_t & ctr3,
vuint32m1_t key0, vuint32m1_t key1)
{
vuint32m1_t lo0 = vmul_vv_u32m1(ctr0, vmv_v_x_u32m1(PHILOX_M4x32_0, vsetvlmax_e32m1()), vsetvlmax_e32m1());
vuint32m1_t hi0 = vmulhu_vv_u32m1(ctr0, vmv_v_x_u32m1(PHILOX_M4x32_0, vsetvlmax_e32m1()), vsetvlmax_e32m1());
vuint32m1_t lo1 = vmul_vv_u32m1(ctr2, vmv_v_x_u32m1(PHILOX_M4x32_1, vsetvlmax_e32m1()), vsetvlmax_e32m1());
vuint32m1_t hi1 = vmulhu_vv_u32m1(ctr2, vmv_v_x_u32m1(PHILOX_M4x32_1, vsetvlmax_e32m1()), vsetvlmax_e32m1());
ctr0 = vxor_vv_u32m1(vxor_vv_u32m1(hi1, ctr1, vsetvlmax_e32m1()), key0, vsetvlmax_e32m1());
ctr1 = lo1;
ctr2 = vxor_vv_u32m1(vxor_vv_u32m1(hi0, ctr3, vsetvlmax_e32m1()), key1, vsetvlmax_e32m1());
ctr3 = lo0;
}
QUALIFIERS void _philox4x32bumpkey(vuint32m1_t & key0, vuint32m1_t & key1)
{
key0 = vadd_vv_u32m1(key0, vmv_v_x_u32m1(PHILOX_W32_0, vsetvlmax_e32m1()), vsetvlmax_e32m1());
key1 = vadd_vv_u32m1(key1, vmv_v_x_u32m1(PHILOX_W32_1, vsetvlmax_e32m1()), vsetvlmax_e32m1());
}
template<bool high>
QUALIFIERS vfloat64m1_t _uniform_double_hq(vuint32m1_t x, vuint32m1_t y)
{
// convert 32 to 64 bit
if (high)
{
size_t s = vsetvlmax_e32m1();
x = vslidedown_vx_u32m1(vundefined_u32m1(), x, s/2, s);
y = vslidedown_vx_u32m1(vundefined_u32m1(), y, s/2, s);
}
vuint64m1_t x64 = vwcvtu_x_x_v_u64m1(vlmul_trunc_v_u32m1_u32mf2(x), vsetvlmax_e64m1());
vuint64m1_t y64 = vwcvtu_x_x_v_u64m1(vlmul_trunc_v_u32m1_u32mf2(y), vsetvlmax_e64m1());
// calculate z = x ^ y << (53 - 32))
vuint64m1_t z = vsll_vx_u64m1(y64, 53 - 32, vsetvlmax_e64m1());
z = vxor_vv_u64m1(x64, z, vsetvlmax_e64m1());
// convert uint64 to double
vfloat64m1_t rs = vfcvt_f_xu_v_f64m1(z, vsetvlmax_e64m1());
// calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
rs = vfmadd_vv_f64m1(rs, vfmv_v_f_f64m1(TWOPOW53_INV_DOUBLE, vsetvlmax_e64m1()), vfmv_v_f_f64m1(TWOPOW53_INV_DOUBLE/2.0, vsetvlmax_e64m1()), vsetvlmax_e64m1());
return rs;
}
QUALIFIERS void philox_float4(vuint32m1_t ctr0, vuint32m1_t ctr1, vuint32m1_t ctr2, vuint32m1_t ctr3,
uint32 key0, uint32 key1,
vfloat32m1_t & rnd1, vfloat32m1_t & rnd2, vfloat32m1_t & rnd3, vfloat32m1_t & rnd4)
{
vuint32m1_t key0v = vmv_v_x_u32m1(key0, vsetvlmax_e32m1());
vuint32m1_t key1v = vmv_v_x_u32m1(key1, vsetvlmax_e32m1());
_philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 1
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 2
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 3
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 4
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 5
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 6
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 7
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 8
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 9
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 10
// convert uint32 to float
rnd1 = vfcvt_f_xu_v_f32m1(ctr0, vsetvlmax_e32m1());
rnd2 = vfcvt_f_xu_v_f32m1(ctr1, vsetvlmax_e32m1());
rnd3 = vfcvt_f_xu_v_f32m1(ctr2, vsetvlmax_e32m1());
rnd4 = vfcvt_f_xu_v_f32m1(ctr3, vsetvlmax_e32m1());
// calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
rnd1 = vfmadd_vv_f32m1(rnd1, vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT, vsetvlmax_e32m1()), vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT/2.0, vsetvlmax_e32m1()), vsetvlmax_e32m1());
rnd2 = vfmadd_vv_f32m1(rnd2, vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT, vsetvlmax_e32m1()), vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT/2.0, vsetvlmax_e32m1()), vsetvlmax_e32m1());
rnd3 = vfmadd_vv_f32m1(rnd3, vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT, vsetvlmax_e32m1()), vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT/2.0, vsetvlmax_e32m1()), vsetvlmax_e32m1());
rnd4 = vfmadd_vv_f32m1(rnd4, vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT, vsetvlmax_e32m1()), vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT/2.0, vsetvlmax_e32m1()), vsetvlmax_e32m1());
}
QUALIFIERS void philox_double2(vuint32m1_t ctr0, vuint32m1_t ctr1, vuint32m1_t ctr2, vuint32m1_t ctr3,
uint32 key0, uint32 key1,
vfloat64m1_t & rnd1lo, vfloat64m1_t & rnd1hi, vfloat64m1_t & rnd2lo, vfloat64m1_t & rnd2hi)
{
vuint32m1_t key0v = vmv_v_x_u32m1(key0, vsetvlmax_e32m1());
vuint32m1_t key1v = vmv_v_x_u32m1(key1, vsetvlmax_e32m1());
_philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 1
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 2
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 3
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 4
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 5
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 6
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 7
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 8
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 9
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 10
rnd1lo = _uniform_double_hq<false>(ctr0, ctr1);
rnd1hi = _uniform_double_hq<true>(ctr0, ctr1);
rnd2lo = _uniform_double_hq<false>(ctr2, ctr3);
rnd2hi = _uniform_double_hq<true>(ctr2, ctr3);
}
QUALIFIERS void philox_float4(uint32 ctr0, vuint32m1_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
vfloat32m1_t & rnd1, vfloat32m1_t & rnd2, vfloat32m1_t & rnd3, vfloat32m1_t & rnd4)
{
vuint32m1_t ctr0v = vmv_v_x_u32m1(ctr0, vsetvlmax_e32m1());
vuint32m1_t ctr2v = vmv_v_x_u32m1(ctr2, vsetvlmax_e32m1());
vuint32m1_t ctr3v = vmv_v_x_u32m1(ctr3, vsetvlmax_e32m1());
philox_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4);
}
QUALIFIERS void philox_float4(uint32 ctr0, vint32m1_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
vfloat32m1_t & rnd1, vfloat32m1_t & rnd2, vfloat32m1_t & rnd3, vfloat32m1_t & rnd4)
{
philox_float4(ctr0, vreinterpret_v_i32m1_u32m1(ctr1), ctr2, ctr3, key0, key1, rnd1, rnd2, rnd3, rnd4);
}
QUALIFIERS void philox_double2(uint32 ctr0, vuint32m1_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
vfloat64m1_t & rnd1lo, vfloat64m1_t & rnd1hi, vfloat64m1_t & rnd2lo, vfloat64m1_t & rnd2hi)
{
vuint32m1_t ctr0v = vmv_v_x_u32m1(ctr0, vsetvlmax_e32m1());
vuint32m1_t ctr2v = vmv_v_x_u32m1(ctr2, vsetvlmax_e32m1());
vuint32m1_t ctr3v = vmv_v_x_u32m1(ctr3, vsetvlmax_e32m1());
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
}
QUALIFIERS void philox_double2(uint32 ctr0, vuint32m1_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
vfloat64m1_t & rnd1, vfloat64m1_t & rnd2)
{
vuint32m1_t ctr0v = vmv_v_x_u32m1(ctr0, vsetvlmax_e32m1());
vuint32m1_t ctr2v = vmv_v_x_u32m1(ctr2, vsetvlmax_e32m1());
vuint32m1_t ctr3v = vmv_v_x_u32m1(ctr3, vsetvlmax_e32m1());
vfloat64m1_t ignore;
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, ignore, rnd2, ignore);
}
QUALIFIERS void philox_double2(uint32 ctr0, vint32m1_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
vfloat64m1_t & rnd1, vfloat64m1_t & rnd2)
{
philox_double2(ctr0, vreinterpret_v_i32m1_u32m1(ctr1), ctr2, ctr3, key0, key1, rnd1, rnd2);
}
#endif
#ifdef __AVX2__
QUALIFIERS void _philox4x32round(__m256i* ctr, __m256i* key)
{
__m256i lohi0a = _mm256_mul_epu32(ctr[0], _mm256_set1_epi32(PHILOX_M4x32_0));
__m256i lohi0b = _mm256_mul_epu32(_mm256_srli_epi64(ctr[0], 32), _mm256_set1_epi32(PHILOX_M4x32_0));
__m256i lohi1a = _mm256_mul_epu32(ctr[2], _mm256_set1_epi32(PHILOX_M4x32_1));
__m256i lohi1b = _mm256_mul_epu32(_mm256_srli_epi64(ctr[2], 32), _mm256_set1_epi32(PHILOX_M4x32_1));
lohi0a = _mm256_shuffle_epi32(lohi0a, 0xD8);
lohi0b = _mm256_shuffle_epi32(lohi0b, 0xD8);
lohi1a = _mm256_shuffle_epi32(lohi1a, 0xD8);
lohi1b = _mm256_shuffle_epi32(lohi1b, 0xD8);
__m256i lo0 = _mm256_unpacklo_epi32(lohi0a, lohi0b);
__m256i hi0 = _mm256_unpackhi_epi32(lohi0a, lohi0b);
__m256i lo1 = _mm256_unpacklo_epi32(lohi1a, lohi1b);
__m256i hi1 = _mm256_unpackhi_epi32(lohi1a, lohi1b);
ctr[0] = _mm256_xor_si256(_mm256_xor_si256(hi1, ctr[1]), key[0]);
ctr[1] = lo1;
ctr[2] = _mm256_xor_si256(_mm256_xor_si256(hi0, ctr[3]), key[1]);
ctr[3] = lo0;
}
QUALIFIERS void _philox4x32bumpkey(__m256i* key)
{
key[0] = _mm256_add_epi32(key[0], _mm256_set1_epi32(PHILOX_W32_0));
key[1] = _mm256_add_epi32(key[1], _mm256_set1_epi32(PHILOX_W32_1));
}
template<bool high>
QUALIFIERS __m256d _uniform_double_hq(__m256i x, __m256i y)
{
// convert 32 to 64 bit
if (high)
{
x = _mm256_cvtepu32_epi64(_mm256_extracti128_si256(x, 1));
y = _mm256_cvtepu32_epi64(_mm256_extracti128_si256(y, 1));
}
else
{
x = _mm256_cvtepu32_epi64(_mm256_extracti128_si256(x, 0));
y = _mm256_cvtepu32_epi64(_mm256_extracti128_si256(y, 0));
}
// calculate z = x ^ y << (53 - 32))
__m256i z = _mm256_sll_epi64(y, _mm_set1_epi64x(53 - 32));
z = _mm256_xor_si256(x, z);
// convert uint64 to double
__m256d rs = _my256_cvtepu64_pd(z);
// calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
#ifdef __FMA__
rs = _mm256_fmadd_pd(rs, _mm256_set1_pd(TWOPOW53_INV_DOUBLE), _mm256_set1_pd(TWOPOW53_INV_DOUBLE/2.0));
#else
rs = _mm256_mul_pd(rs, _mm256_set1_pd(TWOPOW53_INV_DOUBLE));
rs = _mm256_add_pd(rs, _mm256_set1_pd(TWOPOW53_INV_DOUBLE/2.0));
#endif
return rs;
}
QUALIFIERS void philox_float4(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i ctr3,
uint32 key0, uint32 key1,
__m256 & rnd1, __m256 & rnd2, __m256 & rnd3, __m256 & rnd4)
{
__m256i key[2] = {_mm256_set1_epi32(key0), _mm256_set1_epi32(key1)};
__m256i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
// convert uint32 to float
rnd1 = _my256_cvtepu32_ps(ctr[0]);
rnd2 = _my256_cvtepu32_ps(ctr[1]);
rnd3 = _my256_cvtepu32_ps(ctr[2]);
rnd4 = _my256_cvtepu32_ps(ctr[3]);
// calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
#ifdef __FMA__
rnd1 = _mm256_fmadd_ps(rnd1, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd2 = _mm256_fmadd_ps(rnd2, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd3 = _mm256_fmadd_ps(rnd3, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd4 = _mm256_fmadd_ps(rnd4, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0));
#else
rnd1 = _mm256_mul_ps(rnd1, _mm256_set1_ps(TWOPOW32_INV_FLOAT));
rnd1 = _mm256_add_ps(rnd1, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd2 = _mm256_mul_ps(rnd2, _mm256_set1_ps(TWOPOW32_INV_FLOAT));
rnd2 = _mm256_add_ps(rnd2, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd3 = _mm256_mul_ps(rnd3, _mm256_set1_ps(TWOPOW32_INV_FLOAT));
rnd3 = _mm256_add_ps(rnd3, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd4 = _mm256_mul_ps(rnd4, _mm256_set1_ps(TWOPOW32_INV_FLOAT));
rnd4 = _mm256_add_ps(rnd4, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
#endif
}
QUALIFIERS void philox_double2(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i ctr3,
uint32 key0, uint32 key1,
__m256d & rnd1lo, __m256d & rnd1hi, __m256d & rnd2lo, __m256d & rnd2hi)
{
__m256i key[2] = {_mm256_set1_epi32(key0), _mm256_set1_epi32(key1)};
__m256i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]);
rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]);
rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]);
rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]);
}
QUALIFIERS void philox_float4(uint32 ctr0, __m256i ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__m256 & rnd1, __m256 & rnd2, __m256 & rnd3, __m256 & rnd4)
{
__m256i ctr0v = _mm256_set1_epi32(ctr0);
__m256i ctr2v = _mm256_set1_epi32(ctr2);
__m256i ctr3v = _mm256_set1_epi32(ctr3);
philox_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4);
}
QUALIFIERS void philox_double2(uint32 ctr0, __m256i ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__m256d & rnd1lo, __m256d & rnd1hi, __m256d & rnd2lo, __m256d & rnd2hi)
{
__m256i ctr0v = _mm256_set1_epi32(ctr0);
__m256i ctr2v = _mm256_set1_epi32(ctr2);
__m256i ctr3v = _mm256_set1_epi32(ctr3);
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
}
QUALIFIERS void philox_double2(uint32 ctr0, __m256i ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__m256d & rnd1, __m256d & rnd2)
{
#if 0
__m256i ctr0v = _mm256_set1_epi32(ctr0);
__m256i ctr2v = _mm256_set1_epi32(ctr2);
__m256i ctr3v = _mm256_set1_epi32(ctr3);
__m256d ignore;
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, ignore, rnd2, ignore);
#else
__m128d rnd1lo, rnd1hi, rnd2lo, rnd2hi;
philox_double2(ctr0, _mm256_extractf128_si256(ctr1, 0), ctr2, ctr3, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
rnd1 = _my256_set_m128d(rnd1hi, rnd1lo);
rnd2 = _my256_set_m128d(rnd2hi, rnd2lo);
#endif
}
#endif
#if defined(__AVX512F__) || defined(__AVX10_512BIT__)
QUALIFIERS void _philox4x32round(__m512i* ctr, __m512i* key)
{
__m512i lohi0a = _mm512_mul_epu32(ctr[0], _mm512_set1_epi32(PHILOX_M4x32_0));
__m512i lohi0b = _mm512_mul_epu32(_mm512_srli_epi64(ctr[0], 32), _mm512_set1_epi32(PHILOX_M4x32_0));
__m512i lohi1a = _mm512_mul_epu32(ctr[2], _mm512_set1_epi32(PHILOX_M4x32_1));
__m512i lohi1b = _mm512_mul_epu32(_mm512_srli_epi64(ctr[2], 32), _mm512_set1_epi32(PHILOX_M4x32_1));
lohi0a = _mm512_shuffle_epi32(lohi0a, _MM_PERM_DBCA);
lohi0b = _mm512_shuffle_epi32(lohi0b, _MM_PERM_DBCA);
lohi1a = _mm512_shuffle_epi32(lohi1a, _MM_PERM_DBCA);
lohi1b = _mm512_shuffle_epi32(lohi1b, _MM_PERM_DBCA);
__m512i lo0 = _mm512_unpacklo_epi32(lohi0a, lohi0b);
__m512i hi0 = _mm512_unpackhi_epi32(lohi0a, lohi0b);
__m512i lo1 = _mm512_unpacklo_epi32(lohi1a, lohi1b);
__m512i hi1 = _mm512_unpackhi_epi32(lohi1a, lohi1b);
ctr[0] = _mm512_xor_si512(_mm512_xor_si512(hi1, ctr[1]), key[0]);
ctr[1] = lo1;
ctr[2] = _mm512_xor_si512(_mm512_xor_si512(hi0, ctr[3]), key[1]);
ctr[3] = lo0;
}
QUALIFIERS void _philox4x32bumpkey(__m512i* key)
{
key[0] = _mm512_add_epi32(key[0], _mm512_set1_epi32(PHILOX_W32_0));
key[1] = _mm512_add_epi32(key[1], _mm512_set1_epi32(PHILOX_W32_1));
}
template<bool high>
QUALIFIERS __m512d _uniform_double_hq(__m512i x, __m512i y)
{
// convert 32 to 64 bit
if (high)
{
x = _mm512_cvtepu32_epi64(_mm512_extracti64x4_epi64(x, 1));
y = _mm512_cvtepu32_epi64(_mm512_extracti64x4_epi64(y, 1));
}
else
{
x = _mm512_cvtepu32_epi64(_mm512_extracti64x4_epi64(x, 0));
y = _mm512_cvtepu32_epi64(_mm512_extracti64x4_epi64(y, 0));
}
// calculate z = x ^ y << (53 - 32))
__m512i z = _mm512_sll_epi64(y, _mm_set1_epi64x(53 - 32));
z = _mm512_xor_si512(x, z);
// convert uint64 to double
__m512d rs = _mm512_cvtepu64_pd(z);
// calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
rs = _mm512_fmadd_pd(rs, _mm512_set1_pd(TWOPOW53_INV_DOUBLE), _mm512_set1_pd(TWOPOW53_INV_DOUBLE/2.0));
return rs;
}
QUALIFIERS void philox_float4(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i ctr3,
uint32 key0, uint32 key1,
__m512 & rnd1, __m512 & rnd2, __m512 & rnd3, __m512 & rnd4)
{
__m512i key[2] = {_mm512_set1_epi32(key0), _mm512_set1_epi32(key1)};
__m512i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
// convert uint32 to float
rnd1 = _mm512_cvtepu32_ps(ctr[0]);
rnd2 = _mm512_cvtepu32_ps(ctr[1]);
rnd3 = _mm512_cvtepu32_ps(ctr[2]);
rnd4 = _mm512_cvtepu32_ps(ctr[3]);
// calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
rnd1 = _mm512_fmadd_ps(rnd1, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd2 = _mm512_fmadd_ps(rnd2, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd3 = _mm512_fmadd_ps(rnd3, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd4 = _mm512_fmadd_ps(rnd4, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0));
}
QUALIFIERS void philox_double2(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i ctr3,
uint32 key0, uint32 key1,
__m512d & rnd1lo, __m512d & rnd1hi, __m512d & rnd2lo, __m512d & rnd2hi)
{
__m512i key[2] = {_mm512_set1_epi32(key0), _mm512_set1_epi32(key1)};
__m512i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]);
rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]);
rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]);
rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]);
}
QUALIFIERS void philox_float4(uint32 ctr0, __m512i ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__m512 & rnd1, __m512 & rnd2, __m512 & rnd3, __m512 & rnd4)
{
__m512i ctr0v = _mm512_set1_epi32(ctr0);
__m512i ctr2v = _mm512_set1_epi32(ctr2);
__m512i ctr3v = _mm512_set1_epi32(ctr3);
philox_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4);
}
QUALIFIERS void philox_double2(uint32 ctr0, __m512i ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__m512d & rnd1lo, __m512d & rnd1hi, __m512d & rnd2lo, __m512d & rnd2hi)
{
__m512i ctr0v = _mm512_set1_epi32(ctr0);
__m512i ctr2v = _mm512_set1_epi32(ctr2);
__m512i ctr3v = _mm512_set1_epi32(ctr3);
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
}
QUALIFIERS void philox_double2(uint32 ctr0, __m512i ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__m512d & rnd1, __m512d & rnd2)
{
#if 0
__m512i ctr0v = _mm512_set1_epi32(ctr0);
__m512i ctr2v = _mm512_set1_epi32(ctr2);
__m512i ctr3v = _mm512_set1_epi32(ctr3);
__m512d ignore;
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, ignore, rnd2, ignore);
#else
__m256d rnd1lo, rnd1hi, rnd2lo, rnd2hi;
philox_double2(ctr0, _mm512_extracti64x4_epi64(ctr1, 0), ctr2, ctr3, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
rnd1 = _my512_set_m256d(rnd1hi, rnd1lo);
rnd2 = _my512_set_m256d(rnd2hi, rnd2lo);
#endif
}
#endif
#endif
/*
Copyright 2021, Michael Kuron.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions, and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions, and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include <altivec.h>
#undef vector
#undef bool
inline void cachelineZero(void * p) {
#ifdef __xlC__
__dcbz(p);
#else
__asm__ volatile("dcbz 0, %0"::"r"(p):"memory");
#endif
}
inline size_t _cachelineSize() {
// allocate and fill with ones
const size_t max_size = 0x100000;
uint8_t data[2*max_size];
for (size_t i = 0; i < 2*max_size; ++i) {
data[i] = 0xff;
}
// find alignment offset
size_t offset = max_size - ((uintptr_t) data) % max_size;
// zero a cacheline
cachelineZero((void*) (data + offset));
// make sure that at least one byte was zeroed
if (data[offset] != 0) {
return SIZE_MAX;
}
// make sure that nothing was zeroed before the pointer
if (data[offset-1] == 0) {
return SIZE_MAX;
}
// find the last byte that was zeroed
for (size_t size = 1; size < max_size; ++size) {
if (data[offset + size] != 0) {
return size;
}
}
// too much was zeroed
return SIZE_MAX;
}
inline size_t cachelineSize() {
static size_t size = _cachelineSize();
return size;
}
/*
Copyright 2023, Michael Kuron.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions, and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions, and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
inline void cachelineZero(void * p) {
#ifdef __riscv_zicboz
__asm__ volatile("cbo.zero (%0)"::"r"(p):"memory");
#endif
}
inline size_t _cachelineSize() {
// allocate and fill with ones
const size_t max_size = 0x100000;
uint8_t data[2*max_size];
for (size_t i = 0; i < 2*max_size; ++i) {
data[i] = 0xff;
}
// find alignment offset
size_t offset = max_size - ((uintptr_t) data) % max_size;
// zero a cacheline
cachelineZero((void*) (data + offset));
// make sure that at least one byte was zeroed
if (data[offset] != 0) {
return SIZE_MAX;
}
// make sure that nothing was zeroed before the pointer
if (data[offset-1] == 0) {
return SIZE_MAX;
}
// find the last byte that was zeroed
for (size_t size = 1; size < max_size; ++size) {
if (data[offset + size] != 0) {
return size;
}
}
// too much was zeroed
return SIZE_MAX;
}
inline size_t cachelineSize() {
#ifdef __riscv_zicboz
static size_t size = _cachelineSize();
return size;
#else
return SIZE_MAX;
#endif
}
# TODO #47 move to a module functions
import numpy as np
import sympy as sp
from pystencils.typing import CastFunc, collate_types, create_type, get_type_of_expression
from pystencils.sympyextensions import is_integer_sequence
class IntegerFunctionTwoArgsMixIn(sp.Function):
is_integer = True
def __new__(cls, arg1, arg2):
args = []
for a in (arg1, arg2):
if isinstance(a, sp.Number) or isinstance(a, int):
args.append(CastFunc(a, create_type("int")))
elif isinstance(a, np.generic):
args.append(CastFunc(a, a.dtype))
else:
args.append(a)
for a in args:
try:
dtype = get_type_of_expression(a)
if not dtype.is_int():
raise ValueError("Argument to integer function is not an int but " + str(dtype))
except NotImplementedError:
raise ValueError("Integer functions can only be constructed with typed expressions")
return super().__new__(cls, *args)
def _eval_evalf(self, *pargs, **kwargs):
arg1 = self.args[0].evalf(*pargs, **kwargs) if hasattr(self.args[0], 'evalf') else self.args[0]
arg2 = self.args[1].evalf(*pargs, **kwargs) if hasattr(self.args[1], 'evalf') else self.args[1]
return self._eval_op(arg1, arg2)
def _eval_op(self, arg1, arg2):
return self
# noinspection PyPep8Naming
class bitwise_xor(IntegerFunctionTwoArgsMixIn):
pass
# noinspection PyPep8Naming
class bit_shift_right(IntegerFunctionTwoArgsMixIn):
pass
# noinspection PyPep8Naming
class bit_shift_left(IntegerFunctionTwoArgsMixIn):
pass
# noinspection PyPep8Naming
class bitwise_and(IntegerFunctionTwoArgsMixIn):
pass
# noinspection PyPep8Naming
class bitwise_or(IntegerFunctionTwoArgsMixIn):
pass
# noinspection PyPep8Naming
class int_div(IntegerFunctionTwoArgsMixIn):
def _eval_op(self, arg1, arg2):
return int(arg1 // arg2)
# noinspection PyPep8Naming
class int_power_of_2(IntegerFunctionTwoArgsMixIn):
pass
# noinspection PyPep8Naming
class modulo_floor(sp.Function):
"""Returns the next smaller integer divisible by given divisor.
Examples:
>>> modulo_floor(9, 4)
8
>>> modulo_floor(11, 4)
8
>>> modulo_floor(12, 4)
12
>>> from pystencils import TypedSymbol
>>> a, b = TypedSymbol("a", "int64"), TypedSymbol("b", "int32")
>>> modulo_floor(a, b).to_c(str)
'(int64_t)((a) / (b)) * (b)'
"""
nargs = 2
is_integer = True
def __new__(cls, integer, divisor):
if is_integer_sequence((integer, divisor)):
return (int(integer) // int(divisor)) * divisor
else:
return super().__new__(cls, integer, divisor)
def to_c(self, print_func):
dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1])))
assert dtype.is_int()
return "({dtype})(({0}) / ({1})) * ({1})".format(print_func(self.args[0]),
print_func(self.args[1]), dtype=dtype)
# noinspection PyPep8Naming
class modulo_ceil(sp.Function):
"""Returns the next bigger integer divisible by given divisor.
Examples:
>>> modulo_ceil(9, 4)
12
>>> modulo_ceil(11, 4)
12
>>> modulo_ceil(12, 4)
12
>>> from pystencils import TypedSymbol
>>> a, b = TypedSymbol("a", "int64"), TypedSymbol("b", "int32")
>>> modulo_ceil(a, b).to_c(str)
'((a) % (b) == 0 ? a : ((int64_t)((a) / (b))+1) * (b))'
"""
nargs = 2
is_integer = True
def __new__(cls, integer, divisor):
if is_integer_sequence((integer, divisor)):
return integer if integer % divisor == 0 else ((integer // divisor) + 1) * divisor
else:
return super().__new__(cls, integer, divisor)
def to_c(self, print_func):
dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1])))
assert dtype.is_int()
code = "(({0}) % ({1}) == 0 ? {0} : (({dtype})(({0}) / ({1}))+1) * ({1}))"
return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype)
# noinspection PyPep8Naming
class div_ceil(sp.Function):
"""Integer division that is always rounded up
Examples:
>>> div_ceil(9, 4)
3
>>> div_ceil(8, 4)
2
>>> from pystencils import TypedSymbol
>>> a, b = TypedSymbol("a", "int64"), TypedSymbol("b", "int32")
>>> div_ceil(a, b).to_c(str)
'( (a) % (b) == 0 ? (int64_t)(a) / (int64_t)(b) : ( (int64_t)(a) / (int64_t)(b) ) +1 )'
"""
nargs = 2
is_integer = True
def __new__(cls, integer, divisor):
if is_integer_sequence((integer, divisor)):
return integer // divisor if integer % divisor == 0 else (integer // divisor) + 1
else:
return super().__new__(cls, integer, divisor)
def to_c(self, print_func):
dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1])))
assert dtype.is_int()
code = "( ({0}) % ({1}) == 0 ? ({dtype})({0}) / ({dtype})({1}) : ( ({dtype})({0}) / ({dtype})({1}) ) +1 )"
return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype)
# noinspection PyPep8Naming
class div_floor(sp.Function):
"""Integer division
Examples:
>>> div_floor(9, 4)
2
>>> div_floor(8, 4)
2
>>> from pystencils import TypedSymbol
>>> a, b = TypedSymbol("a", "int64"), TypedSymbol("b", "int32")
>>> div_floor(a, b).to_c(str)
'((int64_t)(a) / (int64_t)(b))'
"""
nargs = 2
is_integer = True
def __new__(cls, integer, divisor):
if is_integer_sequence((integer, divisor)):
return integer // divisor
else:
return super().__new__(cls, integer, divisor)
def to_c(self, print_func):
dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1])))
assert dtype.is_int()
code = "(({dtype})({0}) / ({dtype})({1}))"
return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype)
"""Transformations using integer sets based on ISL library"""
import islpy as isl
import sympy as sp
import pystencils.astnodes as ast
from pystencils.typing import parents_of_type
from pystencils.backends.cbackend import CustomSympyPrinter
def remove_brackets(s):
return s.replace('[', '').replace(']', '')
def _degrees_of_freedom_as_string(expr):
expr = sp.sympify(expr)
indexed = expr.atoms(sp.Indexed)
symbols = expr.atoms(sp.Symbol)
symbols_without_indexed_base = symbols - {ind.base.args[0] for ind in indexed}
symbols_without_indexed_base.update(indexed)
return {remove_brackets(str(s)) for s in symbols_without_indexed_base}
def isl_iteration_set(node: ast.Node):
"""Builds up an ISL set describing the iteration space by analysing the enclosing loops of the given node. """
conditions = []
degrees_of_freedom = set()
for loop in parents_of_type(node, ast.LoopOverCoordinate):
if loop.step != 1:
raise NotImplementedError("Loops with strides != 1 are not yet supported.")
degrees_of_freedom.update(_degrees_of_freedom_as_string(loop.loop_counter_symbol))
degrees_of_freedom.update(_degrees_of_freedom_as_string(loop.start))
degrees_of_freedom.update(_degrees_of_freedom_as_string(loop.stop))
loop_start_str = remove_brackets(str(loop.start))
loop_stop_str = remove_brackets(str(loop.stop))
ctr_name = loop.loop_counter_name
set_string_description = f"{ctr_name} >= {loop_start_str} and {ctr_name} < {loop_stop_str}"
conditions.append(remove_brackets(set_string_description))
symbol_names = ','.join(degrees_of_freedom)
condition_str = ' and '.join(conditions)
set_description = f"{{ [{symbol_names}] : {condition_str} }}"
return degrees_of_freedom, isl.BasicSet(set_description)
def simplify_loop_counter_dependent_conditional(conditional):
"""Removes conditionals that depend on the loop counter or iteration limits if they are always true/false."""
dofs_in_condition = _degrees_of_freedom_as_string(conditional.condition_expr)
dofs_in_loops, iteration_set = isl_iteration_set(conditional)
if dofs_in_condition.issubset(dofs_in_loops):
symbol_names = ','.join(dofs_in_loops)
condition_str = CustomSympyPrinter().doprint(conditional.condition_expr)
condition_str = remove_brackets(condition_str)
condition_set = isl.BasicSet(f"{{ [{symbol_names}] : {condition_str} }}")
if condition_set.is_empty():
conditional.replace_by_false_block()
return
intersection = iteration_set.intersect(condition_set)
if intersection.is_empty():
conditional.replace_by_false_block()
elif intersection == iteration_set:
conditional.replace_by_true_block()
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from tempfile import NamedTemporaryFile
import base64
from tempfile import NamedTemporaryFile
from IPython import get_ipython
def log_progress(sequence, every=None, size=None, name='Items'):
"""Copied from https://github.com/alexanderkuk/log-progress"""
from ipywidgets import IntProgress, HTML, VBox
from IPython.display import display
is_iterator = False
if size is None:
try:
size = len(sequence)
except TypeError:
is_iterator = True
if size is not None:
if every is None:
if size <= 200:
every = 1
else:
every = int(size / 200) # every 0.5%
else:
assert every is not None, 'sequence is iterator, set every'
if is_iterator:
progress = IntProgress(min=0, max=1, value=1)
progress.bar_style = 'info'
else:
progress = IntProgress(min=0, max=size, value=0)
label = HTML()
box = VBox(children=[label, progress])
display(box)
index = 0
try:
for index, record in enumerate(sequence, 1):
if index == 1 or index % every == 0:
if is_iterator:
label.value = '{name}: {index} / ?'.format(
name=name,
index=index
)
else:
progress.value = index
label.value = u'{name}: {index} / {size}'.format(
name=name,
index=index,
size=size
)
yield record
except:
progress.bar_style = 'danger'
raise
else:
progress.bar_style = 'success'
progress.value = index
label.value = "{name}: {index}".format(
name=name,
index=str(index or '?')
)
import matplotlib.animation as animation
import sympy as sp
from IPython.display import HTML
import pystencils.plot as plt
VIDEO_TAG = """<video controls width="80%">
<source src="data:video/x-m4v;base64,{0}" type="video/mp4">
......@@ -73,37 +13,37 @@ VIDEO_TAG = """<video controls width="80%">
</video>"""
def __anim_to_html(anim, fps):
if not hasattr(anim, '_encoded_video'):
def __animation_to_html(animation, fps):
if not hasattr(animation, 'encoded_video'):
with NamedTemporaryFile(suffix='.mp4') as f:
anim.save(f.name, fps=fps, extra_args=['-vcodec', 'libx264', '-pix_fmt',
'yuv420p', '-profile:v', 'baseline', '-level', '3.0'])
animation.save(f.name, fps=fps, extra_args=['-vcodec', 'libx264', '-pix_fmt',
'yuv420p', '-profile:v', 'baseline', '-level', '3.0'])
video = open(f.name, "rb").read()
anim._encoded_video = base64.b64encode(video).decode('ascii')
animation.encoded_video = base64.b64encode(video).decode('ascii')
return VIDEO_TAG.format(anim._encoded_video)
return VIDEO_TAG.format(animation.encoded_video)
def makeImshowAnimation(grid, gridUpdateFunction, frames=90, **kwargs):
def make_imshow_animation(grid, grid_update_function, frames=90, **_):
from functools import partial
fig = plt.figure()
im = plt.imshow(grid, interpolation='none')
def updatefig(*args, **kwargs):
def update_figure(*_, **kwargs):
image = kwargs['image']
image = gridUpdateFunction(image)
image = grid_update_function(image)
im.set_array(image)
return im,
return animation.FuncAnimation(fig, partial(updatefig, image=grid), frames=frames)
return animation.FuncAnimation(fig, partial(update_figure, image=grid), frames=frames)
# ------- Version 1: Embed the animation as HTML5 video --------- ----------------------------------
def displayAsHtmlVideo(anim, fps=30, show=True, **kwargs):
def display_as_html_video(animation, fps=30, show=True, **_):
try:
plt.close(anim._fig)
res = __anim_to_html(anim, fps)
plt.close(animation._fig)
res = __animation_to_html(animation, fps)
if show:
return HTML(res)
else:
......@@ -115,27 +55,26 @@ def displayAsHtmlVideo(anim, fps=30, show=True, **kwargs):
# ------- Version 2: Animation is shown in extra matplotlib window ----------------------------------
def displayInExtraWindow(animation, *args, **kwargs):
def display_in_extra_window(*_, **__):
fig = plt.gcf()
try:
fig.canvas.manager.window.raise_()
fig.canvas.manager.window.raise_()
except Exception:
pass
pass
plt.show()
# ------- Version 3: Animation is shown in images that are updated directly in website --------------
def displayAsHtmlImage(animation, show=True, iterations=10000, *args, **kwargs):
def display_as_html_image(animation, show=True, *args, **kwargs):
from IPython import display
try:
if show:
fig = plt.gcf()
if show:
animation._init_draw()
for i in range(iterations):
for _ in animation.frame_seq:
if show:
fig = plt.gcf()
display.display(fig)
animation._step()
if show:
......@@ -146,37 +85,51 @@ def displayAsHtmlImage(animation, show=True, iterations=10000, *args, **kwargs)
# Dispatcher
animation_display_mode = 'imageupdate'
animation_display_mode = 'image_update'
display_animation_func = None
def disp(*args, **kwargs):
def display_animation(*args, **kwargs):
from IPython import get_ipython
ipython = get_ipython()
if not ipython:
return
if not display_animation_func:
raise "Call set_display_mode first"
raise Exception("Call set_display_mode first")
return display_animation_func(*args, **kwargs)
def setDisplayMode(mode):
def set_display_mode(mode):
from IPython import get_ipython
ipython = get_ipython()
if not ipython:
return
global animation_display_mode
global display_animation_func
animation_display_mode = mode
if animation_display_mode == 'video':
ipython.magic("matplotlib inline")
display_animation_func = displayAsHtmlVideo
display_animation_func = display_as_html_video
elif animation_display_mode == 'window':
ipython.magic("matplotlib qt")
display_animation_func = displayInExtraWindow
elif animation_display_mode == 'imageupdate':
display_animation_func = display_in_extra_window
elif animation_display_mode == 'image_update':
ipython.magic("matplotlib inline")
display_animation_func = displayAsHtmlImage
display_animation_func = display_as_html_image
else:
raise Exception("Unknown mode. Available modes 'imageupdate', 'video' and 'window' ")
raise Exception("Unknown mode. Available modes 'image_update', 'video' and 'window' ")
def activate_ipython():
from IPython import get_ipython
ipython = get_ipython()
if ipython:
set_display_mode('image_update')
ipython.magic("config InlineBackend.rc = { }")
ipython.magic("matplotlib inline")
plt.rc('figure', figsize=(16, 6))
sp.init_printing()
ipython = get_ipython()
if ipython:
setDisplayMode('imageupdate')
ipython.magic("matplotlib inline")
matplotlib.rcParams['figure.figsize'] = (16.0, 6.0)
\ No newline at end of file
activate_ipython()
from collections import namedtuple, defaultdict
from typing import Union
import sympy as sp
from sympy.codegen import Assignment
from pystencils.simp import AssignmentCollection
from pystencils import astnodes as ast, TypedSymbol
from pystencils.field import Field
from pystencils.node_collection import NodeCollection
from pystencils.transformations import NestedScopes
# TODO use this in Constraint Checker
accepted_functions = [
sp.Pow,
sp.sqrt,
sp.log,
# TODO trigonometric functions (and whatever tests will fail)
]
class KernelConstraintsCheck:
# TODO: proper specification
# TODO: More checks :)
"""Checks if the input to create_kernel is valid.
Test the following conditions:
- SSA Form for pure symbols:
- Every pure symbol may occur only once as left-hand-side of an assignment
- Every pure symbol that is read, may not be written to later
- Independence / Parallelization condition:
- a field that is written may only be read at exact the same spatial position
(Pure symbols are symbols that are not Field.Accesses)
"""
FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index'])
def __init__(self, check_independence_condition=True, check_double_write_condition=True):
self.scopes = NestedScopes()
self.field_reads = defaultdict(set)
self.field_writes = defaultdict(set)
self.fields_read = set()
self.check_independence_condition = check_independence_condition
self.check_double_write_condition = check_double_write_condition
def visit(self, obj):
if isinstance(obj, (AssignmentCollection, NodeCollection)):
[self.visit(e) for e in obj.all_assignments]
elif isinstance(obj, list) or isinstance(obj, tuple):
[self.visit(e) for e in obj]
elif isinstance(obj, (sp.Eq, ast.SympyAssignment, Assignment)):
self.process_assignment(obj)
elif isinstance(obj, ast.Conditional):
self.scopes.push()
# Disable double write check inside conditionals
# would be triggered by e.g. in-kernel boundaries
old_double_write = self.check_double_write_condition
old_independence_condition = self.check_independence_condition
self.check_double_write_condition = False
self.check_independence_condition = False
if obj.false_block:
self.visit(obj.false_block)
self.process_expression(obj.condition_expr)
self.process_expression(obj.true_block)
self.check_double_write_condition = old_double_write
self.check_independence_condition = old_independence_condition
self.scopes.pop()
elif isinstance(obj, ast.Block):
self.scopes.push()
[self.visit(e) for e in obj.args]
self.scopes.pop()
elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate):
pass
else:
raise ValueError(f'Invalid object in kernel {type(obj)}')
def process_assignment(self, assignment: Union[sp.Eq, ast.SympyAssignment, Assignment]):
# for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1
self.process_expression(assignment.rhs)
self.process_lhs(assignment.lhs)
def process_expression(self, rhs):
# TODO constraint for accepted functions, see TODO above
self.update_accesses_rhs(rhs)
if isinstance(rhs, Field.Access):
self.fields_read.add(rhs.field)
self.fields_read.update(rhs.indirect_addressing_fields)
else:
for arg in rhs.args:
self.process_expression(arg)
@property
def fields_written(self):
"""
Return all rhs fields
"""
return set(k.field for k, v in self.field_writes.items() if len(v))
def process_lhs(self, lhs: Union[Field.Access, TypedSymbol, sp.Symbol]):
assert isinstance(lhs, sp.Symbol)
self.update_accesses_lhs(lhs)
def update_accesses_lhs(self, lhs):
if isinstance(lhs, Field.Access):
fai = self.FieldAndIndex(lhs.field, lhs.index)
if self.check_double_write_condition and lhs.offsets in self.field_writes[fai]:
raise ValueError(f"Field {lhs.field.name} is written twice at the same location")
self.field_writes[fai].add(lhs.offsets)
if self.check_double_write_condition and len(self.field_writes[fai]) > 1:
raise ValueError(
f"Field {lhs.field.name} is written at two different locations")
if fai in self.field_reads:
reads = tuple(self.field_reads[fai])
if len(reads) > 1 or lhs.offsets != reads[0]:
if self.check_independence_condition:
raise ValueError(f"Field {lhs.field.name} is written at different location than it was read. "
f"This means the resulting kernel would not be thread safe")
elif isinstance(lhs, sp.Symbol):
if self.scopes.is_defined_locally(lhs):
raise ValueError(f"Assignments not in SSA form, multiple assignments to {lhs.name}")
if lhs in self.scopes.free_parameters:
raise ValueError(f"Symbol {lhs.name} is written, after it has been read")
self.scopes.define_symbol(lhs)
def update_accesses_rhs(self, rhs):
if isinstance(rhs, Field.Access) and self.check_independence_condition:
fai = self.FieldAndIndex(rhs.field, rhs.index)
writes = self.field_writes[fai]
self.field_reads[fai].add(rhs.offsets)
for write_offset in writes:
assert len(writes) == 1
if write_offset != rhs.offsets:
raise ValueError(f"Violation of loop independence condition. Field "
f"{rhs.field} is read at {rhs.offsets} and written at {write_offset}")
self.fields_read.add(rhs.field)
elif isinstance(rhs, sp.Symbol):
self.scopes.access_symbol(rhs)
import ast
import inspect
import textwrap
from typing import Callable, Union, List, Dict, Tuple
import sympy as sp
from pystencils.assignment import Assignment
from pystencils.sympyextensions import SymbolCreator
from pystencils.config import CreateKernelConfig
__all__ = ['kernel', 'kernel_config']
def _kernel(func: Callable[..., None], **kwargs) -> Tuple[List[Assignment], str]:
"""
Convenient function for kernel decorator to prevent code duplication
Args:
func: decorated function
**kwargs: kwargs for the function
Returns:
assignments, function_name
"""
source = inspect.getsource(func)
source = textwrap.dedent(source)
a = ast.parse(source)
KernelFunctionRewrite().visit(a)
ast.fix_missing_locations(a)
gl = func.__globals__.copy()
assignments = []
def assignment_adder(lhs, rhs):
assignments.append(Assignment(lhs, rhs))
gl['_add_assignment'] = assignment_adder
gl['_Piecewise'] = sp.Piecewise
gl.update(inspect.getclosurevars(func).nonlocals)
exec(compile(a, filename="<ast>", mode="exec"), gl)
func = gl[func.__name__]
args = inspect.getfullargspec(func).args
if 's' in args and 's' not in kwargs:
kwargs['s'] = SymbolCreator()
func(**kwargs)
return assignments, func.__name__
def kernel(func: Callable[..., None], **kwargs) -> List[Assignment]:
"""Decorator to simplify generation of pystencils Assignments.
Changes the meaning of the '@=' operator. Each line containing this operator gives a symbolic assignment
in the result list. Furthermore the meaning of the ternary inline 'if-else' changes meaning to denote a
sympy Piecewise.
The decorated function may not receive any arguments, with exception of an argument called 's' that specifies
a SymbolCreator()
Args:
func: decorated function
**kwargs: kwargs for the function
Examples:
>>> import pystencils as ps
>>> @kernel
... def my_kernel(s):
... f, g = ps.fields('f, g: [2D]')
... s.neighbors @= f[0,1] + f[1,0]
... g[0,0] @= s.neighbors + f[0,0] if f[0,0] > 0 else 0
>>> f, g = ps.fields('f, g: [2D]')
>>> assert my_kernel[0].rhs == f[0,1] + f[1,0]
"""
assignments, _ = _kernel(func, **kwargs)
return assignments
def kernel_config(config: CreateKernelConfig, **kwargs) -> Callable[..., Dict]:
"""Decorator to simplify generation of pystencils Assignments, which takes a configuration
and updates the function name accordingly.
Changes the meaning of the '@=' operator. Each line containing this operator gives a symbolic assignment
in the result list. Furthermore, the meaning of the ternary inline 'if-else' changes meaning to denote a
sympy Piecewise.
The decorated function may not receive any arguments, with exception to an argument called 's' that specifies
a SymbolCreator()
Args:
config: Specify whether to return the list with assignments, or a dictionary containing additional settings
like func_name
Returns:
decorator with config
Examples:
>>> import pystencils as ps
>>> kernel_configuration = ps.CreateKernelConfig()
>>> @kernel_config(kernel_configuration)
... def my_kernel(s):
... src, dst = ps.fields('src, dst: [2D]')
... s.neighbors @= src[0, 1] + src[1, 0]
... dst[0, 0] @= s.neighbors + src[0, 0] if src[0, 0] > 0 else 0
>>> f, g = ps.fields('src, dst: [2D]')
>>> assert my_kernel['assignments'][0].rhs == f[0, 1] + f[1, 0]
"""
def decorator(func: Callable[..., None]) -> Union[List[Assignment], Dict]:
"""
Args:
func: decorated function
Returns:
Dict for unpacking into create_kernel
"""
assignments, func_name = _kernel(func, **kwargs)
config.function_name = func_name
return {'assignments': assignments, 'config': config}
return decorator
# noinspection PyMethodMayBeStatic
class KernelFunctionRewrite(ast.NodeTransformer):
def visit_IfExp(self, node):
piecewise_func = ast.Name(id='_Piecewise', ctx=ast.Load())
piecewise_func = ast.copy_location(piecewise_func, node)
piecewise_args = [ast.Tuple(elts=[node.body, node.test], ctx=ast.Load()),
ast.Tuple(elts=[node.orelse, ast.NameConstant(value=True)], ctx=ast.Load())]
result = ast.Call(func=piecewise_func, args=piecewise_args, keywords=[])
return ast.copy_location(result, node)
def visit_AugAssign(self, node):
self.generic_visit(node)
node.target.ctx = ast.Load()
new_node = ast.Expr(ast.Call(func=ast.Name(id='_add_assignment', ctx=ast.Load()),
args=[node.target, node.value],
keywords=[]))
return ast.copy_location(new_node, node)
def visit_FunctionDef(self, node):
self.generic_visit(node)
node.decorator_list = []
return node