From 9abd0e4a9cf79c7927b09736e6bed554a64ee67d Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Fri, 15 Mar 2019 09:13:43 +0100
Subject: [PATCH] use SIMD inverse sqrt approximation when available

---
 backends/cbackend.py              |  5 ++++-
 backends/simd_instruction_sets.py | 10 ++++++++++
 2 files changed, 14 insertions(+), 1 deletion(-)

diff --git a/backends/cbackend.py b/backends/cbackend.py
index 7dd75d0a6..67f255253 100644
--- a/backends/cbackend.py
+++ b/backends/cbackend.py
@@ -330,7 +330,10 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
         elif expr.func == fast_sqrt:
             return "({})".format(self._print(sp.sqrt(expr.args[0])))
         elif expr.func == fast_inv_sqrt:
-            return "({})".format(self._print(1 / sp.sqrt(expr.args[0])))
+            if self.instruction_set['rsqrt']:
+                return self.instruction_set['rsqrt'].format(self._print(expr.args[0]))
+            else:
+                return "({})".format(self._print(1 / sp.sqrt(expr.args[0])))
         return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)
 
     def _print_And(self, expr):
diff --git a/backends/simd_instruction_sets.py b/backends/simd_instruction_sets.py
index 56ee997ae..2d88352bb 100644
--- a/backends/simd_instruction_sets.py
+++ b/backends/simd_instruction_sets.py
@@ -93,6 +93,16 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
     result['bool'] = "__m%dd" % (bit_width,)
 
     result['headers'] = headers[instruction_set]
+
+    if instruction_set == 'avx512' and data_type == 'double':
+        result['rsqrt'] = "_mm512_rsqrt14_pd({0})"
+    elif instruction_set == 'avx512' and data_type == 'float':
+        result['rsqrt'] = "_mm512_rsqrt14_ps({0})"
+    elif instruction_set == 'avx' and data_type == 'float':
+        result['rsqrt'] = "_mm256_rsqrt_ps({0})"
+    else:
+        result['rsqrt'] = None
+
     return result
 
 
-- 
GitLab