Commit 9abd0e4a authored by Martin Bauer's avatar Martin Bauer
Browse files

use SIMD inverse sqrt approximation when available

parent 5ad56d04
......@@ -330,7 +330,10 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
elif expr.func == fast_sqrt:
return "({})".format(self._print(sp.sqrt(expr.args[0])))
elif expr.func == fast_inv_sqrt:
return "({})".format(self._print(1 / sp.sqrt(expr.args[0])))
if self.instruction_set['rsqrt']:
return self.instruction_set['rsqrt'].format(self._print(expr.args[0]))
else:
return "({})".format(self._print(1 / sp.sqrt(expr.args[0])))
return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)
def _print_And(self, expr):
......
......@@ -93,6 +93,16 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
result['bool'] = "__m%dd" % (bit_width,)
result['headers'] = headers[instruction_set]
if instruction_set == 'avx512' and data_type == 'double':
result['rsqrt'] = "_mm512_rsqrt14_pd({0})"
elif instruction_set == 'avx512' and data_type == 'float':
result['rsqrt'] = "_mm512_rsqrt14_ps({0})"
elif instruction_set == 'avx' and data_type == 'float':
result['rsqrt'] = "_mm256_rsqrt_ps({0})"
else:
result['rsqrt'] = None
return result
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment