Skip to content
Snippets Groups Projects
Commit ecb1614f authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Implement fast approximations for OpenCL: fast_division, fast_inv_sqrt, fast_sqrt

parent bc556567
Branches
Tags
No related merge requests found
......@@ -2,6 +2,7 @@ import pystencils.data_types
from pystencils.astnodes import Node
from pystencils.backends.cbackend import CustomSympyPrinter, generate_c
from pystencils.backends.cuda_backend import CudaBackend, CudaSympyPrinter
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
def generate_opencl(astnode: Node, signature_only: bool = False) -> str:
......@@ -67,5 +68,17 @@ class OpenClSympyPrinter(CudaSympyPrinter):
def _print_TextureAccess(self, node):
raise NotImplementedError()
# Avoid usage of CUDA intrinsics
_print_Function = CustomSympyPrinter._print_Function
# For math functions, OpenCL is more similar to the C++ printer CustomSympyPrinter
# since built-in math functions are generic.
# In CUDA, you have to differentiate between `sin` and `sinf`
_print_math_func = CustomSympyPrinter._print_math_func
_print_Pow = CustomSympyPrinter._print_Pow
def _print_Function(self, expr):
if isinstance(expr, fast_division):
return "native_divide(%s, %s)" % tuple(self._print(a) for a in expr.args)
elif isinstance(expr, fast_sqrt):
return "native_sqrt(%s)" % tuple(self._print(a) for a in expr.args)
elif isinstance(expr, fast_inv_sqrt):
return "native_rsqrt(%s)" % tuple(self._print(a) for a in expr.args)
return CustomSympyPrinter._print_Function(self, expr)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment