From faf330f83b8fdec3974b532bd957d81f271369ff Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Tue, 9 Jul 2019 14:01:08 +0200
Subject: [PATCH] Add CudaBackend, CudaSympyPrinter

---
 pystencils/backends/cbackend.py              |  52 ++--
 pystencils/backends/cuda_backend.py          |  88 ++++++
 pystencils/backends/cuda_known_functions.txt | 293 +++++++++++++++++++
 3 files changed, 406 insertions(+), 27 deletions(-)
 create mode 100644 pystencils/backends/cuda_backend.py
 create mode 100644 pystencils/backends/cuda_known_functions.txt

diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py
index 7c4937d..cbd75e1 100644
--- a/pystencils/backends/cbackend.py
+++ b/pystencils/backends/cbackend.py
@@ -32,6 +32,11 @@ __all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSy
 KERNCRAFT_NO_TERNARY_MODE = False
 
 
+class UnsupportedCDialect(Exception):
+    def __init__(self):
+        super(UnsupportedCDialect, self).__init__()
+
+
 def generate_c(ast_node: Node, signature_only: bool = False, dialect='c') -> str:
     """Prints an abstract syntax tree node as C or CUDA code.
 
@@ -52,9 +57,15 @@ def generate_c(ast_node: Node, signature_only: bool = False, dialect='c') -> str
             ast_node.global_variables.update(d.symbols_defined)
         else:
             ast_node.global_variables = d.symbols_defined
-    printer = CBackend(signature_only=signature_only,
-                       vector_instruction_set=ast_node.instruction_set,
-                       dialect=dialect)
+
+    if dialect == 'c':
+        printer = CBackend(signature_only=signature_only,
+                           vector_instruction_set=ast_node.instruction_set)
+    elif dialect == 'cuda':
+        from pystencils.backends.cuda_backend import CudaBackend
+        printer = CudaBackend(signature_only=signature_only)
+    else:
+        raise UnsupportedCDialect
     code = printer(ast_node)
     if not signature_only and isinstance(ast_node, KernelFunction):
         code = "\n" + code
@@ -141,9 +152,9 @@ class CBackend:
     def __init__(self, sympy_printer=None, signature_only=False, vector_instruction_set=None, dialect='c'):
         if sympy_printer is None:
             if vector_instruction_set is not None:
-                self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set, dialect)
+                self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set)
             else:
-                self.sympy_printer = CustomSympyPrinter(dialect)
+                self.sympy_printer = CustomSympyPrinter()
         else:
             self.sympy_printer = sympy_printer
 
@@ -164,12 +175,12 @@ class CBackend:
             method_name = "_print_" + cls.__name__
             if hasattr(self, method_name):
                 return getattr(self, method_name)(node)
-        raise NotImplementedError("CBackend does not support node of type " + str(type(node)))
+        raise NotImplementedError(self.__class__ + " does not support node of type " + str(type(node)))
 
     def _print_KernelFunction(self, node):
         function_arguments = ["%s %s" % (str(s.symbol.dtype), s.symbol.name) for s in node.get_parameters()]
         launch_bounds = ""
-        if self._dialect == 'cuda':
+        if self.__class__ == 'cuda':
             max_threads = node.indexing.max_threads_per_block()
             if max_threads:
                 launch_bounds = "__launch_bounds__({}) ".format(max_threads)
@@ -241,10 +252,7 @@ class CBackend:
         return "free(%s - %d);" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align))
 
     def _print_SkipIteration(self, _):
-        if self._dialect == 'cuda':
-            return "return;"
-        else:
-            return "continue;"
+        return "continue;"
 
     def _print_CustomCodeNode(self, node):
         return node.get_code(self._dialect, self._vector_instruction_set)
@@ -292,10 +300,9 @@ class CBackend:
 # noinspection PyPep8Naming
 class CustomSympyPrinter(CCodePrinter):
 
-    def __init__(self, dialect):
+    def __init__(self):
         super(CustomSympyPrinter, self).__init__()
         self._float_type = create_type("float32")
-        self._dialect = dialect
         if 'Min' in self.known_functions:
             del self.known_functions['Min']
         if 'Max' in self.known_functions:
@@ -347,22 +354,13 @@ class CustomSympyPrinter(CCodePrinter):
             else:
                 return "((%s)(%s))" % (data_type, self._print(arg))
         elif isinstance(expr, fast_division):
-            if self._dialect == "cuda":
-                return "__fdividef(%s, %s)" % tuple(self._print(a) for a in expr.args)
-            else:
-                return "({})".format(self._print(expr.args[0] / expr.args[1]))
+            return "({})".format(self._print(expr.args[0] / expr.args[1]))
         elif isinstance(expr, fast_sqrt):
-            if self._dialect == "cuda":
-                return "__fsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
-            else:
-                return "({})".format(self._print(sp.sqrt(expr.args[0])))
+            return "({})".format(self._print(sp.sqrt(expr.args[0])))
         elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
             return self._print(expr.args[0])
         elif isinstance(expr, fast_inv_sqrt):
-            if self._dialect == "cuda":
-                return "__frsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
-            else:
-                return "({})".format(self._print(1 / sp.sqrt(expr.args[0])))
+            return "({})".format(self._print(1 / sp.sqrt(expr.args[0])))
         elif expr.func in infix_functions:
             return "(%s %s %s)" % (self._print(expr.args[0]), infix_functions[expr.func], self._print(expr.args[1]))
         elif expr.func == int_power_of_2:
@@ -392,8 +390,8 @@ class CustomSympyPrinter(CCodePrinter):
 class VectorizedCustomSympyPrinter(CustomSympyPrinter):
     SummandInfo = namedtuple("SummandInfo", ['sign', 'term'])
 
-    def __init__(self, instruction_set, dialect):
-        super(VectorizedCustomSympyPrinter, self).__init__(dialect=dialect)
+    def __init__(self, instruction_set):
+        super(VectorizedCustomSympyPrinter, self).__init__()
         self.instruction_set = instruction_set
 
     def _scalarFallback(self, func_name, expr, *args, **kwargs):
diff --git a/pystencils/backends/cuda_backend.py b/pystencils/backends/cuda_backend.py
new file mode 100644
index 0000000..e9a7816
--- /dev/null
+++ b/pystencils/backends/cuda_backend.py
@@ -0,0 +1,88 @@
+
+from os.path import dirname, join
+
+from pystencils.astnodes import Node
+from pystencils.backends.cbackend import (CBackend, CustomSympyPrinter,
+                                          generate_c)
+from pystencils.fast_approximation import (fast_division, fast_inv_sqrt,
+                                           fast_sqrt)
+
+CUDA_KNOWN_FUNCTIONS = None
+with open(join(dirname(__file__), 'cuda_known_functions.txt')) as f:
+    lines = f.readlines()
+    CUDA_KNOWN_FUNCTIONS = {l.strip(): l.strip() for l in lines if l}
+
+
+def generate_cuda(astnode: Node, signature_only: bool = False) -> str:
+    """Prints an abstract syntax tree node as CUDA code.
+
+    Args:
+        ast_node:
+        signature_only:
+
+    Returns:
+        C-like code for the ast node and its descendants
+    """
+    return generate_c(astnode, signature_only, dialect='cuda')
+
+
+class CudaBackend(CBackend):
+
+    def __init__(self, sympy_printer=None,
+                 signature_only=False):
+        if not sympy_printer:
+            sympy_printer = CudaSympyPrinter()
+
+        super().__init__(sympy_printer, signature_only, dialect='cuda')
+
+    def _print_SharedMemoryAllocation(self, node):
+        code = "__shared__ {dtype} {name}[{num_elements}];"
+        return code.format(dtype=node.symbol.dtype,
+                           name=self.sympy_printer.doprint(node.symbol.name),
+                           num_elements='*'.join([str(s) for s in node.shared_mem.shape]))
+
+    def _print_ThreadBlockSynchronization(self, node):
+        code = "__synchtreads();"
+        return code
+
+    def _print_TextureDeclaration(self, node):
+        code = "texture<%s, cudaTextureType%iD, cudaReadModeElementType> %s;" % (
+            str(node.texture.field.dtype),
+            node.texture.field.spatial_dimensions,
+            node.texture
+        )
+        return code
+
+    def _print_SkipIteration(self, _):
+        return "return;"
+
+
+class CudaSympyPrinter(CustomSympyPrinter):
+
+    def __init__(self):
+        super(CudaSympyPrinter, self).__init__()
+        self.known_functions = CUDA_KNOWN_FUNCTIONS
+
+    def _print_TextureAccess(self, node):
+
+        if node.texture.cubic_bspline_interpolation:
+            template = "cubicTex%iDSimple<%s>(%s, %s)"
+        else:
+            template = "tex%iD<%s>(%s, %s)"
+
+        code = template % (
+            node.texture.field.spatial_dimensions,
+            str(node.texture.field.dtype),
+            str(node.texture),
+            ', '.join(self._print(o) for o in node.offsets)
+        )
+        return code
+
+    def _print_Function(self, expr):
+        if isinstance(expr, fast_division):
+            return "__fdividef(%s, %s)" % tuple(self._print(a) for a in expr.args)
+        elif isinstance(expr, fast_sqrt):
+            return "__fsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
+        elif isinstance(expr, fast_inv_sqrt):
+            return "__frsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
+        return super()._print_Function(expr)
diff --git a/pystencils/backends/cuda_known_functions.txt b/pystencils/backends/cuda_known_functions.txt
new file mode 100644
index 0000000..42cf554
--- /dev/null
+++ b/pystencils/backends/cuda_known_functions.txt
@@ -0,0 +1,293 @@
+__prof_trigger
+printf
+
+__syncthreads
+__syncthreads_count
+__syncthreads_and
+__syncthreads_or
+__syncwarp
+__threadfence
+__threadfence_block
+__threadfence_system
+
+atomicAdd
+atomicSub
+atomicExch
+atomicMin
+atomicMax
+atomicInc
+atomicDec
+atomicAnd
+atomicOr
+atomicXor
+atomicCAS
+
+__all_sync
+__any_sync
+__ballot_sync
+__active_mask
+
+__shfl_sync
+__shfl_up_sync
+__shfl_down_sync
+__shfl_xor_sync
+
+__match_any_sync
+__match_all_sync
+
+__isGlobal
+__isShared
+__isConstant
+__isLocal
+
+tex1Dfetch
+tex1D
+tex2D
+tex3D
+
+rsqrtf
+cbrtf
+rcbrtf
+hypotf
+rhypotf
+norm3df
+rnorm3df
+norm4df
+rnorm4df
+normf
+rnormf
+expf
+exp2f
+exp10f
+expm1f
+logf
+log2f
+log10f
+log1pf
+sinf
+cosf
+tanf
+sincosf
+sinpif
+cospif
+sincospif
+asinf
+acosf
+atanf
+atan2f
+sinhf
+coshf
+tanhf
+asinhf
+acoshf
+atanhf
+powf
+erff
+erfcf
+erfinvf
+erfcinvf
+erfcxf
+normcdff
+normcdfinvf
+lgammaf
+tgammaf
+fmaf
+frexpf
+ldexpf
+scalbnf
+scalblnf
+logbf
+ilogbf
+j0f
+j1f
+jnf
+y0f
+y1f
+ynf
+cyl_bessel_i0f
+cyl_bessel_i1f
+fmodf
+remainderf
+remquof
+modff
+fdimf
+truncf
+roundf
+rintf
+nearbyintf
+ceilf
+floorf
+lrintf
+lroundf
+llrintf
+llroundf
+
+sqrt
+rsqrt
+cbrt
+rcbrt
+hypot
+rhypot
+norm3d
+rnorm3d
+norm4d
+rnorm4d
+norm
+rnorm
+exp
+exp2
+exp10
+expm1
+log
+log2
+log10
+log1p
+sin
+cos
+tan
+sincos
+sinpi
+cospi
+sincospi
+asin
+acos
+atan
+atan2
+sinh
+cosh
+tanh
+asinh
+acosh
+atanh
+pow
+erf
+erfc
+erfinv
+erfcinv
+erfcx
+normcdf
+normcdfinv
+lgamma
+tgamma
+fma
+frexp
+ldexp
+scalbn
+scalbln
+logb
+ilogb
+j0
+j1
+jn
+y0
+y1
+yn
+cyl_bessel_i0
+cyl_bessel_i1
+fmod
+remainder
+remquo
+mod
+fdim
+trunc
+round
+rint
+nearbyint
+ceil
+floor
+lrint
+lround
+llrint
+llround
+
+__fdividef
+__sinf
+__cosf
+__tanf
+__sincosf
+__logf
+__log2f
+__log10f
+__expf
+__exp10f
+__powf
+
+__fadd_rn
+__fsub_rn
+__fmul_rn
+__fmaf_rn
+__frcp_rn
+__fsqrt_rn
+__frsqrt_rn
+__fdiv_rn
+
+__fadd_rz
+__fsub_rz
+__fmul_rz
+__fmaf_rz
+__frcp_rz
+__fsqrt_rz
+__frsqrt_rz
+__fdiv_rz
+
+__fadd_ru
+__fsub_ru
+__fmul_ru
+__fmaf_ru
+__frcp_ru
+__fsqrt_ru
+__frsqrt_ru
+__fdiv_ru
+
+__fadd_rd
+__fsub_rd
+__fmul_rd
+__fmaf_rd
+__frcp_rd
+__fsqrt_rd
+__frsqrt_rd
+__fdiv_rd
+
+__fdividef
+__expf
+__exp10f
+__logf
+__log2f
+__log10f
+__sinf
+__cosf
+__sincosf
+__tanf
+__powf
+
+__dadd_rn
+__dsub_rn
+__dmul_rn
+__fma_rn
+__ddiv_rn
+__drcp_rn
+__dsqrt_rn
+
+__dadd_rz
+__dsub_rz
+__dmul_rz
+__fma_rz
+__ddiv_rz
+__drcp_rz
+__dsqrt_rz
+
+__dadd_ru
+__dsub_ru
+__dmul_ru
+__fma_ru
+__ddiv_ru
+__drcp_ru
+__dsqrt_ru
+
+__dadd_rd
+__dsub_rd
+__dmul_rd
+__fma_rd
+__ddiv_rd
+__drcp_rd
+__dsqrt_rd
-- 
GitLab