From faf330f83b8fdec3974b532bd957d81f271369ff Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Tue, 9 Jul 2019 14:01:08 +0200 Subject: [PATCH] Add CudaBackend, CudaSympyPrinter --- pystencils/backends/cbackend.py | 52 ++-- pystencils/backends/cuda_backend.py | 88 ++++++ pystencils/backends/cuda_known_functions.txt | 293 +++++++++++++++++++ 3 files changed, 406 insertions(+), 27 deletions(-) create mode 100644 pystencils/backends/cuda_backend.py create mode 100644 pystencils/backends/cuda_known_functions.txt diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 7c4937d..cbd75e1 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -32,6 +32,11 @@ __all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSy KERNCRAFT_NO_TERNARY_MODE = False +class UnsupportedCDialect(Exception): + def __init__(self): + super(UnsupportedCDialect, self).__init__() + + def generate_c(ast_node: Node, signature_only: bool = False, dialect='c') -> str: """Prints an abstract syntax tree node as C or CUDA code. @@ -52,9 +57,15 @@ def generate_c(ast_node: Node, signature_only: bool = False, dialect='c') -> str ast_node.global_variables.update(d.symbols_defined) else: ast_node.global_variables = d.symbols_defined - printer = CBackend(signature_only=signature_only, - vector_instruction_set=ast_node.instruction_set, - dialect=dialect) + + if dialect == 'c': + printer = CBackend(signature_only=signature_only, + vector_instruction_set=ast_node.instruction_set) + elif dialect == 'cuda': + from pystencils.backends.cuda_backend import CudaBackend + printer = CudaBackend(signature_only=signature_only) + else: + raise UnsupportedCDialect code = printer(ast_node) if not signature_only and isinstance(ast_node, KernelFunction): code = "\n" + code @@ -141,9 +152,9 @@ class CBackend: def __init__(self, sympy_printer=None, signature_only=False, vector_instruction_set=None, dialect='c'): if sympy_printer is None: if vector_instruction_set is not None: - self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set, dialect) + self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set) else: - self.sympy_printer = CustomSympyPrinter(dialect) + self.sympy_printer = CustomSympyPrinter() else: self.sympy_printer = sympy_printer @@ -164,12 +175,12 @@ class CBackend: method_name = "_print_" + cls.__name__ if hasattr(self, method_name): return getattr(self, method_name)(node) - raise NotImplementedError("CBackend does not support node of type " + str(type(node))) + raise NotImplementedError(self.__class__ + " does not support node of type " + str(type(node))) def _print_KernelFunction(self, node): function_arguments = ["%s %s" % (str(s.symbol.dtype), s.symbol.name) for s in node.get_parameters()] launch_bounds = "" - if self._dialect == 'cuda': + if self.__class__ == 'cuda': max_threads = node.indexing.max_threads_per_block() if max_threads: launch_bounds = "__launch_bounds__({}) ".format(max_threads) @@ -241,10 +252,7 @@ class CBackend: return "free(%s - %d);" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align)) def _print_SkipIteration(self, _): - if self._dialect == 'cuda': - return "return;" - else: - return "continue;" + return "continue;" def _print_CustomCodeNode(self, node): return node.get_code(self._dialect, self._vector_instruction_set) @@ -292,10 +300,9 @@ class CBackend: # noinspection PyPep8Naming class CustomSympyPrinter(CCodePrinter): - def __init__(self, dialect): + def __init__(self): super(CustomSympyPrinter, self).__init__() self._float_type = create_type("float32") - self._dialect = dialect if 'Min' in self.known_functions: del self.known_functions['Min'] if 'Max' in self.known_functions: @@ -347,22 +354,13 @@ class CustomSympyPrinter(CCodePrinter): else: return "((%s)(%s))" % (data_type, self._print(arg)) elif isinstance(expr, fast_division): - if self._dialect == "cuda": - return "__fdividef(%s, %s)" % tuple(self._print(a) for a in expr.args) - else: - return "({})".format(self._print(expr.args[0] / expr.args[1])) + return "({})".format(self._print(expr.args[0] / expr.args[1])) elif isinstance(expr, fast_sqrt): - if self._dialect == "cuda": - return "__fsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args) - else: - return "({})".format(self._print(sp.sqrt(expr.args[0]))) + return "({})".format(self._print(sp.sqrt(expr.args[0]))) elif isinstance(expr, vec_any) or isinstance(expr, vec_all): return self._print(expr.args[0]) elif isinstance(expr, fast_inv_sqrt): - if self._dialect == "cuda": - return "__frsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args) - else: - return "({})".format(self._print(1 / sp.sqrt(expr.args[0]))) + return "({})".format(self._print(1 / sp.sqrt(expr.args[0]))) elif expr.func in infix_functions: return "(%s %s %s)" % (self._print(expr.args[0]), infix_functions[expr.func], self._print(expr.args[1])) elif expr.func == int_power_of_2: @@ -392,8 +390,8 @@ class CustomSympyPrinter(CCodePrinter): class VectorizedCustomSympyPrinter(CustomSympyPrinter): SummandInfo = namedtuple("SummandInfo", ['sign', 'term']) - def __init__(self, instruction_set, dialect): - super(VectorizedCustomSympyPrinter, self).__init__(dialect=dialect) + def __init__(self, instruction_set): + super(VectorizedCustomSympyPrinter, self).__init__() self.instruction_set = instruction_set def _scalarFallback(self, func_name, expr, *args, **kwargs): diff --git a/pystencils/backends/cuda_backend.py b/pystencils/backends/cuda_backend.py new file mode 100644 index 0000000..e9a7816 --- /dev/null +++ b/pystencils/backends/cuda_backend.py @@ -0,0 +1,88 @@ + +from os.path import dirname, join + +from pystencils.astnodes import Node +from pystencils.backends.cbackend import (CBackend, CustomSympyPrinter, + generate_c) +from pystencils.fast_approximation import (fast_division, fast_inv_sqrt, + fast_sqrt) + +CUDA_KNOWN_FUNCTIONS = None +with open(join(dirname(__file__), 'cuda_known_functions.txt')) as f: + lines = f.readlines() + CUDA_KNOWN_FUNCTIONS = {l.strip(): l.strip() for l in lines if l} + + +def generate_cuda(astnode: Node, signature_only: bool = False) -> str: + """Prints an abstract syntax tree node as CUDA code. + + Args: + ast_node: + signature_only: + + Returns: + C-like code for the ast node and its descendants + """ + return generate_c(astnode, signature_only, dialect='cuda') + + +class CudaBackend(CBackend): + + def __init__(self, sympy_printer=None, + signature_only=False): + if not sympy_printer: + sympy_printer = CudaSympyPrinter() + + super().__init__(sympy_printer, signature_only, dialect='cuda') + + def _print_SharedMemoryAllocation(self, node): + code = "__shared__ {dtype} {name}[{num_elements}];" + return code.format(dtype=node.symbol.dtype, + name=self.sympy_printer.doprint(node.symbol.name), + num_elements='*'.join([str(s) for s in node.shared_mem.shape])) + + def _print_ThreadBlockSynchronization(self, node): + code = "__synchtreads();" + return code + + def _print_TextureDeclaration(self, node): + code = "texture<%s, cudaTextureType%iD, cudaReadModeElementType> %s;" % ( + str(node.texture.field.dtype), + node.texture.field.spatial_dimensions, + node.texture + ) + return code + + def _print_SkipIteration(self, _): + return "return;" + + +class CudaSympyPrinter(CustomSympyPrinter): + + def __init__(self): + super(CudaSympyPrinter, self).__init__() + self.known_functions = CUDA_KNOWN_FUNCTIONS + + def _print_TextureAccess(self, node): + + if node.texture.cubic_bspline_interpolation: + template = "cubicTex%iDSimple<%s>(%s, %s)" + else: + template = "tex%iD<%s>(%s, %s)" + + code = template % ( + node.texture.field.spatial_dimensions, + str(node.texture.field.dtype), + str(node.texture), + ', '.join(self._print(o) for o in node.offsets) + ) + return code + + def _print_Function(self, expr): + if isinstance(expr, fast_division): + return "__fdividef(%s, %s)" % tuple(self._print(a) for a in expr.args) + elif isinstance(expr, fast_sqrt): + return "__fsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args) + elif isinstance(expr, fast_inv_sqrt): + return "__frsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args) + return super()._print_Function(expr) diff --git a/pystencils/backends/cuda_known_functions.txt b/pystencils/backends/cuda_known_functions.txt new file mode 100644 index 0000000..42cf554 --- /dev/null +++ b/pystencils/backends/cuda_known_functions.txt @@ -0,0 +1,293 @@ +__prof_trigger +printf + +__syncthreads +__syncthreads_count +__syncthreads_and +__syncthreads_or +__syncwarp +__threadfence +__threadfence_block +__threadfence_system + +atomicAdd +atomicSub +atomicExch +atomicMin +atomicMax +atomicInc +atomicDec +atomicAnd +atomicOr +atomicXor +atomicCAS + +__all_sync +__any_sync +__ballot_sync +__active_mask + +__shfl_sync +__shfl_up_sync +__shfl_down_sync +__shfl_xor_sync + +__match_any_sync +__match_all_sync + +__isGlobal +__isShared +__isConstant +__isLocal + +tex1Dfetch +tex1D +tex2D +tex3D + +rsqrtf +cbrtf +rcbrtf +hypotf +rhypotf +norm3df +rnorm3df +norm4df +rnorm4df +normf +rnormf +expf +exp2f +exp10f +expm1f +logf +log2f +log10f +log1pf +sinf +cosf +tanf +sincosf +sinpif +cospif +sincospif +asinf +acosf +atanf +atan2f +sinhf +coshf +tanhf +asinhf +acoshf +atanhf +powf +erff +erfcf +erfinvf +erfcinvf +erfcxf +normcdff +normcdfinvf +lgammaf +tgammaf +fmaf +frexpf +ldexpf +scalbnf +scalblnf +logbf +ilogbf +j0f +j1f +jnf +y0f +y1f +ynf +cyl_bessel_i0f +cyl_bessel_i1f +fmodf +remainderf +remquof +modff +fdimf +truncf +roundf +rintf +nearbyintf +ceilf +floorf +lrintf +lroundf +llrintf +llroundf + +sqrt +rsqrt +cbrt +rcbrt +hypot +rhypot +norm3d +rnorm3d +norm4d +rnorm4d +norm +rnorm +exp +exp2 +exp10 +expm1 +log +log2 +log10 +log1p +sin +cos +tan +sincos +sinpi +cospi +sincospi +asin +acos +atan +atan2 +sinh +cosh +tanh +asinh +acosh +atanh +pow +erf +erfc +erfinv +erfcinv +erfcx +normcdf +normcdfinv +lgamma +tgamma +fma +frexp +ldexp +scalbn +scalbln +logb +ilogb +j0 +j1 +jn +y0 +y1 +yn +cyl_bessel_i0 +cyl_bessel_i1 +fmod +remainder +remquo +mod +fdim +trunc +round +rint +nearbyint +ceil +floor +lrint +lround +llrint +llround + +__fdividef +__sinf +__cosf +__tanf +__sincosf +__logf +__log2f +__log10f +__expf +__exp10f +__powf + +__fadd_rn +__fsub_rn +__fmul_rn +__fmaf_rn +__frcp_rn +__fsqrt_rn +__frsqrt_rn +__fdiv_rn + +__fadd_rz +__fsub_rz +__fmul_rz +__fmaf_rz +__frcp_rz +__fsqrt_rz +__frsqrt_rz +__fdiv_rz + +__fadd_ru +__fsub_ru +__fmul_ru +__fmaf_ru +__frcp_ru +__fsqrt_ru +__frsqrt_ru +__fdiv_ru + +__fadd_rd +__fsub_rd +__fmul_rd +__fmaf_rd +__frcp_rd +__fsqrt_rd +__frsqrt_rd +__fdiv_rd + +__fdividef +__expf +__exp10f +__logf +__log2f +__log10f +__sinf +__cosf +__sincosf +__tanf +__powf + +__dadd_rn +__dsub_rn +__dmul_rn +__fma_rn +__ddiv_rn +__drcp_rn +__dsqrt_rn + +__dadd_rz +__dsub_rz +__dmul_rz +__fma_rz +__ddiv_rz +__drcp_rz +__dsqrt_rz + +__dadd_ru +__dsub_ru +__dmul_ru +__fma_ru +__ddiv_ru +__drcp_ru +__dsqrt_ru + +__dadd_rd +__dsub_rd +__dmul_rd +__fma_rd +__ddiv_rd +__drcp_rd +__dsqrt_rd -- GitLab