From ecb1614f82e237fca0b0b85c7f6350a80beca104 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Thu, 15 Aug 2019 14:06:44 +0200
Subject: [PATCH] Implement fast approximations for OpenCL: fast_division,
 fast_inv_sqrt, fast_sqrt

---
 pystencils/backends/opencl_backend.py | 17 +++++++++++++++--
 1 file changed, 15 insertions(+), 2 deletions(-)

diff --git a/pystencils/backends/opencl_backend.py b/pystencils/backends/opencl_backend.py
index 55af4a6..b70b3ce 100644
--- a/pystencils/backends/opencl_backend.py
+++ b/pystencils/backends/opencl_backend.py
@@ -2,6 +2,7 @@ import pystencils.data_types
 from pystencils.astnodes import Node
 from pystencils.backends.cbackend import CustomSympyPrinter, generate_c
 from pystencils.backends.cuda_backend import CudaBackend, CudaSympyPrinter
+from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
 
 
 def generate_opencl(astnode: Node, signature_only: bool = False) -> str:
@@ -67,5 +68,17 @@ class OpenClSympyPrinter(CudaSympyPrinter):
     def _print_TextureAccess(self, node):
         raise NotImplementedError()
 
-    # Avoid usage of CUDA intrinsics
-    _print_Function = CustomSympyPrinter._print_Function
+    # For math functions, OpenCL is more similar to the C++ printer CustomSympyPrinter
+    # since built-in math functions are generic.
+    # In CUDA, you have to differentiate between `sin` and `sinf`
+    _print_math_func = CustomSympyPrinter._print_math_func
+    _print_Pow = CustomSympyPrinter._print_Pow
+
+    def _print_Function(self, expr):
+        if isinstance(expr, fast_division):
+            return "native_divide(%s, %s)" % tuple(self._print(a) for a in expr.args)
+        elif isinstance(expr, fast_sqrt):
+            return "native_sqrt(%s)" % tuple(self._print(a) for a in expr.args)
+        elif isinstance(expr, fast_inv_sqrt):
+            return "native_rsqrt(%s)" % tuple(self._print(a) for a in expr.args)
+        return CustomSympyPrinter._print_Function(self, expr)
-- 
GitLab