diff --git a/pystencils/backends/opencl_backend.py b/pystencils/backends/opencl_backend.py index 55af4a6bda3ee1417883bd1dfbbbdc0df78051f6..b70b3ce72fbc494261bb059c6948879669d3ec9f 100644 --- a/pystencils/backends/opencl_backend.py +++ b/pystencils/backends/opencl_backend.py @@ -2,6 +2,7 @@ import pystencils.data_types from pystencils.astnodes import Node from pystencils.backends.cbackend import CustomSympyPrinter, generate_c from pystencils.backends.cuda_backend import CudaBackend, CudaSympyPrinter +from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt def generate_opencl(astnode: Node, signature_only: bool = False) -> str: @@ -67,5 +68,17 @@ class OpenClSympyPrinter(CudaSympyPrinter): def _print_TextureAccess(self, node): raise NotImplementedError() - # Avoid usage of CUDA intrinsics - _print_Function = CustomSympyPrinter._print_Function + # For math functions, OpenCL is more similar to the C++ printer CustomSympyPrinter + # since built-in math functions are generic. + # In CUDA, you have to differentiate between `sin` and `sinf` + _print_math_func = CustomSympyPrinter._print_math_func + _print_Pow = CustomSympyPrinter._print_Pow + + def _print_Function(self, expr): + if isinstance(expr, fast_division): + return "native_divide(%s, %s)" % tuple(self._print(a) for a in expr.args) + elif isinstance(expr, fast_sqrt): + return "native_sqrt(%s)" % tuple(self._print(a) for a in expr.args) + elif isinstance(expr, fast_inv_sqrt): + return "native_rsqrt(%s)" % tuple(self._print(a) for a in expr.args) + return CustomSympyPrinter._print_Function(self, expr)