Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Showing
with 2585 additions and 137 deletions
/*
Copyright 2010-2011, D. E. Shaw Research. All rights reserved.
Copyright 2019-2024, Michael Kuron.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions, and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions, and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#if !defined(__OPENCL_VERSION__) && !defined(__HIPCC_RTC__)
#if defined(__SSE2__) || (defined(_MSC_VER) && !defined(_M_ARM64))
#include <emmintrin.h> // SSE2
#endif
#ifdef __AVX2__
#include <immintrin.h> // AVX*
#elif defined(__SSE4_1__) || (defined(_MSC_VER) && !defined(_M_ARM64))
#include <smmintrin.h> // SSE4
#ifdef __FMA__
#include <immintrin.h> // FMA
#endif
#endif
#if defined(_MSC_VER) && defined(_M_ARM64)
#define __ARM_NEON
#endif
#ifdef __ARM_NEON
#include <arm_neon.h>
#endif
#if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_SME)
#include <arm_sve.h>
#endif
#if defined(__powerpc__) && defined(__GNUC__) && !defined(__clang__) && !defined(__xlC__)
#include <ppu_intrinsics.h>
#endif
#ifdef __ALTIVEC__
#include <altivec.h>
#undef bool
#ifndef _ARCH_PWR8
#include <pveclib/vec_int64_ppc.h>
#endif
#endif
#ifdef __riscv_v
#include <riscv_vector.h>
#endif
#endif
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
#define QUALIFIERS static __forceinline__ __device__
#elif defined(__OPENCL_VERSION__)
#define QUALIFIERS static inline
#else
#define QUALIFIERS inline
#include "myintrin.h"
#endif
#if defined(__ARM_FEATURE_SME)
#define SVE_QUALIFIERS __attribute__((arm_streaming_compatible)) QUALIFIERS
#else
#define SVE_QUALIFIERS QUALIFIERS
#endif
#define PHILOX_W32_0 (0x9E3779B9)
#define PHILOX_W32_1 (0xBB67AE85)
#define PHILOX_M4x32_0 (0xD2511F53)
#define PHILOX_M4x32_1 (0xCD9E8D57)
#define TWOPOW53_INV_DOUBLE (1.1102230246251565e-16)
#define TWOPOW32_INV_FLOAT (2.3283064e-10f)
#ifdef __OPENCL_VERSION__
#include "opencl_stdint.h"
typedef uint32_t uint32;
typedef uint64_t uint64;
#else
#ifndef __HIPCC_RTC__
#include <cstdint>
#endif
typedef std::uint32_t uint32;
typedef std::uint64_t uint64;
#endif
#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_SVE_BITS) && __ARM_FEATURE_SVE_BITS > 0
typedef svfloat32_t svfloat32_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
typedef svfloat64_t svfloat64_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
#elif defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_SME)
typedef svfloat32_t svfloat32_st;
typedef svfloat64_t svfloat64_st;
#endif
QUALIFIERS uint32 mulhilo32(uint32 a, uint32 b, uint32* hip)
{
#if !defined(__CUDA_ARCH__) && !defined(__HIP_DEVICE_COMPILE__)
// host code
#if defined(__powerpc__) && (!defined(__clang__) || defined(__xlC__))
*hip = __mulhwu(a,b);
return a*b;
#elif defined(__OPENCL_VERSION__)
*hip = mul_hi(a,b);
return a*b;
#else
uint64 product = ((uint64)a) * ((uint64)b);
*hip = product >> 32;
return (uint32)product;
#endif
#else
// device code
*hip = __umulhi(a,b);
return a*b;
#endif
}
QUALIFIERS void _philox4x32round(uint32* ctr, uint32* key)
{
uint32 hi0;
uint32 hi1;
uint32 lo0 = mulhilo32(PHILOX_M4x32_0, ctr[0], &hi0);
uint32 lo1 = mulhilo32(PHILOX_M4x32_1, ctr[2], &hi1);
ctr[0] = hi1^ctr[1]^key[0];
ctr[1] = lo1;
ctr[2] = hi0^ctr[3]^key[1];
ctr[3] = lo0;
}
QUALIFIERS void _philox4x32bumpkey(uint32* key)
{
key[0] += PHILOX_W32_0;
key[1] += PHILOX_W32_1;
}
QUALIFIERS double _uniform_double_hq(uint32 x, uint32 y)
{
uint64 z = (uint64)x ^ ((uint64)y << (53 - 32));
return z * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0);
}
QUALIFIERS void philox_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
#ifdef __OPENCL_VERSION__
double * rnd1, double * rnd2)
#else
double & rnd1, double & rnd2)
#endif
{
uint32 key[2] = {key0, key1};
uint32 ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
#ifdef __OPENCL_VERSION__
*rnd1 = _uniform_double_hq(ctr[0], ctr[1]);
*rnd2 = _uniform_double_hq(ctr[2], ctr[3]);
#else
rnd1 = _uniform_double_hq(ctr[0], ctr[1]);
rnd2 = _uniform_double_hq(ctr[2], ctr[3]);
#endif
}
QUALIFIERS void philox_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
#ifdef __OPENCL_VERSION__
float * rnd1, float * rnd2, float * rnd3, float * rnd4)
#else
float & rnd1, float & rnd2, float & rnd3, float & rnd4)
#endif
{
uint32 key[2] = {key0, key1};
uint32 ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
#ifdef __OPENCL_VERSION__
*rnd1 = ctr[0] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f);
*rnd2 = ctr[1] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f);
*rnd3 = ctr[2] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f);
*rnd4 = ctr[3] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f);
#else
rnd1 = ctr[0] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f);
rnd2 = ctr[1] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f);
rnd3 = ctr[2] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f);
rnd4 = ctr[3] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f);
#endif
}
#if !defined(__CUDA_ARCH__) && !defined(__OPENCL_VERSION__) && !defined(__HIP_DEVICE_COMPILE__)
#if defined(__SSE4_1__) || (defined(_MSC_VER) && !defined(_M_ARM64))
QUALIFIERS void _philox4x32round(__m128i* ctr, __m128i* key)
{
__m128i lohi0a = _mm_mul_epu32(ctr[0], _mm_set1_epi32(PHILOX_M4x32_0));
__m128i lohi0b = _mm_mul_epu32(_mm_srli_epi64(ctr[0], 32), _mm_set1_epi32(PHILOX_M4x32_0));
__m128i lohi1a = _mm_mul_epu32(ctr[2], _mm_set1_epi32(PHILOX_M4x32_1));
__m128i lohi1b = _mm_mul_epu32(_mm_srli_epi64(ctr[2], 32), _mm_set1_epi32(PHILOX_M4x32_1));
lohi0a = _mm_shuffle_epi32(lohi0a, 0xD8);
lohi0b = _mm_shuffle_epi32(lohi0b, 0xD8);
lohi1a = _mm_shuffle_epi32(lohi1a, 0xD8);
lohi1b = _mm_shuffle_epi32(lohi1b, 0xD8);
__m128i lo0 = _mm_unpacklo_epi32(lohi0a, lohi0b);
__m128i hi0 = _mm_unpackhi_epi32(lohi0a, lohi0b);
__m128i lo1 = _mm_unpacklo_epi32(lohi1a, lohi1b);
__m128i hi1 = _mm_unpackhi_epi32(lohi1a, lohi1b);
ctr[0] = _mm_xor_si128(_mm_xor_si128(hi1, ctr[1]), key[0]);
ctr[1] = lo1;
ctr[2] = _mm_xor_si128(_mm_xor_si128(hi0, ctr[3]), key[1]);
ctr[3] = lo0;
}
QUALIFIERS void _philox4x32bumpkey(__m128i* key)
{
key[0] = _mm_add_epi32(key[0], _mm_set1_epi32(PHILOX_W32_0));
key[1] = _mm_add_epi32(key[1], _mm_set1_epi32(PHILOX_W32_1));
}
template<bool high>
QUALIFIERS __m128d _uniform_double_hq(__m128i x, __m128i y)
{
// convert 32 to 64 bit
if (high)
{
x = _mm_unpackhi_epi32(x, _mm_setzero_si128());
y = _mm_unpackhi_epi32(y, _mm_setzero_si128());
}
else
{
x = _mm_unpacklo_epi32(x, _mm_setzero_si128());
y = _mm_unpacklo_epi32(y, _mm_setzero_si128());
}
// calculate z = x ^ y << (53 - 32))
__m128i z = _mm_sll_epi64(y, _mm_set1_epi64x(53 - 32));
z = _mm_xor_si128(x, z);
// convert uint64 to double
__m128d rs = _my_cvtepu64_pd(z);
// calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
#ifdef __FMA__
rs = _mm_fmadd_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE), _mm_set1_pd(TWOPOW53_INV_DOUBLE/2.0));
#else
rs = _mm_mul_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE));
rs = _mm_add_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE/2.0));
#endif
return rs;
}
QUALIFIERS void philox_float4(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i ctr3,
uint32 key0, uint32 key1,
__m128 & rnd1, __m128 & rnd2, __m128 & rnd3, __m128 & rnd4)
{
__m128i key[2] = {_mm_set1_epi32(key0), _mm_set1_epi32(key1)};
__m128i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
// convert uint32 to float
rnd1 = _my_cvtepu32_ps(ctr[0]);
rnd2 = _my_cvtepu32_ps(ctr[1]);
rnd3 = _my_cvtepu32_ps(ctr[2]);
rnd4 = _my_cvtepu32_ps(ctr[3]);
// calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
#ifdef __FMA__
rnd1 = _mm_fmadd_ps(rnd1, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd2 = _mm_fmadd_ps(rnd2, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd3 = _mm_fmadd_ps(rnd3, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd4 = _mm_fmadd_ps(rnd4, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0));
#else
rnd1 = _mm_mul_ps(rnd1, _mm_set1_ps(TWOPOW32_INV_FLOAT));
rnd1 = _mm_add_ps(rnd1, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd2 = _mm_mul_ps(rnd2, _mm_set1_ps(TWOPOW32_INV_FLOAT));
rnd2 = _mm_add_ps(rnd2, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd3 = _mm_mul_ps(rnd3, _mm_set1_ps(TWOPOW32_INV_FLOAT));
rnd3 = _mm_add_ps(rnd3, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd4 = _mm_mul_ps(rnd4, _mm_set1_ps(TWOPOW32_INV_FLOAT));
rnd4 = _mm_add_ps(rnd4, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
#endif
}
QUALIFIERS void philox_double2(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i ctr3,
uint32 key0, uint32 key1,
__m128d & rnd1lo, __m128d & rnd1hi, __m128d & rnd2lo, __m128d & rnd2hi)
{
__m128i key[2] = {_mm_set1_epi32(key0), _mm_set1_epi32(key1)};
__m128i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]);
rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]);
rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]);
rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]);
}
QUALIFIERS void philox_float4(uint32 ctr0, __m128i ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__m128 & rnd1, __m128 & rnd2, __m128 & rnd3, __m128 & rnd4)
{
__m128i ctr0v = _mm_set1_epi32(ctr0);
__m128i ctr2v = _mm_set1_epi32(ctr2);
__m128i ctr3v = _mm_set1_epi32(ctr3);
philox_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4);
}
QUALIFIERS void philox_double2(uint32 ctr0, __m128i ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__m128d & rnd1lo, __m128d & rnd1hi, __m128d & rnd2lo, __m128d & rnd2hi)
{
__m128i ctr0v = _mm_set1_epi32(ctr0);
__m128i ctr2v = _mm_set1_epi32(ctr2);
__m128i ctr3v = _mm_set1_epi32(ctr3);
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
}
QUALIFIERS void philox_double2(uint32 ctr0, __m128i ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__m128d & rnd1, __m128d & rnd2)
{
__m128i ctr0v = _mm_set1_epi32(ctr0);
__m128i ctr2v = _mm_set1_epi32(ctr2);
__m128i ctr3v = _mm_set1_epi32(ctr3);
__m128d ignore;
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, ignore, rnd2, ignore);
}
#endif
#ifdef __ALTIVEC__
QUALIFIERS void _philox4x32round(__vector unsigned int* ctr, __vector unsigned int* key)
{
#ifndef _ARCH_PWR8
__vector unsigned int lo0 = vec_mul(ctr[0], vec_splats(PHILOX_M4x32_0));
__vector unsigned int hi0 = vec_mulhuw(ctr[0], vec_splats(PHILOX_M4x32_0));
__vector unsigned int lo1 = vec_mul(ctr[2], vec_splats(PHILOX_M4x32_1));
__vector unsigned int hi1 = vec_mulhuw(ctr[2], vec_splats(PHILOX_M4x32_1));
#elif defined(_ARCH_PWR10)
__vector unsigned int lo0 = vec_mul(ctr[0], vec_splats(PHILOX_M4x32_0));
__vector unsigned int hi0 = vec_mulh(ctr[0], vec_splats(PHILOX_M4x32_0));
__vector unsigned int lo1 = vec_mul(ctr[2], vec_splats(PHILOX_M4x32_1));
__vector unsigned int hi1 = vec_mulh(ctr[2], vec_splats(PHILOX_M4x32_1));
#else
__vector unsigned int lohi0a = (__vector unsigned int) vec_mule(ctr[0], vec_splats(PHILOX_M4x32_0));
__vector unsigned int lohi0b = (__vector unsigned int) vec_mulo(ctr[0], vec_splats(PHILOX_M4x32_0));
__vector unsigned int lohi1a = (__vector unsigned int) vec_mule(ctr[2], vec_splats(PHILOX_M4x32_1));
__vector unsigned int lohi1b = (__vector unsigned int) vec_mulo(ctr[2], vec_splats(PHILOX_M4x32_1));
#ifdef __LITTLE_ENDIAN__
__vector unsigned int lo0 = vec_mergee(lohi0a, lohi0b);
__vector unsigned int lo1 = vec_mergee(lohi1a, lohi1b);
__vector unsigned int hi0 = vec_mergeo(lohi0a, lohi0b);
__vector unsigned int hi1 = vec_mergeo(lohi1a, lohi1b);
#else
__vector unsigned int lo0 = vec_mergeo(lohi0a, lohi0b);
__vector unsigned int lo1 = vec_mergeo(lohi1a, lohi1b);
__vector unsigned int hi0 = vec_mergee(lohi0a, lohi0b);
__vector unsigned int hi1 = vec_mergee(lohi1a, lohi1b);
#endif
#endif
ctr[0] = vec_xor(vec_xor(hi1, ctr[1]), key[0]);
ctr[1] = lo1;
ctr[2] = vec_xor(vec_xor(hi0, ctr[3]), key[1]);
ctr[3] = lo0;
}
QUALIFIERS void _philox4x32bumpkey(__vector unsigned int* key)
{
key[0] = vec_add(key[0], vec_splats(PHILOX_W32_0));
key[1] = vec_add(key[1], vec_splats(PHILOX_W32_1));
}
#ifdef __VSX__
template<bool high>
QUALIFIERS __vector double _uniform_double_hq(__vector unsigned int x, __vector unsigned int y)
{
// convert 32 to 64 bit
#ifdef __LITTLE_ENDIAN__
if (high)
{
x = vec_mergel(x, vec_splats(0U));
y = vec_mergel(y, vec_splats(0U));
}
else
{
x = vec_mergeh(x, vec_splats(0U));
y = vec_mergeh(y, vec_splats(0U));
}
#else
if (high)
{
x = vec_mergel(vec_splats(0U), x);
y = vec_mergel(vec_splats(0U), y);
}
else
{
x = vec_mergeh(vec_splats(0U), x);
y = vec_mergeh(vec_splats(0U), y);
}
#endif
// calculate z = x ^ y << (53 - 32))
#ifdef _ARCH_PWR8
__vector unsigned long long z = vec_sl((__vector unsigned long long) y, vec_splats(53ULL - 32ULL));
#else
__vector unsigned long long z = vec_vsld((__vector unsigned long long) y, vec_splats(53ULL - 32ULL));
#endif
z = vec_xor((__vector unsigned long long) x, z);
// convert uint64 to double
#ifdef __xlC__
__vector double rs = vec_ctd(z, 0);
#else
__vector double rs = vec_ctf(z, 0);
#endif
// calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
rs = vec_madd(rs, vec_splats(TWOPOW53_INV_DOUBLE), vec_splats(TWOPOW53_INV_DOUBLE/2.0));
return rs;
}
#endif
QUALIFIERS void philox_float4(__vector unsigned int ctr0, __vector unsigned int ctr1, __vector unsigned int ctr2, __vector unsigned int ctr3,
uint32 key0, uint32 key1,
__vector float & rnd1, __vector float & rnd2, __vector float & rnd3, __vector float & rnd4)
{
__vector unsigned int key[2] = {vec_splats(key0), vec_splats(key1)};
__vector unsigned int ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
// convert uint32 to float
rnd1 = vec_ctf(ctr[0], 0);
rnd2 = vec_ctf(ctr[1], 0);
rnd3 = vec_ctf(ctr[2], 0);
rnd4 = vec_ctf(ctr[3], 0);
// calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
rnd1 = vec_madd(rnd1, vec_splats(TWOPOW32_INV_FLOAT), vec_splats(TWOPOW32_INV_FLOAT/2.0f));
rnd2 = vec_madd(rnd2, vec_splats(TWOPOW32_INV_FLOAT), vec_splats(TWOPOW32_INV_FLOAT/2.0f));
rnd3 = vec_madd(rnd3, vec_splats(TWOPOW32_INV_FLOAT), vec_splats(TWOPOW32_INV_FLOAT/2.0f));
rnd4 = vec_madd(rnd4, vec_splats(TWOPOW32_INV_FLOAT), vec_splats(TWOPOW32_INV_FLOAT/2.0f));
}
#ifdef __VSX__
QUALIFIERS void philox_double2(__vector unsigned int ctr0, __vector unsigned int ctr1, __vector unsigned int ctr2, __vector unsigned int ctr3,
uint32 key0, uint32 key1,
__vector double & rnd1lo, __vector double & rnd1hi, __vector double & rnd2lo, __vector double & rnd2hi)
{
__vector unsigned int key[2] = {vec_splats(key0), vec_splats(key1)};
__vector unsigned int ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]);
rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]);
rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]);
rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]);
}
#endif
QUALIFIERS void philox_float4(uint32 ctr0, __vector unsigned int ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__vector float & rnd1, __vector float & rnd2, __vector float & rnd3, __vector float & rnd4)
{
__vector unsigned int ctr0v = vec_splats(ctr0);
__vector unsigned int ctr2v = vec_splats(ctr2);
__vector unsigned int ctr3v = vec_splats(ctr3);
philox_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4);
}
QUALIFIERS void philox_float4(uint32 ctr0, __vector int ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__vector float & rnd1, __vector float & rnd2, __vector float & rnd3, __vector float & rnd4)
{
philox_float4(ctr0, (__vector unsigned int) ctr1, ctr2, ctr3, key0, key1, rnd1, rnd2, rnd3, rnd4);
}
#ifdef __VSX__
QUALIFIERS void philox_double2(uint32 ctr0, __vector unsigned int ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__vector double & rnd1lo, __vector double & rnd1hi, __vector double & rnd2lo, __vector double & rnd2hi)
{
__vector unsigned int ctr0v = vec_splats(ctr0);
__vector unsigned int ctr2v = vec_splats(ctr2);
__vector unsigned int ctr3v = vec_splats(ctr3);
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
}
QUALIFIERS void philox_double2(uint32 ctr0, __vector unsigned int ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__vector double & rnd1, __vector double & rnd2)
{
__vector unsigned int ctr0v = vec_splats(ctr0);
__vector unsigned int ctr2v = vec_splats(ctr2);
__vector unsigned int ctr3v = vec_splats(ctr3);
__vector double ignore;
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, ignore, rnd2, ignore);
}
QUALIFIERS void philox_double2(uint32 ctr0, __vector int ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__vector double & rnd1, __vector double & rnd2)
{
philox_double2(ctr0, (__vector unsigned int) ctr1, ctr2, ctr3, key0, key1, rnd1, rnd2);
}
#endif
#endif
#if defined(__ARM_NEON)
QUALIFIERS void _philox4x32round(uint32x4_t* ctr, uint32x4_t* key)
{
uint32x4_t lohi0a = vreinterpretq_u32_u64(vmull_u32(vget_low_u32(ctr[0]), vdup_n_u32(PHILOX_M4x32_0)));
uint32x4_t lohi0b = vreinterpretq_u32_u64(vmull_high_u32(ctr[0], vdupq_n_u32(PHILOX_M4x32_0)));
uint32x4_t lohi1a = vreinterpretq_u32_u64(vmull_u32(vget_low_u32(ctr[2]), vdup_n_u32(PHILOX_M4x32_1)));
uint32x4_t lohi1b = vreinterpretq_u32_u64(vmull_high_u32(ctr[2], vdupq_n_u32(PHILOX_M4x32_1)));
uint32x4_t lo0 = vuzp1q_u32(lohi0a, lohi0b);
uint32x4_t lo1 = vuzp1q_u32(lohi1a, lohi1b);
uint32x4_t hi0 = vuzp2q_u32(lohi0a, lohi0b);
uint32x4_t hi1 = vuzp2q_u32(lohi1a, lohi1b);
ctr[0] = veorq_u32(veorq_u32(hi1, ctr[1]), key[0]);
ctr[1] = lo1;
ctr[2] = veorq_u32(veorq_u32(hi0, ctr[3]), key[1]);
ctr[3] = lo0;
}
QUALIFIERS void _philox4x32bumpkey(uint32x4_t* key)
{
key[0] = vaddq_u32(key[0], vdupq_n_u32(PHILOX_W32_0));
key[1] = vaddq_u32(key[1], vdupq_n_u32(PHILOX_W32_1));
}
template<bool high>
QUALIFIERS float64x2_t _uniform_double_hq(uint32x4_t x, uint32x4_t y)
{
// convert 32 to 64 bit
if (high)
{
x = vzip2q_u32(x, vdupq_n_u32(0));
y = vzip2q_u32(y, vdupq_n_u32(0));
}
else
{
x = vzip1q_u32(x, vdupq_n_u32(0));
y = vzip1q_u32(y, vdupq_n_u32(0));
}
// calculate z = x ^ y << (53 - 32))
uint64x2_t z = vshlq_n_u64(vreinterpretq_u64_u32(y), 53 - 32);
z = veorq_u64(vreinterpretq_u64_u32(x), z);
// convert uint64 to double
float64x2_t rs = vcvtq_f64_u64(z);
// calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
rs = vfmaq_f64(vdupq_n_f64(TWOPOW53_INV_DOUBLE/2.0), vdupq_n_f64(TWOPOW53_INV_DOUBLE), rs);
return rs;
}
QUALIFIERS void philox_float4(uint32x4_t ctr0, uint32x4_t ctr1, uint32x4_t ctr2, uint32x4_t ctr3,
uint32 key0, uint32 key1,
float32x4_t & rnd1, float32x4_t & rnd2, float32x4_t & rnd3, float32x4_t & rnd4)
{
uint32x4_t key[2] = {vdupq_n_u32(key0), vdupq_n_u32(key1)};
uint32x4_t ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
// convert uint32 to float
rnd1 = vcvtq_f32_u32(ctr[0]);
rnd2 = vcvtq_f32_u32(ctr[1]);
rnd3 = vcvtq_f32_u32(ctr[2]);
rnd4 = vcvtq_f32_u32(ctr[3]);
// calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
rnd1 = vfmaq_f32(vdupq_n_f32(TWOPOW32_INV_FLOAT/2.0), vdupq_n_f32(TWOPOW32_INV_FLOAT), rnd1);
rnd2 = vfmaq_f32(vdupq_n_f32(TWOPOW32_INV_FLOAT/2.0), vdupq_n_f32(TWOPOW32_INV_FLOAT), rnd2);
rnd3 = vfmaq_f32(vdupq_n_f32(TWOPOW32_INV_FLOAT/2.0), vdupq_n_f32(TWOPOW32_INV_FLOAT), rnd3);
rnd4 = vfmaq_f32(vdupq_n_f32(TWOPOW32_INV_FLOAT/2.0), vdupq_n_f32(TWOPOW32_INV_FLOAT), rnd4);
}
QUALIFIERS void philox_double2(uint32x4_t ctr0, uint32x4_t ctr1, uint32x4_t ctr2, uint32x4_t ctr3,
uint32 key0, uint32 key1,
float64x2_t & rnd1lo, float64x2_t & rnd1hi, float64x2_t & rnd2lo, float64x2_t & rnd2hi)
{
uint32x4_t key[2] = {vdupq_n_u32(key0), vdupq_n_u32(key1)};
uint32x4_t ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]);
rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]);
rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]);
rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]);
}
QUALIFIERS void philox_float4(uint32 ctr0, uint32x4_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
float32x4_t & rnd1, float32x4_t & rnd2, float32x4_t & rnd3, float32x4_t & rnd4)
{
uint32x4_t ctr0v = vdupq_n_u32(ctr0);
uint32x4_t ctr2v = vdupq_n_u32(ctr2);
uint32x4_t ctr3v = vdupq_n_u32(ctr3);
philox_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4);
}
#ifndef _MSC_VER
QUALIFIERS void philox_float4(uint32 ctr0, int32x4_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
float32x4_t & rnd1, float32x4_t & rnd2, float32x4_t & rnd3, float32x4_t & rnd4)
{
philox_float4(ctr0, vreinterpretq_u32_s32(ctr1), ctr2, ctr3, key0, key1, rnd1, rnd2, rnd3, rnd4);
}
#endif
QUALIFIERS void philox_double2(uint32 ctr0, uint32x4_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
float64x2_t & rnd1lo, float64x2_t & rnd1hi, float64x2_t & rnd2lo, float64x2_t & rnd2hi)
{
uint32x4_t ctr0v = vdupq_n_u32(ctr0);
uint32x4_t ctr2v = vdupq_n_u32(ctr2);
uint32x4_t ctr3v = vdupq_n_u32(ctr3);
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
}
QUALIFIERS void philox_double2(uint32 ctr0, uint32x4_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
float64x2_t & rnd1, float64x2_t & rnd2)
{
uint32x4_t ctr0v = vdupq_n_u32(ctr0);
uint32x4_t ctr2v = vdupq_n_u32(ctr2);
uint32x4_t ctr3v = vdupq_n_u32(ctr3);
float64x2_t ignore;
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, ignore, rnd2, ignore);
}
#ifndef _MSC_VER
QUALIFIERS void philox_double2(uint32 ctr0, int32x4_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
float64x2_t & rnd1, float64x2_t & rnd2)
{
philox_double2(ctr0, vreinterpretq_u32_s32(ctr1), ctr2, ctr3, key0, key1, rnd1, rnd2);
}
#endif
#endif
#if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_SME)
SVE_QUALIFIERS void _philox4x32round(svuint32x4_t & ctr, svuint32x2_t & key)
{
svuint32_t lo0 = svmul_u32_x(svptrue_b32(), svget4_u32(ctr, 0), svdup_u32(PHILOX_M4x32_0));
svuint32_t hi0 = svmulh_u32_x(svptrue_b32(), svget4_u32(ctr, 0), svdup_u32(PHILOX_M4x32_0));
svuint32_t lo1 = svmul_u32_x(svptrue_b32(), svget4_u32(ctr, 2), svdup_u32(PHILOX_M4x32_1));
svuint32_t hi1 = svmulh_u32_x(svptrue_b32(), svget4_u32(ctr, 2), svdup_u32(PHILOX_M4x32_1));
ctr = svset4_u32(ctr, 0, sveor_u32_x(svptrue_b32(), sveor_u32_x(svptrue_b32(), hi1, svget4_u32(ctr, 1)), svget2_u32(key, 0)));
ctr = svset4_u32(ctr, 1, lo1);
ctr = svset4_u32(ctr, 2, sveor_u32_x(svptrue_b32(), sveor_u32_x(svptrue_b32(), hi0, svget4_u32(ctr, 3)), svget2_u32(key, 1)));
ctr = svset4_u32(ctr, 3, lo0);
}
SVE_QUALIFIERS void _philox4x32bumpkey(svuint32x2_t & key)
{
key = svset2_u32(key, 0, svadd_u32_x(svptrue_b32(), svget2_u32(key, 0), svdup_u32(PHILOX_W32_0)));
key = svset2_u32(key, 1, svadd_u32_x(svptrue_b32(), svget2_u32(key, 1), svdup_u32(PHILOX_W32_1)));
}
template<bool high>
SVE_QUALIFIERS svfloat64_t _uniform_double_hq(svuint32_t x, svuint32_t y)
{
// convert 32 to 64 bit
if (high)
{
x = svzip2_u32(x, svdup_u32(0));
y = svzip2_u32(y, svdup_u32(0));
}
else
{
x = svzip1_u32(x, svdup_u32(0));
y = svzip1_u32(y, svdup_u32(0));
}
// calculate z = x ^ y << (53 - 32))
svuint64_t z = svlsl_n_u64_x(svptrue_b64(), svreinterpret_u64_u32(y), 53 - 32);
z = sveor_u64_x(svptrue_b64(), svreinterpret_u64_u32(x), z);
// convert uint64 to double
svfloat64_t rs = svcvt_f64_u64_x(svptrue_b64(), z);
// calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
rs = svmad_f64_x(svptrue_b64(), rs, svdup_f64(TWOPOW53_INV_DOUBLE), svdup_f64(TWOPOW53_INV_DOUBLE/2.0));
return rs;
}
SVE_QUALIFIERS void philox_float4(svuint32_t ctr0, svuint32_t ctr1, svuint32_t ctr2, svuint32_t ctr3,
uint32 key0, uint32 key1,
svfloat32_st & rnd1, svfloat32_st & rnd2, svfloat32_st & rnd3, svfloat32_st & rnd4)
{
svuint32x2_t key = svcreate2_u32(svdup_u32(key0), svdup_u32(key1));
svuint32x4_t ctr = svcreate4_u32(ctr0, ctr1, ctr2, ctr3);
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
// convert uint32 to float
rnd1 = svcvt_f32_u32_x(svptrue_b32(), svget4_u32(ctr, 0));
rnd2 = svcvt_f32_u32_x(svptrue_b32(), svget4_u32(ctr, 1));
rnd3 = svcvt_f32_u32_x(svptrue_b32(), svget4_u32(ctr, 2));
rnd4 = svcvt_f32_u32_x(svptrue_b32(), svget4_u32(ctr, 3));
// calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
rnd1 = svmad_f32_x(svptrue_b32(), rnd1, svdup_f32(TWOPOW32_INV_FLOAT), svdup_f32(TWOPOW32_INV_FLOAT/2.0));
rnd2 = svmad_f32_x(svptrue_b32(), rnd2, svdup_f32(TWOPOW32_INV_FLOAT), svdup_f32(TWOPOW32_INV_FLOAT/2.0));
rnd3 = svmad_f32_x(svptrue_b32(), rnd3, svdup_f32(TWOPOW32_INV_FLOAT), svdup_f32(TWOPOW32_INV_FLOAT/2.0));
rnd4 = svmad_f32_x(svptrue_b32(), rnd4, svdup_f32(TWOPOW32_INV_FLOAT), svdup_f32(TWOPOW32_INV_FLOAT/2.0));
}
SVE_QUALIFIERS void philox_double2(svuint32_t ctr0, svuint32_t ctr1, svuint32_t ctr2, svuint32_t ctr3,
uint32 key0, uint32 key1,
svfloat64_st & rnd1lo, svfloat64_st & rnd1hi, svfloat64_st & rnd2lo, svfloat64_st & rnd2hi)
{
svuint32x2_t key = svcreate2_u32(svdup_u32(key0), svdup_u32(key1));
svuint32x4_t ctr = svcreate4_u32(ctr0, ctr1, ctr2, ctr3);
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
rnd1lo = _uniform_double_hq<false>(svget4_u32(ctr, 0), svget4_u32(ctr, 1));
rnd1hi = _uniform_double_hq<true>(svget4_u32(ctr, 0), svget4_u32(ctr, 1));
rnd2lo = _uniform_double_hq<false>(svget4_u32(ctr, 2), svget4_u32(ctr, 3));
rnd2hi = _uniform_double_hq<true>(svget4_u32(ctr, 2), svget4_u32(ctr, 3));
}
SVE_QUALIFIERS void philox_float4(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
svfloat32_st & rnd1, svfloat32_st & rnd2, svfloat32_st & rnd3, svfloat32_st & rnd4)
{
svuint32_t ctr0v = svdup_u32(ctr0);
svuint32_t ctr2v = svdup_u32(ctr2);
svuint32_t ctr3v = svdup_u32(ctr3);
philox_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4);
}
SVE_QUALIFIERS void philox_float4(uint32 ctr0, svint32_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
svfloat32_st & rnd1, svfloat32_st & rnd2, svfloat32_st & rnd3, svfloat32_st & rnd4)
{
philox_float4(ctr0, svreinterpret_u32_s32(ctr1), ctr2, ctr3, key0, key1, rnd1, rnd2, rnd3, rnd4);
}
SVE_QUALIFIERS void philox_double2(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
svfloat64_st & rnd1lo, svfloat64_st & rnd1hi, svfloat64_st & rnd2lo, svfloat64_st & rnd2hi)
{
svuint32_t ctr0v = svdup_u32(ctr0);
svuint32_t ctr2v = svdup_u32(ctr2);
svuint32_t ctr3v = svdup_u32(ctr3);
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
}
SVE_QUALIFIERS void philox_double2(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
svfloat64_st & rnd1, svfloat64_st & rnd2)
{
svuint32_t ctr0v = svdup_u32(ctr0);
svuint32_t ctr2v = svdup_u32(ctr2);
svuint32_t ctr3v = svdup_u32(ctr3);
svfloat64_st ignore;
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, ignore, rnd2, ignore);
}
SVE_QUALIFIERS void philox_double2(uint32 ctr0, svint32_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
svfloat64_st & rnd1, svfloat64_st & rnd2)
{
philox_double2(ctr0, svreinterpret_u32_s32(ctr1), ctr2, ctr3, key0, key1, rnd1, rnd2);
}
#endif
#if defined(__riscv_v)
QUALIFIERS void _philox4x32round(vuint32m1_t & ctr0, vuint32m1_t & ctr1, vuint32m1_t & ctr2, vuint32m1_t & ctr3,
vuint32m1_t key0, vuint32m1_t key1)
{
vuint32m1_t lo0 = vmul_vv_u32m1(ctr0, vmv_v_x_u32m1(PHILOX_M4x32_0, vsetvlmax_e32m1()), vsetvlmax_e32m1());
vuint32m1_t hi0 = vmulhu_vv_u32m1(ctr0, vmv_v_x_u32m1(PHILOX_M4x32_0, vsetvlmax_e32m1()), vsetvlmax_e32m1());
vuint32m1_t lo1 = vmul_vv_u32m1(ctr2, vmv_v_x_u32m1(PHILOX_M4x32_1, vsetvlmax_e32m1()), vsetvlmax_e32m1());
vuint32m1_t hi1 = vmulhu_vv_u32m1(ctr2, vmv_v_x_u32m1(PHILOX_M4x32_1, vsetvlmax_e32m1()), vsetvlmax_e32m1());
ctr0 = vxor_vv_u32m1(vxor_vv_u32m1(hi1, ctr1, vsetvlmax_e32m1()), key0, vsetvlmax_e32m1());
ctr1 = lo1;
ctr2 = vxor_vv_u32m1(vxor_vv_u32m1(hi0, ctr3, vsetvlmax_e32m1()), key1, vsetvlmax_e32m1());
ctr3 = lo0;
}
QUALIFIERS void _philox4x32bumpkey(vuint32m1_t & key0, vuint32m1_t & key1)
{
key0 = vadd_vv_u32m1(key0, vmv_v_x_u32m1(PHILOX_W32_0, vsetvlmax_e32m1()), vsetvlmax_e32m1());
key1 = vadd_vv_u32m1(key1, vmv_v_x_u32m1(PHILOX_W32_1, vsetvlmax_e32m1()), vsetvlmax_e32m1());
}
template<bool high>
QUALIFIERS vfloat64m1_t _uniform_double_hq(vuint32m1_t x, vuint32m1_t y)
{
// convert 32 to 64 bit
if (high)
{
size_t s = vsetvlmax_e32m1();
x = vslidedown_vx_u32m1(vundefined_u32m1(), x, s/2, s);
y = vslidedown_vx_u32m1(vundefined_u32m1(), y, s/2, s);
}
vuint64m1_t x64 = vwcvtu_x_x_v_u64m1(vlmul_trunc_v_u32m1_u32mf2(x), vsetvlmax_e64m1());
vuint64m1_t y64 = vwcvtu_x_x_v_u64m1(vlmul_trunc_v_u32m1_u32mf2(y), vsetvlmax_e64m1());
// calculate z = x ^ y << (53 - 32))
vuint64m1_t z = vsll_vx_u64m1(y64, 53 - 32, vsetvlmax_e64m1());
z = vxor_vv_u64m1(x64, z, vsetvlmax_e64m1());
// convert uint64 to double
vfloat64m1_t rs = vfcvt_f_xu_v_f64m1(z, vsetvlmax_e64m1());
// calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
rs = vfmadd_vv_f64m1(rs, vfmv_v_f_f64m1(TWOPOW53_INV_DOUBLE, vsetvlmax_e64m1()), vfmv_v_f_f64m1(TWOPOW53_INV_DOUBLE/2.0, vsetvlmax_e64m1()), vsetvlmax_e64m1());
return rs;
}
QUALIFIERS void philox_float4(vuint32m1_t ctr0, vuint32m1_t ctr1, vuint32m1_t ctr2, vuint32m1_t ctr3,
uint32 key0, uint32 key1,
vfloat32m1_t & rnd1, vfloat32m1_t & rnd2, vfloat32m1_t & rnd3, vfloat32m1_t & rnd4)
{
vuint32m1_t key0v = vmv_v_x_u32m1(key0, vsetvlmax_e32m1());
vuint32m1_t key1v = vmv_v_x_u32m1(key1, vsetvlmax_e32m1());
_philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 1
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 2
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 3
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 4
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 5
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 6
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 7
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 8
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 9
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 10
// convert uint32 to float
rnd1 = vfcvt_f_xu_v_f32m1(ctr0, vsetvlmax_e32m1());
rnd2 = vfcvt_f_xu_v_f32m1(ctr1, vsetvlmax_e32m1());
rnd3 = vfcvt_f_xu_v_f32m1(ctr2, vsetvlmax_e32m1());
rnd4 = vfcvt_f_xu_v_f32m1(ctr3, vsetvlmax_e32m1());
// calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
rnd1 = vfmadd_vv_f32m1(rnd1, vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT, vsetvlmax_e32m1()), vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT/2.0, vsetvlmax_e32m1()), vsetvlmax_e32m1());
rnd2 = vfmadd_vv_f32m1(rnd2, vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT, vsetvlmax_e32m1()), vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT/2.0, vsetvlmax_e32m1()), vsetvlmax_e32m1());
rnd3 = vfmadd_vv_f32m1(rnd3, vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT, vsetvlmax_e32m1()), vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT/2.0, vsetvlmax_e32m1()), vsetvlmax_e32m1());
rnd4 = vfmadd_vv_f32m1(rnd4, vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT, vsetvlmax_e32m1()), vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT/2.0, vsetvlmax_e32m1()), vsetvlmax_e32m1());
}
QUALIFIERS void philox_double2(vuint32m1_t ctr0, vuint32m1_t ctr1, vuint32m1_t ctr2, vuint32m1_t ctr3,
uint32 key0, uint32 key1,
vfloat64m1_t & rnd1lo, vfloat64m1_t & rnd1hi, vfloat64m1_t & rnd2lo, vfloat64m1_t & rnd2hi)
{
vuint32m1_t key0v = vmv_v_x_u32m1(key0, vsetvlmax_e32m1());
vuint32m1_t key1v = vmv_v_x_u32m1(key1, vsetvlmax_e32m1());
_philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 1
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 2
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 3
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 4
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 5
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 6
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 7
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 8
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 9
_philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 10
rnd1lo = _uniform_double_hq<false>(ctr0, ctr1);
rnd1hi = _uniform_double_hq<true>(ctr0, ctr1);
rnd2lo = _uniform_double_hq<false>(ctr2, ctr3);
rnd2hi = _uniform_double_hq<true>(ctr2, ctr3);
}
QUALIFIERS void philox_float4(uint32 ctr0, vuint32m1_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
vfloat32m1_t & rnd1, vfloat32m1_t & rnd2, vfloat32m1_t & rnd3, vfloat32m1_t & rnd4)
{
vuint32m1_t ctr0v = vmv_v_x_u32m1(ctr0, vsetvlmax_e32m1());
vuint32m1_t ctr2v = vmv_v_x_u32m1(ctr2, vsetvlmax_e32m1());
vuint32m1_t ctr3v = vmv_v_x_u32m1(ctr3, vsetvlmax_e32m1());
philox_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4);
}
QUALIFIERS void philox_float4(uint32 ctr0, vint32m1_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
vfloat32m1_t & rnd1, vfloat32m1_t & rnd2, vfloat32m1_t & rnd3, vfloat32m1_t & rnd4)
{
philox_float4(ctr0, vreinterpret_v_i32m1_u32m1(ctr1), ctr2, ctr3, key0, key1, rnd1, rnd2, rnd3, rnd4);
}
QUALIFIERS void philox_double2(uint32 ctr0, vuint32m1_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
vfloat64m1_t & rnd1lo, vfloat64m1_t & rnd1hi, vfloat64m1_t & rnd2lo, vfloat64m1_t & rnd2hi)
{
vuint32m1_t ctr0v = vmv_v_x_u32m1(ctr0, vsetvlmax_e32m1());
vuint32m1_t ctr2v = vmv_v_x_u32m1(ctr2, vsetvlmax_e32m1());
vuint32m1_t ctr3v = vmv_v_x_u32m1(ctr3, vsetvlmax_e32m1());
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
}
QUALIFIERS void philox_double2(uint32 ctr0, vuint32m1_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
vfloat64m1_t & rnd1, vfloat64m1_t & rnd2)
{
vuint32m1_t ctr0v = vmv_v_x_u32m1(ctr0, vsetvlmax_e32m1());
vuint32m1_t ctr2v = vmv_v_x_u32m1(ctr2, vsetvlmax_e32m1());
vuint32m1_t ctr3v = vmv_v_x_u32m1(ctr3, vsetvlmax_e32m1());
vfloat64m1_t ignore;
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, ignore, rnd2, ignore);
}
QUALIFIERS void philox_double2(uint32 ctr0, vint32m1_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
vfloat64m1_t & rnd1, vfloat64m1_t & rnd2)
{
philox_double2(ctr0, vreinterpret_v_i32m1_u32m1(ctr1), ctr2, ctr3, key0, key1, rnd1, rnd2);
}
#endif
#ifdef __AVX2__
QUALIFIERS void _philox4x32round(__m256i* ctr, __m256i* key)
{
__m256i lohi0a = _mm256_mul_epu32(ctr[0], _mm256_set1_epi32(PHILOX_M4x32_0));
__m256i lohi0b = _mm256_mul_epu32(_mm256_srli_epi64(ctr[0], 32), _mm256_set1_epi32(PHILOX_M4x32_0));
__m256i lohi1a = _mm256_mul_epu32(ctr[2], _mm256_set1_epi32(PHILOX_M4x32_1));
__m256i lohi1b = _mm256_mul_epu32(_mm256_srli_epi64(ctr[2], 32), _mm256_set1_epi32(PHILOX_M4x32_1));
lohi0a = _mm256_shuffle_epi32(lohi0a, 0xD8);
lohi0b = _mm256_shuffle_epi32(lohi0b, 0xD8);
lohi1a = _mm256_shuffle_epi32(lohi1a, 0xD8);
lohi1b = _mm256_shuffle_epi32(lohi1b, 0xD8);
__m256i lo0 = _mm256_unpacklo_epi32(lohi0a, lohi0b);
__m256i hi0 = _mm256_unpackhi_epi32(lohi0a, lohi0b);
__m256i lo1 = _mm256_unpacklo_epi32(lohi1a, lohi1b);
__m256i hi1 = _mm256_unpackhi_epi32(lohi1a, lohi1b);
ctr[0] = _mm256_xor_si256(_mm256_xor_si256(hi1, ctr[1]), key[0]);
ctr[1] = lo1;
ctr[2] = _mm256_xor_si256(_mm256_xor_si256(hi0, ctr[3]), key[1]);
ctr[3] = lo0;
}
QUALIFIERS void _philox4x32bumpkey(__m256i* key)
{
key[0] = _mm256_add_epi32(key[0], _mm256_set1_epi32(PHILOX_W32_0));
key[1] = _mm256_add_epi32(key[1], _mm256_set1_epi32(PHILOX_W32_1));
}
template<bool high>
QUALIFIERS __m256d _uniform_double_hq(__m256i x, __m256i y)
{
// convert 32 to 64 bit
if (high)
{
x = _mm256_cvtepu32_epi64(_mm256_extracti128_si256(x, 1));
y = _mm256_cvtepu32_epi64(_mm256_extracti128_si256(y, 1));
}
else
{
x = _mm256_cvtepu32_epi64(_mm256_extracti128_si256(x, 0));
y = _mm256_cvtepu32_epi64(_mm256_extracti128_si256(y, 0));
}
// calculate z = x ^ y << (53 - 32))
__m256i z = _mm256_sll_epi64(y, _mm_set1_epi64x(53 - 32));
z = _mm256_xor_si256(x, z);
// convert uint64 to double
__m256d rs = _my256_cvtepu64_pd(z);
// calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
#ifdef __FMA__
rs = _mm256_fmadd_pd(rs, _mm256_set1_pd(TWOPOW53_INV_DOUBLE), _mm256_set1_pd(TWOPOW53_INV_DOUBLE/2.0));
#else
rs = _mm256_mul_pd(rs, _mm256_set1_pd(TWOPOW53_INV_DOUBLE));
rs = _mm256_add_pd(rs, _mm256_set1_pd(TWOPOW53_INV_DOUBLE/2.0));
#endif
return rs;
}
QUALIFIERS void philox_float4(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i ctr3,
uint32 key0, uint32 key1,
__m256 & rnd1, __m256 & rnd2, __m256 & rnd3, __m256 & rnd4)
{
__m256i key[2] = {_mm256_set1_epi32(key0), _mm256_set1_epi32(key1)};
__m256i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
// convert uint32 to float
rnd1 = _my256_cvtepu32_ps(ctr[0]);
rnd2 = _my256_cvtepu32_ps(ctr[1]);
rnd3 = _my256_cvtepu32_ps(ctr[2]);
rnd4 = _my256_cvtepu32_ps(ctr[3]);
// calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
#ifdef __FMA__
rnd1 = _mm256_fmadd_ps(rnd1, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd2 = _mm256_fmadd_ps(rnd2, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd3 = _mm256_fmadd_ps(rnd3, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd4 = _mm256_fmadd_ps(rnd4, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0));
#else
rnd1 = _mm256_mul_ps(rnd1, _mm256_set1_ps(TWOPOW32_INV_FLOAT));
rnd1 = _mm256_add_ps(rnd1, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd2 = _mm256_mul_ps(rnd2, _mm256_set1_ps(TWOPOW32_INV_FLOAT));
rnd2 = _mm256_add_ps(rnd2, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd3 = _mm256_mul_ps(rnd3, _mm256_set1_ps(TWOPOW32_INV_FLOAT));
rnd3 = _mm256_add_ps(rnd3, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
rnd4 = _mm256_mul_ps(rnd4, _mm256_set1_ps(TWOPOW32_INV_FLOAT));
rnd4 = _mm256_add_ps(rnd4, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
#endif
}
QUALIFIERS void philox_double2(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i ctr3,
uint32 key0, uint32 key1,
__m256d & rnd1lo, __m256d & rnd1hi, __m256d & rnd2lo, __m256d & rnd2hi)
{
__m256i key[2] = {_mm256_set1_epi32(key0), _mm256_set1_epi32(key1)};
__m256i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]);
rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]);
rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]);
rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]);
}
QUALIFIERS void philox_float4(uint32 ctr0, __m256i ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__m256 & rnd1, __m256 & rnd2, __m256 & rnd3, __m256 & rnd4)
{
__m256i ctr0v = _mm256_set1_epi32(ctr0);
__m256i ctr2v = _mm256_set1_epi32(ctr2);
__m256i ctr3v = _mm256_set1_epi32(ctr3);
philox_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4);
}
QUALIFIERS void philox_double2(uint32 ctr0, __m256i ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__m256d & rnd1lo, __m256d & rnd1hi, __m256d & rnd2lo, __m256d & rnd2hi)
{
__m256i ctr0v = _mm256_set1_epi32(ctr0);
__m256i ctr2v = _mm256_set1_epi32(ctr2);
__m256i ctr3v = _mm256_set1_epi32(ctr3);
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
}
QUALIFIERS void philox_double2(uint32 ctr0, __m256i ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__m256d & rnd1, __m256d & rnd2)
{
#if 0
__m256i ctr0v = _mm256_set1_epi32(ctr0);
__m256i ctr2v = _mm256_set1_epi32(ctr2);
__m256i ctr3v = _mm256_set1_epi32(ctr3);
__m256d ignore;
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, ignore, rnd2, ignore);
#else
__m128d rnd1lo, rnd1hi, rnd2lo, rnd2hi;
philox_double2(ctr0, _mm256_extractf128_si256(ctr1, 0), ctr2, ctr3, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
rnd1 = _my256_set_m128d(rnd1hi, rnd1lo);
rnd2 = _my256_set_m128d(rnd2hi, rnd2lo);
#endif
}
#endif
#if defined(__AVX512F__) || defined(__AVX10_512BIT__)
QUALIFIERS void _philox4x32round(__m512i* ctr, __m512i* key)
{
__m512i lohi0a = _mm512_mul_epu32(ctr[0], _mm512_set1_epi32(PHILOX_M4x32_0));
__m512i lohi0b = _mm512_mul_epu32(_mm512_srli_epi64(ctr[0], 32), _mm512_set1_epi32(PHILOX_M4x32_0));
__m512i lohi1a = _mm512_mul_epu32(ctr[2], _mm512_set1_epi32(PHILOX_M4x32_1));
__m512i lohi1b = _mm512_mul_epu32(_mm512_srli_epi64(ctr[2], 32), _mm512_set1_epi32(PHILOX_M4x32_1));
lohi0a = _mm512_shuffle_epi32(lohi0a, _MM_PERM_DBCA);
lohi0b = _mm512_shuffle_epi32(lohi0b, _MM_PERM_DBCA);
lohi1a = _mm512_shuffle_epi32(lohi1a, _MM_PERM_DBCA);
lohi1b = _mm512_shuffle_epi32(lohi1b, _MM_PERM_DBCA);
__m512i lo0 = _mm512_unpacklo_epi32(lohi0a, lohi0b);
__m512i hi0 = _mm512_unpackhi_epi32(lohi0a, lohi0b);
__m512i lo1 = _mm512_unpacklo_epi32(lohi1a, lohi1b);
__m512i hi1 = _mm512_unpackhi_epi32(lohi1a, lohi1b);
ctr[0] = _mm512_xor_si512(_mm512_xor_si512(hi1, ctr[1]), key[0]);
ctr[1] = lo1;
ctr[2] = _mm512_xor_si512(_mm512_xor_si512(hi0, ctr[3]), key[1]);
ctr[3] = lo0;
}
QUALIFIERS void _philox4x32bumpkey(__m512i* key)
{
key[0] = _mm512_add_epi32(key[0], _mm512_set1_epi32(PHILOX_W32_0));
key[1] = _mm512_add_epi32(key[1], _mm512_set1_epi32(PHILOX_W32_1));
}
template<bool high>
QUALIFIERS __m512d _uniform_double_hq(__m512i x, __m512i y)
{
// convert 32 to 64 bit
if (high)
{
x = _mm512_cvtepu32_epi64(_mm512_extracti64x4_epi64(x, 1));
y = _mm512_cvtepu32_epi64(_mm512_extracti64x4_epi64(y, 1));
}
else
{
x = _mm512_cvtepu32_epi64(_mm512_extracti64x4_epi64(x, 0));
y = _mm512_cvtepu32_epi64(_mm512_extracti64x4_epi64(y, 0));
}
// calculate z = x ^ y << (53 - 32))
__m512i z = _mm512_sll_epi64(y, _mm_set1_epi64x(53 - 32));
z = _mm512_xor_si512(x, z);
// convert uint64 to double
__m512d rs = _mm512_cvtepu64_pd(z);
// calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
rs = _mm512_fmadd_pd(rs, _mm512_set1_pd(TWOPOW53_INV_DOUBLE), _mm512_set1_pd(TWOPOW53_INV_DOUBLE/2.0));
return rs;
}
QUALIFIERS void philox_float4(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i ctr3,
uint32 key0, uint32 key1,
__m512 & rnd1, __m512 & rnd2, __m512 & rnd3, __m512 & rnd4)
{
__m512i key[2] = {_mm512_set1_epi32(key0), _mm512_set1_epi32(key1)};
__m512i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
// convert uint32 to float
rnd1 = _mm512_cvtepu32_ps(ctr[0]);
rnd2 = _mm512_cvtepu32_ps(ctr[1]);
rnd3 = _mm512_cvtepu32_ps(ctr[2]);
rnd4 = _mm512_cvtepu32_ps(ctr[3]);
// calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
rnd1 = _mm512_fmadd_ps(rnd1, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd2 = _mm512_fmadd_ps(rnd2, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd3 = _mm512_fmadd_ps(rnd3, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0));
rnd4 = _mm512_fmadd_ps(rnd4, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0));
}
QUALIFIERS void philox_double2(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i ctr3,
uint32 key0, uint32 key1,
__m512d & rnd1lo, __m512d & rnd1hi, __m512d & rnd2lo, __m512d & rnd2hi)
{
__m512i key[2] = {_mm512_set1_epi32(key0), _mm512_set1_epi32(key1)};
__m512i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]);
rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]);
rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]);
rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]);
}
QUALIFIERS void philox_float4(uint32 ctr0, __m512i ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__m512 & rnd1, __m512 & rnd2, __m512 & rnd3, __m512 & rnd4)
{
__m512i ctr0v = _mm512_set1_epi32(ctr0);
__m512i ctr2v = _mm512_set1_epi32(ctr2);
__m512i ctr3v = _mm512_set1_epi32(ctr3);
philox_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4);
}
QUALIFIERS void philox_double2(uint32 ctr0, __m512i ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__m512d & rnd1lo, __m512d & rnd1hi, __m512d & rnd2lo, __m512d & rnd2hi)
{
__m512i ctr0v = _mm512_set1_epi32(ctr0);
__m512i ctr2v = _mm512_set1_epi32(ctr2);
__m512i ctr3v = _mm512_set1_epi32(ctr3);
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
}
QUALIFIERS void philox_double2(uint32 ctr0, __m512i ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
__m512d & rnd1, __m512d & rnd2)
{
#if 0
__m512i ctr0v = _mm512_set1_epi32(ctr0);
__m512i ctr2v = _mm512_set1_epi32(ctr2);
__m512i ctr3v = _mm512_set1_epi32(ctr3);
__m512d ignore;
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, ignore, rnd2, ignore);
#else
__m256d rnd1lo, rnd1hi, rnd2lo, rnd2hi;
philox_double2(ctr0, _mm512_extracti64x4_epi64(ctr1, 0), ctr2, ctr3, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
rnd1 = _my512_set_m256d(rnd1hi, rnd1lo);
rnd2 = _my512_set_m256d(rnd2hi, rnd2lo);
#endif
}
#endif
#endif
/*
Copyright 2021, Michael Kuron.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions, and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions, and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include <altivec.h>
#undef vector
#undef bool
inline void cachelineZero(void * p) {
#ifdef __xlC__
__dcbz(p);
#else
__asm__ volatile("dcbz 0, %0"::"r"(p):"memory");
#endif
}
inline size_t _cachelineSize() {
// allocate and fill with ones
const size_t max_size = 0x100000;
uint8_t data[2*max_size];
for (size_t i = 0; i < 2*max_size; ++i) {
data[i] = 0xff;
}
// find alignment offset
size_t offset = max_size - ((uintptr_t) data) % max_size;
// zero a cacheline
cachelineZero((void*) (data + offset));
// make sure that at least one byte was zeroed
if (data[offset] != 0) {
return SIZE_MAX;
}
// make sure that nothing was zeroed before the pointer
if (data[offset-1] == 0) {
return SIZE_MAX;
}
// find the last byte that was zeroed
for (size_t size = 1; size < max_size; ++size) {
if (data[offset + size] != 0) {
return size;
}
}
// too much was zeroed
return SIZE_MAX;
}
inline size_t cachelineSize() {
static size_t size = _cachelineSize();
return size;
}
/*
Copyright 2023, Michael Kuron.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions, and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions, and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
inline void cachelineZero(void * p) {
#ifdef __riscv_zicboz
__asm__ volatile("cbo.zero (%0)"::"r"(p):"memory");
#endif
}
inline size_t _cachelineSize() {
// allocate and fill with ones
const size_t max_size = 0x100000;
uint8_t data[2*max_size];
for (size_t i = 0; i < 2*max_size; ++i) {
data[i] = 0xff;
}
// find alignment offset
size_t offset = max_size - ((uintptr_t) data) % max_size;
// zero a cacheline
cachelineZero((void*) (data + offset));
// make sure that at least one byte was zeroed
if (data[offset] != 0) {
return SIZE_MAX;
}
// make sure that nothing was zeroed before the pointer
if (data[offset-1] == 0) {
return SIZE_MAX;
}
// find the last byte that was zeroed
for (size_t size = 1; size < max_size; ++size) {
if (data[offset + size] != 0) {
return size;
}
}
// too much was zeroed
return SIZE_MAX;
}
inline size_t cachelineSize() {
#ifdef __riscv_zicboz
static size_t size = _cachelineSize();
return size;
#else
return SIZE_MAX;
#endif
}
# TODO #47 move to a module functions
import numpy as np
import sympy as sp
from pystencils.data_types import cast_func, collate_types, create_type, get_type_of_expression
from pystencils.typing import CastFunc, collate_types, create_type, get_type_of_expression
from pystencils.sympyextensions import is_integer_sequence
......@@ -12,21 +13,29 @@ class IntegerFunctionTwoArgsMixIn(sp.Function):
args = []
for a in (arg1, arg2):
if isinstance(a, sp.Number) or isinstance(a, int):
args.append(cast_func(a, create_type("int")))
args.append(CastFunc(a, create_type("int")))
elif isinstance(a, np.generic):
args.append(cast_func(a, a.dtype))
args.append(CastFunc(a, a.dtype))
else:
args.append(a)
for a in args:
try:
type = get_type_of_expression(a)
if not type.is_int():
raise ValueError("Argument to integer function is not an int but " + str(type))
dtype = get_type_of_expression(a)
if not dtype.is_int():
raise ValueError("Argument to integer function is not an int but " + str(dtype))
except NotImplementedError:
raise ValueError("Integer functions can only be constructed with typed expressions")
return super().__new__(cls, *args)
def _eval_evalf(self, *pargs, **kwargs):
arg1 = self.args[0].evalf(*pargs, **kwargs) if hasattr(self.args[0], 'evalf') else self.args[0]
arg2 = self.args[1].evalf(*pargs, **kwargs) if hasattr(self.args[1], 'evalf') else self.args[1]
return self._eval_op(arg1, arg2)
def _eval_op(self, arg1, arg2):
return self
# noinspection PyPep8Naming
class bitwise_xor(IntegerFunctionTwoArgsMixIn):
......@@ -55,7 +64,9 @@ class bitwise_or(IntegerFunctionTwoArgsMixIn):
# noinspection PyPep8Naming
class int_div(IntegerFunctionTwoArgsMixIn):
pass
def _eval_op(self, arg1, arg2):
return int(arg1 // arg2)
# noinspection PyPep8Naming
......
......@@ -4,7 +4,8 @@ import islpy as isl
import sympy as sp
import pystencils.astnodes as ast
from pystencils.transformations import parents_of_type
from pystencils.typing import parents_of_type
from pystencils.backends.cbackend import CustomSympyPrinter
def remove_brackets(s):
......@@ -36,13 +37,12 @@ def isl_iteration_set(node: ast.Node):
loop_start_str = remove_brackets(str(loop.start))
loop_stop_str = remove_brackets(str(loop.stop))
ctr_name = loop.loop_counter_name
set_string_description = "{} >= {} and {} < {}".format(ctr_name, loop_start_str, ctr_name, loop_stop_str)
set_string_description = f"{ctr_name} >= {loop_start_str} and {ctr_name} < {loop_stop_str}"
conditions.append(remove_brackets(set_string_description))
symbol_names = ','.join(degrees_of_freedom)
condition_str = ' and '.join(conditions)
set_description = "{{ [{symbol_names}] : {condition_str} }}".format(symbol_names=symbol_names,
condition_str=condition_str)
set_description = f"{{ [{symbol_names}] : {condition_str} }}"
return degrees_of_freedom, isl.BasicSet(set_description)
......@@ -52,12 +52,13 @@ def simplify_loop_counter_dependent_conditional(conditional):
dofs_in_loops, iteration_set = isl_iteration_set(conditional)
if dofs_in_condition.issubset(dofs_in_loops):
symbol_names = ','.join(dofs_in_loops)
condition_str = remove_brackets(str(conditional.condition_expr))
condition_set = isl.BasicSet("{{ [{symbol_names}] : {condition_str} }}".format(symbol_names=symbol_names,
condition_str=condition_str))
condition_str = CustomSympyPrinter().doprint(conditional.condition_expr)
condition_str = remove_brackets(condition_str)
condition_set = isl.BasicSet(f"{{ [{symbol_names}] : {condition_str} }}")
if condition_set.is_empty():
conditional.replace_by_false_block()
return
intersection = iteration_set.intersect(condition_set)
if intersection.is_empty():
......
......@@ -7,67 +7,6 @@ from IPython.display import HTML
import pystencils.plot as plt
__all__ = ['log_progress', 'make_imshow_animation', 'display_animation', 'set_display_mode']
def log_progress(sequence, every=None, size=None, name='Items'):
"""Copied from https://github.com/alexanderkuk/log-progress"""
from ipywidgets import IntProgress, HTML, VBox
from IPython.display import display
is_iterator = False
if size is None:
try:
size = len(sequence)
except TypeError:
is_iterator = True
if size is not None:
if every is None:
if size <= 200:
every = 1
else:
every = int(size / 200) # every 0.5%
else:
assert every is not None, 'sequence is iterator, set every'
if is_iterator:
progress = IntProgress(min=0, max=1, value=1)
progress.bar_style = 'info'
else:
progress = IntProgress(min=0, max=size, value=0)
label = HTML()
box = VBox(children=[label, progress])
display(box)
index = 0
try:
for index, record in enumerate(sequence, 1):
if index == 1 or index % every == 0:
if is_iterator:
label.value = '{name}: {index} / ?'.format(
name=name,
index=index
)
else:
progress.value = index
label.value = u'{name}: {index} / {size}'.format(
name=name,
index=index,
size=size
)
yield record
except:
progress.bar_style = 'danger'
raise
else:
progress.bar_style = 'success'
progress.value = index
label.value = "{name}: {index}".format(
name=name,
index=str(index or '?')
)
VIDEO_TAG = """<video controls width="80%">
<source src="data:video/x-m4v;base64,{0}" type="video/mp4">
Your browser does not support the video tag.
......
from collections import namedtuple, defaultdict
from typing import Union
import sympy as sp
from sympy.codegen import Assignment
from pystencils.simp import AssignmentCollection
from pystencils import astnodes as ast, TypedSymbol
from pystencils.field import Field
from pystencils.node_collection import NodeCollection
from pystencils.transformations import NestedScopes
# TODO use this in Constraint Checker
accepted_functions = [
sp.Pow,
sp.sqrt,
sp.log,
# TODO trigonometric functions (and whatever tests will fail)
]
class KernelConstraintsCheck:
# TODO: proper specification
# TODO: More checks :)
"""Checks if the input to create_kernel is valid.
Test the following conditions:
- SSA Form for pure symbols:
- Every pure symbol may occur only once as left-hand-side of an assignment
- Every pure symbol that is read, may not be written to later
- Independence / Parallelization condition:
- a field that is written may only be read at exact the same spatial position
(Pure symbols are symbols that are not Field.Accesses)
"""
FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index'])
def __init__(self, check_independence_condition=True, check_double_write_condition=True):
self.scopes = NestedScopes()
self.field_reads = defaultdict(set)
self.field_writes = defaultdict(set)
self.fields_read = set()
self.check_independence_condition = check_independence_condition
self.check_double_write_condition = check_double_write_condition
def visit(self, obj):
if isinstance(obj, (AssignmentCollection, NodeCollection)):
[self.visit(e) for e in obj.all_assignments]
elif isinstance(obj, list) or isinstance(obj, tuple):
[self.visit(e) for e in obj]
elif isinstance(obj, (sp.Eq, ast.SympyAssignment, Assignment)):
self.process_assignment(obj)
elif isinstance(obj, ast.Conditional):
self.scopes.push()
# Disable double write check inside conditionals
# would be triggered by e.g. in-kernel boundaries
old_double_write = self.check_double_write_condition
old_independence_condition = self.check_independence_condition
self.check_double_write_condition = False
self.check_independence_condition = False
if obj.false_block:
self.visit(obj.false_block)
self.process_expression(obj.condition_expr)
self.process_expression(obj.true_block)
self.check_double_write_condition = old_double_write
self.check_independence_condition = old_independence_condition
self.scopes.pop()
elif isinstance(obj, ast.Block):
self.scopes.push()
[self.visit(e) for e in obj.args]
self.scopes.pop()
elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate):
pass
else:
raise ValueError(f'Invalid object in kernel {type(obj)}')
def process_assignment(self, assignment: Union[sp.Eq, ast.SympyAssignment, Assignment]):
# for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1
self.process_expression(assignment.rhs)
self.process_lhs(assignment.lhs)
def process_expression(self, rhs):
# TODO constraint for accepted functions, see TODO above
self.update_accesses_rhs(rhs)
if isinstance(rhs, Field.Access):
self.fields_read.add(rhs.field)
self.fields_read.update(rhs.indirect_addressing_fields)
else:
for arg in rhs.args:
self.process_expression(arg)
@property
def fields_written(self):
"""
Return all rhs fields
"""
return set(k.field for k, v in self.field_writes.items() if len(v))
def process_lhs(self, lhs: Union[Field.Access, TypedSymbol, sp.Symbol]):
assert isinstance(lhs, sp.Symbol)
self.update_accesses_lhs(lhs)
def update_accesses_lhs(self, lhs):
if isinstance(lhs, Field.Access):
fai = self.FieldAndIndex(lhs.field, lhs.index)
if self.check_double_write_condition and lhs.offsets in self.field_writes[fai]:
raise ValueError(f"Field {lhs.field.name} is written twice at the same location")
self.field_writes[fai].add(lhs.offsets)
if self.check_double_write_condition and len(self.field_writes[fai]) > 1:
raise ValueError(
f"Field {lhs.field.name} is written at two different locations")
if fai in self.field_reads:
reads = tuple(self.field_reads[fai])
if len(reads) > 1 or lhs.offsets != reads[0]:
if self.check_independence_condition:
raise ValueError(f"Field {lhs.field.name} is written at different location than it was read. "
f"This means the resulting kernel would not be thread safe")
elif isinstance(lhs, sp.Symbol):
if self.scopes.is_defined_locally(lhs):
raise ValueError(f"Assignments not in SSA form, multiple assignments to {lhs.name}")
if lhs in self.scopes.free_parameters:
raise ValueError(f"Symbol {lhs.name} is written, after it has been read")
self.scopes.define_symbol(lhs)
def update_accesses_rhs(self, rhs):
if isinstance(rhs, Field.Access) and self.check_independence_condition:
fai = self.FieldAndIndex(rhs.field, rhs.index)
writes = self.field_writes[fai]
self.field_reads[fai].add(rhs.offsets)
for write_offset in writes:
assert len(writes) == 1
if write_offset != rhs.offsets:
raise ValueError(f"Violation of loop independence condition. Field "
f"{rhs.field} is read at {rhs.offsets} and written at {write_offset}")
self.fields_read.add(rhs.field)
elif isinstance(rhs, sp.Symbol):
self.scopes.access_symbol(rhs)
import ast
import inspect
import textwrap
from typing import Callable, Union, List, Dict, Tuple
import sympy as sp
from pystencils.assignment import Assignment
from pystencils.sympyextensions import SymbolCreator
from pystencils.config import CreateKernelConfig
__all__ = ['kernel']
__all__ = ['kernel', 'kernel_config']
def kernel(func, **kwargs):
"""Decorator to simplify generation of pystencils Assignments.
Changes the meaning of the '@=' operator. Each line containing this operator gives a symbolic assignment
in the result list. Furthermore the meaning of the ternary inline 'if-else' changes meaning to denote a
sympy Piecewise.
The decorated function may not receive any arguments, with exception of an argument called 's' that specifies
a SymbolCreator()
Examples:
>>> import pystencils as ps
>>> @kernel
... def my_kernel(s):
... f, g = ps.fields('f, g: [2D]')
... s.neighbors @= f[0,1] + f[1,0]
... g[0,0] @= s.neighbors + f[0,0] if f[0,0] > 0 else 0
>>> f, g = ps.fields('f, g: [2D]')
>>> assert my_kernel[0].rhs == f[0,1] + f[1,0]
def _kernel(func: Callable[..., None], **kwargs) -> Tuple[List[Assignment], str]:
"""
Convenient function for kernel decorator to prevent code duplication
Args:
func: decorated function
**kwargs: kwargs for the function
Returns:
assignments, function_name
"""
source = inspect.getsource(func)
source = textwrap.dedent(source)
......@@ -51,9 +42,76 @@ def kernel(func, **kwargs):
if 's' in args and 's' not in kwargs:
kwargs['s'] = SymbolCreator()
func(**kwargs)
return assignments, func.__name__
def kernel(func: Callable[..., None], **kwargs) -> List[Assignment]:
"""Decorator to simplify generation of pystencils Assignments.
Changes the meaning of the '@=' operator. Each line containing this operator gives a symbolic assignment
in the result list. Furthermore the meaning of the ternary inline 'if-else' changes meaning to denote a
sympy Piecewise.
The decorated function may not receive any arguments, with exception of an argument called 's' that specifies
a SymbolCreator()
Args:
func: decorated function
**kwargs: kwargs for the function
Examples:
>>> import pystencils as ps
>>> @kernel
... def my_kernel(s):
... f, g = ps.fields('f, g: [2D]')
... s.neighbors @= f[0,1] + f[1,0]
... g[0,0] @= s.neighbors + f[0,0] if f[0,0] > 0 else 0
>>> f, g = ps.fields('f, g: [2D]')
>>> assert my_kernel[0].rhs == f[0,1] + f[1,0]
"""
assignments, _ = _kernel(func, **kwargs)
return assignments
def kernel_config(config: CreateKernelConfig, **kwargs) -> Callable[..., Dict]:
"""Decorator to simplify generation of pystencils Assignments, which takes a configuration
and updates the function name accordingly.
Changes the meaning of the '@=' operator. Each line containing this operator gives a symbolic assignment
in the result list. Furthermore, the meaning of the ternary inline 'if-else' changes meaning to denote a
sympy Piecewise.
The decorated function may not receive any arguments, with exception to an argument called 's' that specifies
a SymbolCreator()
Args:
config: Specify whether to return the list with assignments, or a dictionary containing additional settings
like func_name
Returns:
decorator with config
Examples:
>>> import pystencils as ps
>>> kernel_configuration = ps.CreateKernelConfig()
>>> @kernel_config(kernel_configuration)
... def my_kernel(s):
... src, dst = ps.fields('src, dst: [2D]')
... s.neighbors @= src[0, 1] + src[1, 0]
... dst[0, 0] @= s.neighbors + src[0, 0] if src[0, 0] > 0 else 0
>>> f, g = ps.fields('src, dst: [2D]')
>>> assert my_kernel['assignments'][0].rhs == f[0, 1] + f[1, 0]
"""
def decorator(func: Callable[..., None]) -> Union[List[Assignment], Dict]:
"""
Args:
func: decorated function
Returns:
Dict for unpacking into create_kernel
"""
assignments, func_name = _kernel(func, **kwargs)
config.function_name = func_name
return {'assignments': assignments, 'config': config}
return decorator
# noinspection PyMethodMayBeStatic
class KernelFunctionRewrite(ast.NodeTransformer):
......
"""
Light-weight wrapper around a compiled kernel
"""
import pystencils
class KernelWrapper:
def __init__(self, kernel, parameters, ast_node):
"""
Light-weight wrapper around a compiled kernel.
Can be called while still providing access to underlying AST.
"""
def __init__(self, kernel, parameters, ast_node: pystencils.astnodes.KernelFunction):
self.kernel = kernel
self.parameters = parameters
self.ast = ast_node
......@@ -16,4 +19,4 @@ class KernelWrapper:
@property
def code(self):
return str(pystencils.show_code(self.ast))
return pystencils.get_code_str(self.ast)
import itertools
import warnings
from typing import Union, List
import sympy as sp
from pystencils.config import CreateKernelConfig
from pystencils.assignment import Assignment, AddAugmentedAssignment
from pystencils.astnodes import Node, Block, Conditional, LoopOverCoordinate, SympyAssignment
from pystencils.cpu.vectorization import vectorize
from pystencils.enums import Target, Backend
from pystencils.field import Field, FieldType
from pystencils.node_collection import NodeCollection
from pystencils.simp.assignment_collection import AssignmentCollection
from pystencils.kernel_contrains_check import KernelConstraintsCheck
from pystencils.simplificationfactory import create_simplification_strategy
from pystencils.stencil import direction_string_to_offset, inverse_direction_string
from pystencils.transformations import (
loop_blocking, move_constants_before_loop, remove_conditionals_in_staggered_kernel)
def create_kernel(assignments: Union[Assignment, List[Assignment],
AddAugmentedAssignment, List[AddAugmentedAssignment],
AssignmentCollection, List[Node], NodeCollection],
*,
config: CreateKernelConfig = None, **kwargs):
"""
Creates abstract syntax tree (AST) of kernel, using a list of update equations.
This function forms the general API and delegates the kernel creation to others depending on the CreateKernelConfig.
Args:
assignments: can be a single assignment, sequence of assignments or an `AssignmentCollection`
config: CreateKernelConfig which includes the needed configuration
kwargs: Arguments for updating the config
Returns:
abstract syntax tree (AST) object, that can either be printed as source code with `show_code` or
can be compiled with through its 'compile()' member
Example:
>>> import pystencils as ps
>>> import numpy as np
>>> s, d = ps.fields('s, d: [2D]')
>>> assignment = ps.Assignment(d[0,0], s[0, 1] + s[0, -1] + s[1, 0] + s[-1, 0])
>>> kernel_ast = ps.create_kernel(assignment, config=ps.CreateKernelConfig(cpu_openmp=True))
>>> kernel = kernel_ast.compile()
>>> d_arr = np.zeros([5, 5])
>>> kernel(d=d_arr, s=np.ones([5, 5]))
>>> d_arr
array([[0., 0., 0., 0., 0.],
[0., 4., 4., 4., 0.],
[0., 4., 4., 4., 0.],
[0., 4., 4., 4., 0.],
[0., 0., 0., 0., 0.]])
"""
# ---- Updating configuration from kwargs
if not config:
config = CreateKernelConfig(**kwargs)
else:
for k, v in kwargs.items():
if not hasattr(config, k):
raise KeyError(f'{v} is not a valid kwarg. Please look in CreateKernelConfig for valid settings')
setattr(config, k, v)
# ---- Normalizing parameters
if isinstance(assignments, (Assignment, AddAugmentedAssignment)):
assignments = [assignments]
assert assignments, "Assignments must not be empty!"
if isinstance(assignments, list):
assignments = NodeCollection(assignments)
elif isinstance(assignments, AssignmentCollection):
# TODO Markus check and doku
# --- applying first default simplifications
try:
if config.default_assignment_simplifications:
simplification = create_simplification_strategy()
assignments = simplification(assignments)
except Exception as e:
warnings.warn(f"It was not possible to apply the default pystencils optimisations to the "
f"AssignmentCollection due to the following problem :{e}")
simplification_hints = assignments.simplification_hints
assignments = NodeCollection.from_assignment_collection(assignments)
assignments.simplification_hints = simplification_hints
if config.index_fields:
return create_indexed_kernel(assignments, config=config)
else:
return create_domain_kernel(assignments, config=config)
def create_domain_kernel(assignments: NodeCollection, *, config: CreateKernelConfig):
"""
Creates abstract syntax tree (AST) of kernel, using a NodeCollection.
Note that `create_domain_kernel` is a lower level function which shoul be accessed by not providing `index_fields`
to create_kernel
Args:
assignments: `pystencils.node_collection.NodeCollection` containing all assignements and nodes to be processed
config: CreateKernelConfig which includes the needed configuration
Returns:
abstract syntax tree (AST) object, that can either be printed as source code with `show_code` or
can be compiled with through its 'compile()' member
Example:
>>> import pystencils as ps
>>> import numpy as np
>>> from pystencils.kernelcreation import create_domain_kernel
>>> from pystencils.node_collection import NodeCollection
>>> s, d = ps.fields('s, d: [2D]')
>>> assignment = ps.Assignment(d[0,0], s[0, 1] + s[0, -1] + s[1, 0] + s[-1, 0])
>>> kernel_config = ps.CreateKernelConfig(cpu_openmp=True)
>>> kernel_ast = create_domain_kernel(NodeCollection([assignment]), config=kernel_config)
>>> kernel = kernel_ast.compile()
>>> d_arr = np.zeros([5, 5])
>>> kernel(d=d_arr, s=np.ones([5, 5]))
>>> d_arr
array([[0., 0., 0., 0., 0.],
[0., 4., 4., 4., 0.],
[0., 4., 4., 4., 0.],
[0., 4., 4., 4., 0.],
[0., 0., 0., 0., 0.]])
"""
# --- eval
assignments.evaluate_terms()
# FUTURE WORK from here we shouldn't NEED sympy
# --- check constrains
check = KernelConstraintsCheck(check_independence_condition=not config.skip_independence_check,
check_double_write_condition=not config.allow_double_writes)
check.visit(assignments)
assignments.bound_fields = check.fields_written
assignments.rhs_fields = check.fields_read
# ---- Creating ast
ast = None
if config.target == Target.CPU:
if config.backend == Backend.C:
from pystencils.cpu import add_openmp, create_kernel
ast = create_kernel(assignments, config=config)
for optimization in config.cpu_prepend_optimizations:
optimization(ast)
omp_collapse = None
if config.cpu_blocking:
omp_collapse = loop_blocking(ast, config.cpu_blocking)
if config.cpu_openmp:
add_openmp(ast, num_threads=config.cpu_openmp, collapse=omp_collapse,
assume_single_outer_loop=config.omp_single_loop)
if config.cpu_vectorize_info:
if config.cpu_vectorize_info is True:
vectorize(ast)
elif isinstance(config.cpu_vectorize_info, dict):
vectorize(ast, **config.cpu_vectorize_info)
if config.cpu_openmp and config.cpu_blocking and 'nontemporal' in config.cpu_vectorize_info and \
config.cpu_vectorize_info['nontemporal'] and 'cachelineZero' in ast.instruction_set:
# This condition is stricter than it needs to be: if blocks along the fastest axis start on a
# cache line boundary, it's okay. But we cannot determine that here.
# We don't need to disallow OpenMP collapsing because it is never applied to the inner loop.
raise ValueError("Blocking cannot be combined with cacheline-zeroing")
else:
raise ValueError("Invalid value for cpu_vectorize_info")
elif config.target == Target.GPU:
if config.backend == Backend.CUDA:
from pystencils.gpu import create_cuda_kernel
ast = create_cuda_kernel(assignments, config=config)
if not ast:
raise NotImplementedError(
f'{config.target} together with {config.backend} is not supported by `create_domain_kernel`')
if config.use_auto_for_assignments:
for a in ast.atoms(SympyAssignment):
a.use_auto = True
return ast
def create_indexed_kernel(assignments: NodeCollection, *, config: CreateKernelConfig):
"""
Similar to :func:`create_kernel`, but here not all cells of a field are updated but only cells with
coordinates which are stored in an index field. This traversal method can e.g. be used for boundary handling.
The coordinates are stored in a separated index_field, which is a one dimensional array with struct data type.
This struct has to contain fields named 'x', 'y' and for 3D fields ('z'). These names are configurable with the
'coordinate_names' parameter. The struct can have also other fields that can be read and written in the kernel, for
example boundary parameters.
Note that `create_indexed_kernel` is a lower level function which shoul be accessed by providing `index_fields`
to create_kernel
Args:
assignments: `pystencils.node_collection.NodeCollection` containing all assignements and nodes to be processed
config: CreateKernelConfig which includes the needed configuration
Returns:
abstract syntax tree (AST) object, that can either be printed as source code with `show_code` or
can be compiled with through its 'compile()' member
Example:
>>> import pystencils as ps
>>> from pystencils.node_collection import NodeCollection
>>> import numpy as np
>>> from pystencils.kernelcreation import create_indexed_kernel
>>>
>>> # Index field stores the indices of the cell to visit together with optional values
>>> index_arr_dtype = np.dtype([('x', np.int32), ('y', np.int32), ('val', np.double)])
>>> index_arr = np.array([(1, 1, 0.1), (2, 2, 0.2), (3, 3, 0.3)], dtype=index_arr_dtype)
>>> idx_field = ps.fields(idx=index_arr)
>>>
>>> # Additional values stored in index field can be accessed in the kernel as well
>>> s, d = ps.fields('s, d: [2D]')
>>> assignment = ps.Assignment(d[0, 0], 2 * s[0, 1] + 2 * s[1, 0] + idx_field('val'))
>>> kernel_config = ps.CreateKernelConfig(index_fields=[idx_field], coordinate_names=('x', 'y'))
>>> kernel_ast = create_indexed_kernel(NodeCollection([assignment]), config=kernel_config)
>>> kernel = kernel_ast.compile()
>>> d_arr = np.zeros([5, 5])
>>> kernel(s=np.ones([5, 5]), d=d_arr, idx=index_arr)
>>> d_arr
array([[0. , 0. , 0. , 0. , 0. ],
[0. , 4.1, 0. , 0. , 0. ],
[0. , 0. , 4.2, 0. , 0. ],
[0. , 0. , 0. , 4.3, 0. ],
[0. , 0. , 0. , 0. , 0. ]])
"""
# --- eval
assignments.evaluate_terms()
# FUTURE WORK from here we shouldn't NEED sympy
# --- check constrains
check = KernelConstraintsCheck(check_independence_condition=not config.skip_independence_check,
check_double_write_condition=not config.allow_double_writes)
check.visit(assignments)
assignments.bound_fields = check.fields_written
assignments.rhs_fields = check.fields_read
ast = None
if config.target == Target.CPU and config.backend == Backend.C:
from pystencils.cpu import add_openmp, create_indexed_kernel
ast = create_indexed_kernel(assignments, config=config)
if config.cpu_openmp:
add_openmp(ast, num_threads=config.cpu_openmp)
elif config.target == Target.GPU:
if config.backend == Backend.CUDA:
from pystencils.gpu import created_indexed_cuda_kernel
ast = created_indexed_cuda_kernel(assignments, config=config)
if not ast:
raise NotImplementedError(f'Indexed kernels are not yet supported for {config.target} with {config.backend}')
return ast
def create_staggered_kernel(assignments, target: Target = Target.CPU, gpu_exclusive_conditions=False, **kwargs):
"""Kernel that updates a staggered field.
.. image:: /img/staggered_grid.svg
For a staggered field, the first index coordinate defines the location of the staggered value.
Further index coordinates can be used to store vectors/tensors at each point.
Args:
assignments: a sequence of assignments or an AssignmentCollection.
Assignments to staggered field are processed specially, while subexpressions and assignments to
regular fields are passed through to `create_kernel`. Multiple different staggered fields can be
used, but they all need to use the same stencil (i.e. the same number of staggered points) and
shape.
target: 'CPU' or 'GPU'
gpu_exclusive_conditions: disable the use of multiple conditionals inside the loop. The outer layers are then
handled in an else branch.
kwargs: passed directly to create_kernel, iteration_slice and ghost_layers parameters are not allowed
Returns:
AST, see `create_kernel`
"""
# TODO: Add doku like in the other kernels
if 'ghost_layers' in kwargs:
assert kwargs['ghost_layers'] is None
del kwargs['ghost_layers']
if 'iteration_slice' in kwargs:
assert kwargs['iteration_slice'] is None
del kwargs['iteration_slice']
if 'omp_single_loop' in kwargs:
assert kwargs['omp_single_loop'] is False
del kwargs['omp_single_loop']
if isinstance(assignments, AssignmentCollection):
subexpressions = assignments.subexpressions + [a for a in assignments.main_assignments
if not hasattr(a, 'lhs')
or type(a.lhs) is not Field.Access
or not FieldType.is_staggered(a.lhs.field)]
assignments = [a for a in assignments.main_assignments if hasattr(a, 'lhs')
and type(a.lhs) is Field.Access
and FieldType.is_staggered(a.lhs.field)]
else:
subexpressions = [a for a in assignments if not hasattr(a, 'lhs')
or type(a.lhs) is not Field.Access
or not FieldType.is_staggered(a.lhs.field)]
assignments = [a for a in assignments if hasattr(a, 'lhs')
and type(a.lhs) is Field.Access
and FieldType.is_staggered(a.lhs.field)]
if len(set([tuple(a.lhs.field.staggered_stencil) for a in assignments])) != 1:
raise ValueError("All assignments need to be made to staggered fields with the same stencil")
if len(set([a.lhs.field.shape for a in assignments])) != 1:
raise ValueError("All assignments need to be made to staggered fields with the same shape")
staggered_field = assignments[0].lhs.field
stencil = staggered_field.staggered_stencil
dim = staggered_field.spatial_dimensions
shape = staggered_field.shape
counters = [LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(dim)]
final_assignments = []
# find out whether any of the ghost layers is not needed
common_exclusions = set(["E", "W", "N", "S", "T", "B"][:2 * dim])
for direction in stencil:
exclusions = set(["E", "W", "N", "S", "T", "B"][:2 * dim])
for elementary_direction in direction:
exclusions.remove(inverse_direction_string(elementary_direction))
common_exclusions.intersection_update(exclusions)
ghost_layers = [[0, 0] for d in range(dim)]
for direction in common_exclusions:
direction = direction_string_to_offset(direction)
for d, s in enumerate(direction):
if s == 1:
ghost_layers[d][1] = 1
elif s == -1:
ghost_layers[d][0] = 1
def condition(direction):
"""exclude those staggered points that correspond to fluxes between ghost cells"""
exclusions = set(["E", "W", "N", "S", "T", "B"][:2 * dim])
for elementary_direction in direction:
exclusions.remove(inverse_direction_string(elementary_direction))
conditions = []
for e in exclusions:
if e in common_exclusions:
continue
offset = direction_string_to_offset(e)
for i, o in enumerate(offset):
if o == 1:
conditions.append(counters[i] < shape[i] - 1)
elif o == -1:
conditions.append(counters[i] > 0)
return sp.And(*conditions)
if gpu_exclusive_conditions:
outer_assignment = None
conditions = {direction: condition(direction) for direction in stencil}
for num_conditions in range(len(stencil)):
for combination in itertools.combinations(conditions.values(), num_conditions):
for assignment in assignments:
direction = stencil[assignment.lhs.index[0]]
if conditions[direction] in combination:
assignment = SympyAssignment(assignment.lhs, assignment.rhs)
outer_assignment = Conditional(sp.And(*combination), Block([assignment]), outer_assignment)
inner_assignment = []
for assignment in assignments:
inner_assignment.append(SympyAssignment(assignment.lhs, assignment.rhs))
last_conditional = Conditional(sp.And(*[condition(d) for d in stencil]),
Block(inner_assignment), outer_assignment)
final_assignments = [s for s in subexpressions if not hasattr(s, 'lhs')] + \
[SympyAssignment(s.lhs, s.rhs) for s in subexpressions if hasattr(s, 'lhs')] + \
[last_conditional]
config = CreateKernelConfig(target=target, ghost_layers=ghost_layers, omp_single_loop=False, **kwargs)
ast = create_kernel(final_assignments, config=config)
return ast
for assignment in assignments:
direction = stencil[assignment.lhs.index[0]]
sp_assignments = [s for s in subexpressions if not hasattr(s, 'lhs')] + \
[SympyAssignment(s.lhs, s.rhs) for s in subexpressions if hasattr(s, 'lhs')] + \
[SympyAssignment(assignment.lhs, assignment.rhs)]
last_conditional = Conditional(condition(direction), Block(sp_assignments))
final_assignments.append(last_conditional)
remove_start_conditional = any([gl[0] == 0 for gl in ghost_layers])
prepend_optimizations = [lambda ast: remove_conditionals_in_staggered_kernel(ast, remove_start_conditional),
move_constants_before_loop]
if 'cpu_prepend_optimizations' in kwargs:
prepend_optimizations += kwargs['cpu_prepend_optimizations']
del kwargs['cpu_prepend_optimizations']
config = CreateKernelConfig(ghost_layers=ghost_layers, target=target, omp_single_loop=False,
cpu_prepend_optimizations=prepend_optimizations, **kwargs)
ast = create_kernel(final_assignments, config=config)
return ast
from typing import Any, Dict, List, Union, Optional, Set
import sympy
import sympy as sp
from sympy.codegen.rewriting import ReplaceOptim, optimize
from pystencils.assignment import Assignment, AddAugmentedAssignment
import pystencils.astnodes as ast
from pystencils.backends.cbackend import CustomCodeNode
from pystencils.functions import DivFunc
from pystencils.simp import AssignmentCollection
from pystencils.typing import FieldPointerSymbol
class NodeCollection:
def __init__(self, assignments: List[Union[ast.Node, Assignment]],
simplification_hints: Optional[Dict[str, Any]] = None,
bound_fields: Set[sp.Symbol] = None, rhs_fields: Set[sp.Symbol] = None):
def visit(obj):
if isinstance(obj, (list, tuple)):
return [visit(e) for e in obj]
if isinstance(obj, Assignment):
if isinstance(obj.lhs, FieldPointerSymbol):
return ast.SympyAssignment(obj.lhs, obj.rhs, is_const=obj.lhs.dtype.const)
return ast.SympyAssignment(obj.lhs, obj.rhs)
elif isinstance(obj, AddAugmentedAssignment):
return ast.SympyAssignment(obj.lhs, obj.lhs + obj.rhs)
elif isinstance(obj, ast.SympyAssignment):
return obj
elif isinstance(obj, ast.Conditional):
true_block = visit(obj.true_block)
false_block = None if obj.false_block is None else visit(obj.false_block)
return ast.Conditional(obj.condition_expr, true_block=true_block, false_block=false_block)
elif isinstance(obj, ast.Block):
return ast.Block([visit(e) for e in obj.args])
elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate):
return obj
else:
raise ValueError("Invalid object in the List of Assignments " + str(type(obj)))
self.all_assignments = visit(assignments)
self.simplification_hints = simplification_hints if simplification_hints else {}
self.bound_fields = bound_fields if bound_fields else {}
self.rhs_fields = rhs_fields if rhs_fields else {}
@staticmethod
def from_assignment_collection(assignment_collection: AssignmentCollection):
return NodeCollection(assignments=assignment_collection.all_assignments,
simplification_hints=assignment_collection.simplification_hints,
bound_fields=assignment_collection.bound_fields,
rhs_fields=assignment_collection.rhs_fields)
def evaluate_terms(self):
evaluate_constant_terms = ReplaceOptim(
lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer,
lambda p: p.evalf()
)
evaluate_pow = ReplaceOptim(
lambda e: e.is_Pow and e.exp.is_Integer and abs(e.exp) <= 8,
lambda p: sp.UnevaluatedExpr(sp.Mul(*([p.base] * +p.exp), evaluate=False)) if p.exp > 0 else
(DivFunc(sp.Integer(1), p.base) if p.exp == -1 else
DivFunc(sp.Integer(1), sp.UnevaluatedExpr(sp.Mul(*([p.base] * -p.exp), evaluate=False))))
)
sympy_optimisations = [evaluate_constant_terms, evaluate_pow]
def visitor(node):
if isinstance(node, CustomCodeNode):
return node
elif isinstance(node, ast.Block):
return node.func([visitor(child) for child in node.args])
elif isinstance(node, ast.SympyAssignment):
new_lhs = visitor(node.lhs)
new_rhs = visitor(node.rhs)
return node.func(new_lhs, new_rhs, node.is_const, node.use_auto)
elif isinstance(node, ast.Node):
return node.func(*[visitor(child) for child in node.args])
elif isinstance(node, sympy.Basic):
return optimize(node, sympy_optimisations)
else:
raise NotImplementedError(f'{node} {type(node)} has no valid visitor')
self.all_assignments = [visitor(assignment) for assignment in self.all_assignments]
......@@ -34,7 +34,7 @@ def to_placeholder_function(expr, name):
"""
symbols = list(expr.atoms(sp.Symbol))
symbols.sort(key=lambda e: e.name)
derivative_symbols = [sp.Symbol("_d{}_d{}".format(name, s.name)) for s in symbols]
derivative_symbols = [sp.Symbol(f"_d{name}_d{s.name}") for s in symbols]
derivatives = [sp.diff(expr, s) for s in symbols]
assignments = [Assignment(sp.Symbol(name), expr)]
......
File moved
import copy
import numpy as np
import sympy as sp
from pystencils.typing import TypedSymbol, CastFunc
from pystencils.astnodes import LoopOverCoordinate
from pystencils.backends.cbackend import CustomCodeNode
from pystencils.sympyextensions import fast_subs
class RNGBase(CustomCodeNode):
id = 0
def __init__(self, dim, time_step=TypedSymbol("time_step", np.uint32), offsets=None, keys=None):
if keys is None:
keys = (0,) * self._num_keys
if offsets is None:
offsets = (0,) * dim
if len(keys) != self._num_keys:
raise ValueError(f"Provided {len(keys)} keys but need {self._num_keys}")
if len(offsets) != dim:
raise ValueError(f"Provided {len(offsets)} offsets but need {dim}")
coordinates = [LoopOverCoordinate.get_loop_counter_symbol(i) + offsets[i] for i in range(dim)]
if dim < 3:
coordinates.append(0)
self._args = sp.sympify([time_step, *coordinates, *keys])
self.result_symbols = tuple(TypedSymbol(f'random_{self.id}_{i}', self._data_type)
for i in range(self._num_vars))
symbols_read = set.union(*[s.atoms(sp.Symbol) for s in self.args])
super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols)
self.headers = [f'"{self._name.split("_")[0]}_rand.h"']
RNGBase.id += 1
@property
def args(self):
return self._args
def fast_subs(self, subs_dict, skip):
rng = copy.deepcopy(self)
rng._args = [fast_subs(a, subs_dict, skip) for a in rng._args]
return rng
def get_code(self, dialect, vector_instruction_set, print_arg):
code = "\n"
for r in self.result_symbols:
if vector_instruction_set and not self.args[1].atoms(CastFunc):
# this vector RNG has become scalar through substitution
code += f"{r.dtype} {r.name};\n"
else:
code += f"{vector_instruction_set[r.dtype.c_name] if vector_instruction_set else r.dtype} " + \
f"{r.name};\n"
args = [print_arg(a) for a in self.args] + ['' + r.name for r in self.result_symbols]
code += (self._name + "(" + ", ".join(args) + ");\n")
return code
def __repr__(self):
return ", ".join([str(s) for s in self.result_symbols]) + " \\leftarrow " + \
self._name.capitalize() + "_RNG(" + ", ".join([str(a) for a in self.args]) + ")"
def _hashable_content(self):
return (self._name, *self.result_symbols, *self.args)
def __eq__(self, other):
return type(self) is type(other) and self._hashable_content() == other._hashable_content()
def __hash__(self):
return hash(self._hashable_content())
class PhiloxTwoDoubles(RNGBase):
_name = "philox_double2"
_data_type = np.float64
_num_vars = 2
_num_keys = 2
class PhiloxFourFloats(RNGBase):
_name = "philox_float4"
_data_type = np.float32
_num_vars = 4
_num_keys = 2
class AESNITwoDoubles(RNGBase):
_name = "aesni_double2"
_data_type = np.float64
_num_vars = 2
_num_keys = 4
class AESNIFourFloats(RNGBase):
_name = "aesni_float4"
_data_type = np.float32
_num_vars = 4
_num_keys = 4
def random_symbol(assignment_list, dim, seed=TypedSymbol("seed", np.uint32), rng_node=PhiloxTwoDoubles,
time_step=TypedSymbol("time_step", np.uint32), offsets=None):
"""Return a symbol generator for random numbers
Args:
assignment_list: the subexpressions member of an AssignmentCollection, into which helper variables assignments
will be inserted
dim: 2 or 3 for two or three spatial dimensions
seed: an integer or TypedSymbol(..., np.uint32) to seed the random number generator. If you create multiple
symbol generators, please pass them different seeds so you don't get the same stream of random numbers!
rng_node: which random number generator to use (PhiloxTwoDoubles, PhiloxFourFloats, AESNITwoDoubles,
AESNIFourFloats).
time_step: TypedSymbol(..., np.uint32) that indicates the number of the current time step
offsets: tuple of offsets (constant integers or TypedSymbol(..., np.uint32)) that give the global coordinates
of the local origin
"""
counter = 0
while True:
keys = (counter, seed) + (0,) * (rng_node._num_keys - 2)
node = rng_node(dim, keys=keys, time_step=time_step, offsets=offsets)
inserted = False
for symbol in node.result_symbols:
if not inserted:
assignment_list.insert(0, node)
inserted = True
yield symbol
counter += 1
import socket
import time
from types import MappingProxyType
from typing import Dict, Iterator, Sequence
import blitzdb
import six
from blitzdb.backends.file.backend import serializer_classes
from blitzdb.backends.file.utils import JsonEncoder
from pystencils.cpu.cpujit import get_compiler_config
from pystencils import CreateKernelConfig, Target, Backend, Field
import json
import sympy as sp
from pystencils.typing import BasicType
class PystencilsJsonEncoder(JsonEncoder):
def default(self, obj):
if isinstance(obj, CreateKernelConfig):
return obj.__dict__
if isinstance(obj, (sp.Float, sp.Rational)):
return float(obj)
if isinstance(obj, sp.Integer):
return int(obj)
if isinstance(obj, (BasicType, MappingProxyType)):
return str(obj)
if isinstance(obj, (Target, Backend, sp.Symbol)):
return obj.name
if isinstance(obj, Field):
return f"pystencils.Field(name = {obj.name}, field_type = {obj.field_type.name}, " \
f"dtype = {str(obj.dtype)}, layout = {obj.layout}, shape = {obj.shape}, " \
f"strides = {obj.strides})"
return JsonEncoder.default(self, obj)
class PystencilsJsonSerializer(object):
@classmethod
def serialize(cls, data):
if six.PY3:
if isinstance(data, bytes):
return json.dumps(data.decode('utf-8'), cls=PystencilsJsonEncoder, ensure_ascii=False).encode('utf-8')
else:
return json.dumps(data, cls=PystencilsJsonEncoder, ensure_ascii=False).encode('utf-8')
else:
return json.dumps(data, cls=PystencilsJsonEncoder, ensure_ascii=False).encode('utf-8')
@classmethod
def deserialize(cls, data):
if six.PY3:
return json.loads(data.decode('utf-8'))
else:
return json.loads(data.decode('utf-8'))
class Database:
......@@ -46,7 +96,7 @@ class Database:
class SimulationResult(blitzdb.Document):
pass
def __init__(self, file: str) -> None:
def __init__(self, file: str, serializer_info: tuple = None) -> None:
if file.startswith("mongo://"):
from pymongo import MongoClient
db_name = file[len("mongo://"):]
......@@ -57,6 +107,10 @@ class Database:
self.backend.autocommit = True
if serializer_info:
serializer_classes.update({serializer_info[0]: serializer_info[1]})
self.backend.load_config({'serializer_class': serializer_info[0]}, True)
def save(self, params: Dict, result: Dict, env: Dict = None, **kwargs) -> None:
"""Stores a simulation result in the database.
......@@ -120,7 +174,7 @@ class Database:
Returns:
pandas data frame
"""
from pandas.io.json import json_normalize
from pandas import json_normalize
query_result = self.filter_params(parameter_query)
attributes = [e.attributes for e in query_result]
......@@ -146,10 +200,15 @@ class Database:
'cpuCompilerConfig': get_compiler_config(),
}
try:
from git import Repo, InvalidGitRepositoryError
from git import Repo
except ImportError:
return result
try:
from git import InvalidGitRepositoryError
repo = Repo(search_parent_directories=True)
result['git_hash'] = str(repo.head.commit)
except (ImportError, InvalidGitRepositoryError):
except InvalidGitRepositoryError:
pass
return result
......
......@@ -9,6 +9,7 @@ from time import sleep
from typing import Any, Callable, Dict, Optional, Sequence, Tuple
from pystencils.runhelper import Database
from pystencils.runhelper.db import PystencilsJsonSerializer
from pystencils.utils import DotDict
ParameterDict = Dict[str, Any]
......@@ -54,10 +55,11 @@ class ParameterStudy:
Run = namedtuple("Run", ['parameter_dict', 'weight'])
def __init__(self, run_function: Callable[..., Dict], runs: Sequence = (),
database_connector: str = './db') -> None:
database_connector: str = './db',
serializer_info: tuple = ('pystencils_serializer', PystencilsJsonSerializer)) -> None:
self.runs = list(runs)
self.run_function = run_function
self.db = Database(database_connector)
self.db = Database(database_connector, serializer_info)
def add_run(self, parameter_dict: ParameterDict, weight: int = 1) -> None:
"""Schedule a dictionary of parameters to run in this parameter study.
......@@ -215,7 +217,7 @@ class ParameterStudy:
def log_message(self, fmt, *args):
return
print("Listening to connections on {}:{}. Scenarios to simulate: {}".format(ip, port, len(filtered_runs)))
print(f"Listening to connections on {ip}:{port}. Scenarios to simulate: {len(filtered_runs)}")
server = HTTPServer((ip, port), ParameterStudyServer)
while len(ParameterStudyServer.currently_running) > 0 or len(ParameterStudyServer.runs) > 0:
server.handle_request()
......@@ -241,7 +243,7 @@ class ParameterStudy:
from urllib.error import URLError
import time
parameter_update = {} if parameter_update is None else parameter_update
url = "http://{}:{}".format(server, port)
url = f"http://{server}:{port}"
client_name = client_name.format(hostname=socket.gethostname(), pid=os.getpid())
start_time = time.time()
while True:
......@@ -265,7 +267,7 @@ class ParameterStudy:
'client_name': client_name}
urlopen(url + '/result', data=json.dumps(answer).encode())
except URLError:
print("Cannot connect to server {} retrying in 5 seconds...".format(url))
print(f"Cannot connect to server {url} retrying in 5 seconds...")
sleep(5)
def run_from_command_line(self, argv: Optional[Sequence[str]] = None) -> None:
......
......@@ -2,8 +2,7 @@ import numpy as np
import sympy as sp
import pystencils as ps
import pystencils.jupyter
from pystencils.jupyter import make_imshow_animation, display_animation, set_display_mode
import pystencils.plot as plt
import pystencils.sympy_gmpy_bug_workaround
__all__ = ['sp', 'np', 'ps', 'plt']
__all__ = ['sp', 'np', 'ps', 'plt', 'make_imshow_animation', 'display_animation', 'set_display_mode']
from .assignment_collection import AssignmentCollection
from .simplifications import (
add_subexpressions_for_constants,
add_subexpressions_for_divisions, add_subexpressions_for_field_reads,
apply_on_all_subexpressions, apply_to_all_assignments,
add_subexpressions_for_sums, apply_on_all_subexpressions, apply_to_all_assignments,
subexpression_substitution_in_existing_subexpressions,
subexpression_substitution_in_main_assignments, sympy_cse, sympy_cse_on_assignment_list)
from .subexpression_insertion import (
insert_aliases, insert_zeros, insert_constants,
insert_constant_additions, insert_constant_multiples,
insert_squares, insert_symbol_times_minus_one)
from .simplificationstrategy import SimplificationStrategy
__all__ = ['AssignmentCollection', 'SimplificationStrategy',
'sympy_cse', 'sympy_cse_on_assignment_list', 'apply_to_all_assignments',
'apply_on_all_subexpressions', 'subexpression_substitution_in_existing_subexpressions',
'subexpression_substitution_in_main_assignments', 'add_subexpressions_for_divisions',
'add_subexpressions_for_field_reads']
'subexpression_substitution_in_main_assignments', 'add_subexpressions_for_constants',
'add_subexpressions_for_divisions', 'add_subexpressions_for_sums', 'add_subexpressions_for_field_reads',
'insert_aliases', 'insert_zeros', 'insert_constants',
'insert_constant_additions', 'insert_constant_multiples',
'insert_squares', 'insert_symbol_times_minus_one']
import itertools
from copy import copy
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union
import sympy as sp
import pystencils
from pystencils.assignment import Assignment
from pystencils.simp.simplifications import (
sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs)
from pystencils.simp.simplifications import (sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs)
from pystencils.sympyextensions import count_operations, fast_subs
......@@ -18,13 +19,14 @@ class AssignmentCollection:
Additionally a dictionary of simplification hints is stored, which are set by the functions that create
assignment collections to transport information to the simplification system.
Attributes:
main_assignments: list of assignments
subexpressions: list of assignments defining subexpressions used in main equations
simplification_hints: dict that is used to annotate the assignment collection with hints that are
Args:
main_assignments: List of assignments. Main assignments are characterised, that the right hand side of each
assignment is a field access. Thus the generated equations write on arrays.
subexpressions: List of assignments defining subexpressions used in main equations
simplification_hints: Dict that is used to annotate the assignment collection with hints that are
used by the simplification system. See documentation of the simplification rules for
potentially required hints and their meaning.
subexpression_symbol_generator: generator for new symbols that are used when new subexpressions are added
subexpression_symbol_generator: Generator for new symbols that are used when new subexpressions are added
used to get new symbols that are unique for this AssignmentCollection
"""
......@@ -32,9 +34,13 @@ class AssignmentCollection:
# ------------------------------- Creation & Inplace Manipulation --------------------------------------------------
def __init__(self, main_assignments: Union[List[Assignment], Dict[sp.Expr, sp.Expr]],
subexpressions: Union[List[Assignment], Dict[sp.Expr, sp.Expr]] = {},
subexpressions: Union[List[Assignment], Dict[sp.Expr, sp.Expr]] = None,
simplification_hints: Optional[Dict[str, Any]] = None,
subexpression_symbol_generator: Iterator[sp.Symbol] = None) -> None:
if subexpressions is None:
subexpressions = {}
if isinstance(main_assignments, Dict):
main_assignments = [Assignment(k, v)
for k, v in main_assignments.items()]
......@@ -42,6 +48,11 @@ class AssignmentCollection:
subexpressions = [Assignment(k, v)
for k, v in subexpressions.items()]
main_assignments = list(itertools.chain.from_iterable(
[(a if isinstance(a, Iterable) else [a]) for a in main_assignments]))
subexpressions = list(itertools.chain.from_iterable(
[(a if isinstance(a, Iterable) else [a]) for a in subexpressions]))
self.main_assignments = main_assignments
self.subexpressions = subexpressions
......@@ -50,8 +61,11 @@ class AssignmentCollection:
self.simplification_hints = simplification_hints
ctrs = [int(n.name[3:])for n in self.rhs_symbols if "xi_" in n.name]
max_ctr = max(ctrs) + 1 if len(ctrs) > 0 else 0
if subexpression_symbol_generator is None:
self.subexpression_symbol_generator = SymbolGen()
self.subexpression_symbol_generator = SymbolGen(ctr=max_ctr)
else:
self.subexpression_symbol_generator = subexpression_symbol_generator
......@@ -95,22 +109,45 @@ class AssignmentCollection:
"""Subexpression and main equations as a single list."""
return self.subexpressions + self.main_assignments
@property
def rhs_symbols(self) -> Set[sp.Symbol]:
"""All symbols used in the assignment collection, which occur on the rhs of any assignment."""
rhs_symbols = set()
for eq in self.all_assignments:
if isinstance(eq, Assignment):
rhs_symbols.update(eq.rhs.atoms(sp.Symbol))
elif isinstance(eq, pystencils.astnodes.Node):
rhs_symbols.update(eq.undefined_symbols)
return rhs_symbols
@property
def free_symbols(self) -> Set[sp.Symbol]:
"""All symbols used in the assignment collection, which do not occur as left hand sides in any assignment."""
free_symbols = set()
for eq in self.all_assignments:
free_symbols.update(eq.rhs.atoms(sp.Symbol))
return free_symbols - self.bound_symbols
return self.rhs_symbols - self.bound_symbols
@property
def bound_symbols(self) -> Set[sp.Symbol]:
"""All symbols which occur on the left hand side of a main assignment or a subexpression."""
bound_symbols_set = set([eq.lhs for eq in self.all_assignments])
assert len(bound_symbols_set) == len(self.subexpressions) + len(self.main_assignments), \
bound_symbols_set = set(
[assignment.lhs for assignment in self.all_assignments if isinstance(assignment, Assignment)]
)
assert len(bound_symbols_set) == len(list(a for a in self.all_assignments if isinstance(a, Assignment))), \
"Not in SSA form - same symbol assigned multiple times"
bound_symbols_set = bound_symbols_set.union(*[
assignment.symbols_defined for assignment in self.all_assignments
if isinstance(assignment, pystencils.astnodes.Node)
])
return bound_symbols_set
@property
def rhs_fields(self):
"""All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment."""
return {s.field for s in self.rhs_symbols if hasattr(s, 'field')}
@property
def free_fields(self):
"""All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment."""
......@@ -124,13 +161,18 @@ class AssignmentCollection:
@property
def defined_symbols(self) -> Set[sp.Symbol]:
"""All symbols which occur as left-hand-sides of one of the main equations"""
return set([assignment.lhs for assignment in self.main_assignments])
lhs_set = set([assignment.lhs for assignment in self.main_assignments if isinstance(assignment, Assignment)])
return (lhs_set.union(*[assignment.symbols_defined for assignment in self.main_assignments
if isinstance(assignment, pystencils.astnodes.Node)]))
@property
def operation_count(self):
"""See :func:`count_operations` """
return count_operations(self.all_assignments, only_type=None)
def atoms(self, *args):
return set().union(*[a.atoms(*args) for a in self.all_assignments])
def dependent_symbols(self, symbols: Iterable[sp.Symbol]) -> Set[sp.Symbol]:
"""Returns all symbols that depend on one of the passed symbols.
......@@ -182,6 +224,7 @@ class AssignmentCollection:
return {s: func(*args, **kwargs) for s, func in lambdas.items()}
return f
# ---------------------------- Creating new modified collections ---------------------------------------------------
def copy(self,
......@@ -235,7 +278,7 @@ class AssignmentCollection:
own_definitions = set([e.lhs for e in self.main_assignments])
other_definitions = set([e.lhs for e in other.main_assignments])
assert len(own_definitions.intersection(other_definitions)) == 0, \
"Cannot new_merged, since both collection define the same symbols"
"Cannot merge collections, since both define the same symbols"
own_subexpression_symbols = {e.lhs: e.rhs for e in self.subexpressions}
substitution_dict = {}
......@@ -243,12 +286,13 @@ class AssignmentCollection:
processed_other_subexpression_equations = []
for other_subexpression_eq in other.subexpressions:
if other_subexpression_eq.lhs in own_subexpression_symbols:
if other_subexpression_eq.rhs == own_subexpression_symbols[other_subexpression_eq.lhs]:
new_rhs = fast_subs(other_subexpression_eq.rhs, substitution_dict)
if new_rhs == own_subexpression_symbols[other_subexpression_eq.lhs]:
continue # exact the same subexpression equation exists already
else:
# different definition - a new name has to be introduced
new_lhs = next(self.subexpression_symbol_generator)
new_eq = Assignment(new_lhs, fast_subs(other_subexpression_eq.rhs, substitution_dict))
new_eq = Assignment(new_lhs, new_rhs)
processed_other_subexpression_equations.append(new_eq)
substitution_dict[other_subexpression_eq.lhs] = new_lhs
else:
......@@ -271,9 +315,9 @@ class AssignmentCollection:
if eq.lhs in symbols_to_extract:
new_assignments.append(eq)
new_sub_expr = [eq for eq in self.subexpressions
new_sub_expr = [eq for eq in self.all_assignments
if eq.lhs in dependent_symbols and eq.lhs not in symbols_to_extract]
return AssignmentCollection(new_assignments, new_sub_expr)
return self.copy(new_assignments, new_sub_expr)
def new_without_unused_subexpressions(self) -> 'AssignmentCollection':
"""Returns new collection that only contains subexpressions required to compute the main assignments."""
......@@ -296,8 +340,10 @@ class AssignmentCollection:
new_eqs = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in self.main_assignments]
return self.copy(new_eqs, new_subexpressions)
def new_without_subexpressions(self, subexpressions_to_keep: Set[sp.Symbol] = set()) -> 'AssignmentCollection':
def new_without_subexpressions(self, subexpressions_to_keep=None) -> 'AssignmentCollection':
"""Returns a new collection where all subexpressions have been inserted."""
if subexpressions_to_keep is None:
subexpressions_to_keep = set()
if len(self.subexpressions) == 0:
return self.copy()
......@@ -306,7 +352,7 @@ class AssignmentCollection:
kept_subexpressions = []
if self.subexpressions[0].lhs in subexpressions_to_keep:
substitution_dict = {}
kept_subexpressions = self.subexpressions[0]
kept_subexpressions.append(self.subexpressions[0])
else:
substitution_dict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs}
......@@ -325,6 +371,7 @@ class AssignmentCollection:
def _repr_html_(self):
"""Interface to Jupyter notebook, to display as a nicely formatted HTML table"""
def make_html_equation_table(equations):
no_border = 'style="border:none"'
html_table = '<table style="border:none; width: 100%; ">'
......@@ -345,15 +392,15 @@ class AssignmentCollection:
return result
def __repr__(self):
return "Assignment Collection for " + ",".join([str(eq.lhs) for eq in self.main_assignments])
return f"AssignmentCollection: {str(tuple(self.defined_symbols))[1:-1]} <- f{tuple(self.free_symbols)}"
def __str__(self):
result = "Subexpressions:\n"
for eq in self.subexpressions:
result += "\t{eq}\n".format(eq=eq)
result += f"\t{eq}\n"
result += "Main Assignments:\n"
for eq in self.main_assignments:
result += "\t{eq}\n".format(eq=eq)
result += f"\t{eq}\n"
return result
def __iter__(self):
......@@ -403,18 +450,24 @@ class AssignmentCollection:
def __eq__(self, other):
return set(self.all_assignments) == set(other.all_assignments)
def __bool__(self):
return bool(self.all_assignments)
class SymbolGen:
"""Default symbol generator producing number symbols ζ_0, ζ_1, ..."""
def __init__(self, symbol="xi"):
self._ctr = 0
def __init__(self, symbol="xi", dtype=None, ctr=0):
self._ctr = ctr
self._symbol = symbol
self._dtype = dtype
def __iter__(self):
return self
def __next__(self):
name = "{}_{}".format(self._symbol, self._ctr)
name = f"{self._symbol}_{self._ctr}"
self._ctr += 1
if self._dtype is not None:
return pystencils.TypedSymbol(name, self._dtype)
return sp.Symbol(name)