diff --git a/pystencils/include/aesni_rand.h b/pystencils/include/aesni_rand.h index 4646b17c15d7d0f8fa28e2b01b160a632d080f4f..648c9f7aea2e7fdcd05400af9f7ade95a682979e 100644 --- a/pystencils/include/aesni_rand.h +++ b/pystencils/include/aesni_rand.h @@ -218,6 +218,30 @@ QUALIFIERS __m256i aesni1xm128i(const __m256i & in, const __m256i & k) { return x; } +QUALIFIERS void aesni_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, uint32 key2, uint32 key3, + __m128 & rnd1, __m128 & rnd2, __m128 & rnd3, __m128 & rnd4) +{ + __m128i ctr0v = _mm_add_epi32(_mm_set1_epi32(ctr0), _mm_set_epi32(3,2,1,0)); + __m128i ctr1v = _mm_set1_epi32(ctr1); + __m128i ctr2v = _mm_set1_epi32(ctr2); + __m128i ctr3v = _mm_set1_epi32(ctr3); + + aesni_float4(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, key2, key3, rnd1, rnd2, rnd3, rnd4); +} + +QUALIFIERS void philox_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, uint32 key2, uint32 key3, + __m128d & rnd1lo, __m128d & rnd1hi, __m128d & rnd2lo, __m128d & rnd2hi) +{ + __m128i ctr0v = _mm_add_epi32(_mm_set1_epi32(ctr0), _mm_set_epi32(3,2,1,0)); + __m128i ctr1v = _mm_set1_epi32(ctr1); + __m128i ctr2v = _mm_set1_epi32(ctr2); + __m128i ctr3v = _mm_set1_epi32(ctr3); + + aesni_double2(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, key2, key3, rnd1lo, rnd1hi, rnd2lo, rnd2hi); +} + template<bool high> QUALIFIERS __m256d _uniform_double_hq(__m256i x, __m256i y) { @@ -350,6 +374,30 @@ QUALIFIERS void aesni_double2(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]); rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]); } + +QUALIFIERS void aesni_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, uint32 key2, uint32 key3, + __m256 & rnd1, __m256 & rnd2, __m256 & rnd3, __m256 & rnd4) +{ + __m256i ctr0v = _mm256_add_epi32(_mm256_set1_epi32(ctr0), _mm256_set_epi32(7,6,5,4,3,2,1,0)); + __m256i ctr1v = _mm256_set1_epi32(ctr1); + __m256i ctr2v = _mm256_set1_epi32(ctr2); + __m256i ctr3v = _mm256_set1_epi32(ctr3); + + aesni_float4(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, key2, key3, rnd1, rnd2, rnd3, rnd4); +} + +QUALIFIERS void aesni_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, uint32 key2, uint32 key3, + __m256d & rnd1lo, __m256d & rnd1hi, __m256d & rnd2lo, __m256d & rnd2hi) +{ + __m256i ctr0v = _mm256_add_epi32(_mm256_set1_epi32(ctr0), _mm256_set_epi32(7,6,5,4,3,2,1,0)); + __m256i ctr1v = _mm256_set1_epi32(ctr1); + __m256i ctr2v = _mm256_set1_epi32(ctr2); + __m256i ctr3v = _mm256_set1_epi32(ctr3); + + aesni_double2(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, key2, key3, rnd1lo, rnd1hi, rnd2lo, rnd2hi); +} #endif @@ -511,5 +559,29 @@ QUALIFIERS void aesni_double2(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]); rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]); } + +QUALIFIERS void aesni_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, uint32 key2, uint32 key3, + __m512 & rnd1, __m512 & rnd2, __m512 & rnd3, __m512 & rnd4) +{ + __m512i ctr0v = _mm512_add_epi32(_mm512_set1_epi32(ctr0), _mm512_set_epi32(15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0)); + __m512i ctr1v = _mm512_set1_epi32(ctr1); + __m512i ctr2v = _mm512_set1_epi32(ctr2); + __m512i ctr3v = _mm512_set1_epi32(ctr3); + + philox_float4(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, key2, key3, rnd1, rnd2, rnd3, rnd4); +} + +QUALIFIERS void aesni_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, uint32 key2, uint32 key3, + __m512d & rnd1lo, __m512d & rnd1hi, __m512d & rnd2lo, __m512d & rnd2hi) +{ + __m512i ctr0v = _mm512_add_epi32(_mm512_set1_epi32(ctr0), _mm512_set_epi32(15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0)); + __m512i ctr1v = _mm512_set1_epi32(ctr1); + __m512i ctr2v = _mm512_set1_epi32(ctr2); + __m512i ctr3v = _mm512_set1_epi32(ctr3); + + philox_double2(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, key2, key3, rnd1lo, rnd1hi, rnd2lo, rnd2hi); +} #endif diff --git a/pystencils/include/philox_rand.h b/pystencils/include/philox_rand.h index 5df48abb0b41779d36c0b68dc3c6fb9fbfdca6db..1f6fed52ab2b79a5165e3fa538ca501504ba32c3 100644 --- a/pystencils/include/philox_rand.h +++ b/pystencils/include/philox_rand.h @@ -241,6 +241,30 @@ QUALIFIERS void philox_double2(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]); rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]); } + +QUALIFIERS void philox_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, + __m128 & rnd1, __m128 & rnd2, __m128 & rnd3, __m128 & rnd4) +{ + __m128i ctr0v = _mm_add_epi32(_mm_set1_epi32(ctr0), _mm_set_epi32(3,2,1,0)); + __m128i ctr1v = _mm_set1_epi32(ctr1); + __m128i ctr2v = _mm_set1_epi32(ctr2); + __m128i ctr3v = _mm_set1_epi32(ctr3); + + philox_float4(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4); +} + +QUALIFIERS void philox_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, + __m128d & rnd1lo, __m128d & rnd1hi, __m128d & rnd2lo, __m128d & rnd2hi) +{ + __m128i ctr0v = _mm_add_epi32(_mm_set1_epi32(ctr0), _mm_set_epi32(3,2,1,0)); + __m128i ctr1v = _mm_set1_epi32(ctr1); + __m128i ctr2v = _mm_set1_epi32(ctr2); + __m128i ctr3v = _mm_set1_epi32(ctr3); + + philox_double2(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi); +} #endif #ifdef __AVX2__ @@ -369,6 +393,30 @@ QUALIFIERS void philox_double2(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]); rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]); } + +QUALIFIERS void philox_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, + __m256 & rnd1, __m256 & rnd2, __m256 & rnd3, __m256 & rnd4) +{ + __m256i ctr0v = _mm256_add_epi32(_mm256_set1_epi32(ctr0), _mm256_set_epi32(7,6,5,4,3,2,1,0)); + __m256i ctr1v = _mm256_set1_epi32(ctr1); + __m256i ctr2v = _mm256_set1_epi32(ctr2); + __m256i ctr3v = _mm256_set1_epi32(ctr3); + + philox_float4(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4); +} + +QUALIFIERS void philox_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, + __m256d & rnd1lo, __m256d & rnd1hi, __m256d & rnd2lo, __m256d & rnd2hi) +{ + __m256i ctr0v = _mm256_add_epi32(_mm256_set1_epi32(ctr0), _mm256_set_epi32(7,6,5,4,3,2,1,0)); + __m256i ctr1v = _mm256_set1_epi32(ctr1); + __m256i ctr2v = _mm256_set1_epi32(ctr2); + __m256i ctr3v = _mm256_set1_epi32(ctr3); + + philox_double2(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi); +} #endif #ifdef __AVX512F__ @@ -481,6 +529,30 @@ QUALIFIERS void philox_double2(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]); rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]); } + +QUALIFIERS void philox_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, + __m512 & rnd1, __m512 & rnd2, __m512 & rnd3, __m512 & rnd4) +{ + __m512i ctr0v = _mm512_add_epi32(_mm512_set1_epi32(ctr0), _mm512_set_epi32(15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0)); + __m512i ctr1v = _mm512_set1_epi32(ctr1); + __m512i ctr2v = _mm512_set1_epi32(ctr2); + __m512i ctr3v = _mm512_set1_epi32(ctr3); + + philox_float4(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4); +} + +QUALIFIERS void philox_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, + __m512d & rnd1lo, __m512d & rnd1hi, __m512d & rnd2lo, __m512d & rnd2hi) +{ + __m512i ctr0v = _mm512_add_epi32(_mm512_set1_epi32(ctr0), _mm512_set_epi32(15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0)); + __m512i ctr1v = _mm512_set1_epi32(ctr1); + __m512i ctr2v = _mm512_set1_epi32(ctr2); + __m512i ctr3v = _mm512_set1_epi32(ctr3); + + philox_double2(ctr0v, ctr1v, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi); +} #endif #endif diff --git a/pystencils/rng.py b/pystencils/rng.py index 5bc91b5678de9ec9e10bc43953604c4978eeaae7..b1168724aa497153a130a6fbcdd7beaef2cefcef 100644 --- a/pystencils/rng.py +++ b/pystencils/rng.py @@ -6,16 +6,23 @@ from pystencils.astnodes import LoopOverCoordinate from pystencils.backends.cbackend import CustomCodeNode -def _get_rng_template(name, data_type, num_vars): +def _data_type_to_str(data_type): if data_type is np.float32: - c_type = "float" + return "float" elif data_type is np.float64: - c_type = "double" + return "double" + elif type(data_type) is str: + return data_type + raise ValueError("%s is not a supported data type" % (data_type, )) + + +def _get_rng_template(name, data_type, num_vars): + c_type = _data_type_to_str(data_type) template = "\n" for i in range(num_vars): - template += "{{result_symbols[{}].dtype}} {{result_symbols[{}].name}};\n".format(i, i) - template += ("{}_{}{}({{parameters}}, " + ", ".join(["{{result_symbols[{}].name}}"] * num_vars) + ");\n") \ - .format(name, c_type, num_vars, *tuple(range(num_vars))) + template += "{} {{result_symbols[{}].name}};\n".format(c_type, i, i) + template += ("{}({{parameters}}, " + ", ".join(["{{result_symbols[{}].name}}"] * num_vars) + ");\n") \ + .format(name, *tuple(range(num_vars))) return template @@ -23,7 +30,7 @@ def _get_rng_code(template, dialect, vector_instruction_set, time_step, offsets, parameters = [time_step] + [LoopOverCoordinate.get_loop_counter_symbol(i) + offsets[i] for i in range(dim)] + [0] * (3 - dim) + list(keys) - if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None): + if dialect == 'cuda' or dialect == 'c': return template.format(parameters=', '.join(str(p) for p in parameters), result_symbols=result_symbols) else: @@ -44,7 +51,7 @@ class RNGBase(CustomCodeNode): super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols) self._time_step = time_step self._offsets = offsets - self.headers = ['"{}_rand.h"'.format(self._name)] + self.headers = ['"{}_rand.h"'.format(self._name.split('_')[0])] self.keys = tuple(keys) self._args = sp.sympify((dim, time_step, keys)) self._dim = dim @@ -65,7 +72,11 @@ class RNGBase(CustomCodeNode): return self # nothing to replace inside this node - would destroy intermediate "dummy" by re-creating them def get_code(self, dialect, vector_instruction_set): - template = _get_rng_template(self._name, self._data_type, self._num_vars) + if vector_instruction_set: + template = _get_rng_template(self._name, vector_instruction_set[_data_type_to_str(self._data_type)], + self._num_vars) + else: + template = _get_rng_template(self._name, self._data_type, self._num_vars) return _get_rng_code(template, dialect, vector_instruction_set, self._time_step, self._offsets, self.keys, self._dim, self.result_symbols) @@ -74,28 +85,28 @@ class RNGBase(CustomCodeNode): class PhiloxTwoDoubles(RNGBase): - _name = "philox" + _name = "philox_double2" _data_type = np.float64 _num_vars = 2 _num_keys = 2 class PhiloxFourFloats(RNGBase): - _name = "philox" + _name = "philox_float4" _data_type = np.float32 _num_vars = 4 _num_keys = 2 class AESNITwoDoubles(RNGBase): - _name = "aesni" + _name = "aesni_double2" _data_type = np.float64 _num_vars = 2 _num_keys = 4 class AESNIFourFloats(RNGBase): - _name = "aesni" + _name = "aesni_float4" _data_type = np.float32 _num_vars = 4 _num_keys = 4