From ecb1614f82e237fca0b0b85c7f6350a80beca104 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Thu, 15 Aug 2019 14:06:44 +0200 Subject: [PATCH] Implement fast approximations for OpenCL: fast_division, fast_inv_sqrt, fast_sqrt --- pystencils/backends/opencl_backend.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/pystencils/backends/opencl_backend.py b/pystencils/backends/opencl_backend.py index 55af4a6..b70b3ce 100644 --- a/pystencils/backends/opencl_backend.py +++ b/pystencils/backends/opencl_backend.py @@ -2,6 +2,7 @@ import pystencils.data_types from pystencils.astnodes import Node from pystencils.backends.cbackend import CustomSympyPrinter, generate_c from pystencils.backends.cuda_backend import CudaBackend, CudaSympyPrinter +from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt def generate_opencl(astnode: Node, signature_only: bool = False) -> str: @@ -67,5 +68,17 @@ class OpenClSympyPrinter(CudaSympyPrinter): def _print_TextureAccess(self, node): raise NotImplementedError() - # Avoid usage of CUDA intrinsics - _print_Function = CustomSympyPrinter._print_Function + # For math functions, OpenCL is more similar to the C++ printer CustomSympyPrinter + # since built-in math functions are generic. + # In CUDA, you have to differentiate between `sin` and `sinf` + _print_math_func = CustomSympyPrinter._print_math_func + _print_Pow = CustomSympyPrinter._print_Pow + + def _print_Function(self, expr): + if isinstance(expr, fast_division): + return "native_divide(%s, %s)" % tuple(self._print(a) for a in expr.args) + elif isinstance(expr, fast_sqrt): + return "native_sqrt(%s)" % tuple(self._print(a) for a in expr.args) + elif isinstance(expr, fast_inv_sqrt): + return "native_rsqrt(%s)" % tuple(self._print(a) for a in expr.args) + return CustomSympyPrinter._print_Function(self, expr) -- GitLab