diff --git a/backends/cbackend.py b/backends/cbackend.py index 7dd75d0a6f37b7051b65b40d5facc0f7d7cfad44..67f25525381a2ea00f7dfbabea03aa017699de3b 100644 --- a/backends/cbackend.py +++ b/backends/cbackend.py @@ -330,7 +330,10 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): elif expr.func == fast_sqrt: return "({})".format(self._print(sp.sqrt(expr.args[0]))) elif expr.func == fast_inv_sqrt: - return "({})".format(self._print(1 / sp.sqrt(expr.args[0]))) + if self.instruction_set['rsqrt']: + return self.instruction_set['rsqrt'].format(self._print(expr.args[0])) + else: + return "({})".format(self._print(1 / sp.sqrt(expr.args[0]))) return super(VectorizedCustomSympyPrinter, self)._print_Function(expr) def _print_And(self, expr): diff --git a/backends/simd_instruction_sets.py b/backends/simd_instruction_sets.py index 56ee997aeac4379e774016bd34a013907b708d27..2d88352bb0833338e69f83e7b1cc3a4accde6d49 100644 --- a/backends/simd_instruction_sets.py +++ b/backends/simd_instruction_sets.py @@ -93,6 +93,16 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'): result['bool'] = "__m%dd" % (bit_width,) result['headers'] = headers[instruction_set] + + if instruction_set == 'avx512' and data_type == 'double': + result['rsqrt'] = "_mm512_rsqrt14_pd({0})" + elif instruction_set == 'avx512' and data_type == 'float': + result['rsqrt'] = "_mm512_rsqrt14_ps({0})" + elif instruction_set == 'avx' and data_type == 'float': + result['rsqrt'] = "_mm256_rsqrt_ps({0})" + else: + result['rsqrt'] = None + return result