Skip to content
Snippets Groups Projects
Commit faf330f8 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add CudaBackend, CudaSympyPrinter

parent 9b9a4b54
Branches
Tags
1 merge request!9Add CudaBackend, CudaSympyPrinter
......@@ -32,6 +32,11 @@ __all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSy
KERNCRAFT_NO_TERNARY_MODE = False
class UnsupportedCDialect(Exception):
def __init__(self):
super(UnsupportedCDialect, self).__init__()
def generate_c(ast_node: Node, signature_only: bool = False, dialect='c') -> str:
"""Prints an abstract syntax tree node as C or CUDA code.
......@@ -52,9 +57,15 @@ def generate_c(ast_node: Node, signature_only: bool = False, dialect='c') -> str
ast_node.global_variables.update(d.symbols_defined)
else:
ast_node.global_variables = d.symbols_defined
printer = CBackend(signature_only=signature_only,
vector_instruction_set=ast_node.instruction_set,
dialect=dialect)
if dialect == 'c':
printer = CBackend(signature_only=signature_only,
vector_instruction_set=ast_node.instruction_set)
elif dialect == 'cuda':
from pystencils.backends.cuda_backend import CudaBackend
printer = CudaBackend(signature_only=signature_only)
else:
raise UnsupportedCDialect
code = printer(ast_node)
if not signature_only and isinstance(ast_node, KernelFunction):
code = "\n" + code
......@@ -141,9 +152,9 @@ class CBackend:
def __init__(self, sympy_printer=None, signature_only=False, vector_instruction_set=None, dialect='c'):
if sympy_printer is None:
if vector_instruction_set is not None:
self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set, dialect)
self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set)
else:
self.sympy_printer = CustomSympyPrinter(dialect)
self.sympy_printer = CustomSympyPrinter()
else:
self.sympy_printer = sympy_printer
......@@ -164,12 +175,12 @@ class CBackend:
method_name = "_print_" + cls.__name__
if hasattr(self, method_name):
return getattr(self, method_name)(node)
raise NotImplementedError("CBackend does not support node of type " + str(type(node)))
raise NotImplementedError(self.__class__ + " does not support node of type " + str(type(node)))
def _print_KernelFunction(self, node):
function_arguments = ["%s %s" % (str(s.symbol.dtype), s.symbol.name) for s in node.get_parameters()]
launch_bounds = ""
if self._dialect == 'cuda':
if self.__class__ == 'cuda':
max_threads = node.indexing.max_threads_per_block()
if max_threads:
launch_bounds = "__launch_bounds__({}) ".format(max_threads)
......@@ -241,10 +252,7 @@ class CBackend:
return "free(%s - %d);" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align))
def _print_SkipIteration(self, _):
if self._dialect == 'cuda':
return "return;"
else:
return "continue;"
return "continue;"
def _print_CustomCodeNode(self, node):
return node.get_code(self._dialect, self._vector_instruction_set)
......@@ -292,10 +300,9 @@ class CBackend:
# noinspection PyPep8Naming
class CustomSympyPrinter(CCodePrinter):
def __init__(self, dialect):
def __init__(self):
super(CustomSympyPrinter, self).__init__()
self._float_type = create_type("float32")
self._dialect = dialect
if 'Min' in self.known_functions:
del self.known_functions['Min']
if 'Max' in self.known_functions:
......@@ -347,22 +354,13 @@ class CustomSympyPrinter(CCodePrinter):
else:
return "((%s)(%s))" % (data_type, self._print(arg))
elif isinstance(expr, fast_division):
if self._dialect == "cuda":
return "__fdividef(%s, %s)" % tuple(self._print(a) for a in expr.args)
else:
return "({})".format(self._print(expr.args[0] / expr.args[1]))
return "({})".format(self._print(expr.args[0] / expr.args[1]))
elif isinstance(expr, fast_sqrt):
if self._dialect == "cuda":
return "__fsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
else:
return "({})".format(self._print(sp.sqrt(expr.args[0])))
return "({})".format(self._print(sp.sqrt(expr.args[0])))
elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
return self._print(expr.args[0])
elif isinstance(expr, fast_inv_sqrt):
if self._dialect == "cuda":
return "__frsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
else:
return "({})".format(self._print(1 / sp.sqrt(expr.args[0])))
return "({})".format(self._print(1 / sp.sqrt(expr.args[0])))
elif expr.func in infix_functions:
return "(%s %s %s)" % (self._print(expr.args[0]), infix_functions[expr.func], self._print(expr.args[1]))
elif expr.func == int_power_of_2:
......@@ -392,8 +390,8 @@ class CustomSympyPrinter(CCodePrinter):
class VectorizedCustomSympyPrinter(CustomSympyPrinter):
SummandInfo = namedtuple("SummandInfo", ['sign', 'term'])
def __init__(self, instruction_set, dialect):
super(VectorizedCustomSympyPrinter, self).__init__(dialect=dialect)
def __init__(self, instruction_set):
super(VectorizedCustomSympyPrinter, self).__init__()
self.instruction_set = instruction_set
def _scalarFallback(self, func_name, expr, *args, **kwargs):
......
from os.path import dirname, join
from pystencils.astnodes import Node
from pystencils.backends.cbackend import (CBackend, CustomSympyPrinter,
generate_c)
from pystencils.fast_approximation import (fast_division, fast_inv_sqrt,
fast_sqrt)
CUDA_KNOWN_FUNCTIONS = None
with open(join(dirname(__file__), 'cuda_known_functions.txt')) as f:
lines = f.readlines()
CUDA_KNOWN_FUNCTIONS = {l.strip(): l.strip() for l in lines if l}
def generate_cuda(astnode: Node, signature_only: bool = False) -> str:
"""Prints an abstract syntax tree node as CUDA code.
Args:
ast_node:
signature_only:
Returns:
C-like code for the ast node and its descendants
"""
return generate_c(astnode, signature_only, dialect='cuda')
class CudaBackend(CBackend):
def __init__(self, sympy_printer=None,
signature_only=False):
if not sympy_printer:
sympy_printer = CudaSympyPrinter()
super().__init__(sympy_printer, signature_only, dialect='cuda')
def _print_SharedMemoryAllocation(self, node):
code = "__shared__ {dtype} {name}[{num_elements}];"
return code.format(dtype=node.symbol.dtype,
name=self.sympy_printer.doprint(node.symbol.name),
num_elements='*'.join([str(s) for s in node.shared_mem.shape]))
def _print_ThreadBlockSynchronization(self, node):
code = "__synchtreads();"
return code
def _print_TextureDeclaration(self, node):
code = "texture<%s, cudaTextureType%iD, cudaReadModeElementType> %s;" % (
str(node.texture.field.dtype),
node.texture.field.spatial_dimensions,
node.texture
)
return code
def _print_SkipIteration(self, _):
return "return;"
class CudaSympyPrinter(CustomSympyPrinter):
def __init__(self):
super(CudaSympyPrinter, self).__init__()
self.known_functions = CUDA_KNOWN_FUNCTIONS
def _print_TextureAccess(self, node):
if node.texture.cubic_bspline_interpolation:
template = "cubicTex%iDSimple<%s>(%s, %s)"
else:
template = "tex%iD<%s>(%s, %s)"
code = template % (
node.texture.field.spatial_dimensions,
str(node.texture.field.dtype),
str(node.texture),
', '.join(self._print(o) for o in node.offsets)
)
return code
def _print_Function(self, expr):
if isinstance(expr, fast_division):
return "__fdividef(%s, %s)" % tuple(self._print(a) for a in expr.args)
elif isinstance(expr, fast_sqrt):
return "__fsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
elif isinstance(expr, fast_inv_sqrt):
return "__frsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
return super()._print_Function(expr)
__prof_trigger
printf
__syncthreads
__syncthreads_count
__syncthreads_and
__syncthreads_or
__syncwarp
__threadfence
__threadfence_block
__threadfence_system
atomicAdd
atomicSub
atomicExch
atomicMin
atomicMax
atomicInc
atomicDec
atomicAnd
atomicOr
atomicXor
atomicCAS
__all_sync
__any_sync
__ballot_sync
__active_mask
__shfl_sync
__shfl_up_sync
__shfl_down_sync
__shfl_xor_sync
__match_any_sync
__match_all_sync
__isGlobal
__isShared
__isConstant
__isLocal
tex1Dfetch
tex1D
tex2D
tex3D
rsqrtf
cbrtf
rcbrtf
hypotf
rhypotf
norm3df
rnorm3df
norm4df
rnorm4df
normf
rnormf
expf
exp2f
exp10f
expm1f
logf
log2f
log10f
log1pf
sinf
cosf
tanf
sincosf
sinpif
cospif
sincospif
asinf
acosf
atanf
atan2f
sinhf
coshf
tanhf
asinhf
acoshf
atanhf
powf
erff
erfcf
erfinvf
erfcinvf
erfcxf
normcdff
normcdfinvf
lgammaf
tgammaf
fmaf
frexpf
ldexpf
scalbnf
scalblnf
logbf
ilogbf
j0f
j1f
jnf
y0f
y1f
ynf
cyl_bessel_i0f
cyl_bessel_i1f
fmodf
remainderf
remquof
modff
fdimf
truncf
roundf
rintf
nearbyintf
ceilf
floorf
lrintf
lroundf
llrintf
llroundf
sqrt
rsqrt
cbrt
rcbrt
hypot
rhypot
norm3d
rnorm3d
norm4d
rnorm4d
norm
rnorm
exp
exp2
exp10
expm1
log
log2
log10
log1p
sin
cos
tan
sincos
sinpi
cospi
sincospi
asin
acos
atan
atan2
sinh
cosh
tanh
asinh
acosh
atanh
pow
erf
erfc
erfinv
erfcinv
erfcx
normcdf
normcdfinv
lgamma
tgamma
fma
frexp
ldexp
scalbn
scalbln
logb
ilogb
j0
j1
jn
y0
y1
yn
cyl_bessel_i0
cyl_bessel_i1
fmod
remainder
remquo
mod
fdim
trunc
round
rint
nearbyint
ceil
floor
lrint
lround
llrint
llround
__fdividef
__sinf
__cosf
__tanf
__sincosf
__logf
__log2f
__log10f
__expf
__exp10f
__powf
__fadd_rn
__fsub_rn
__fmul_rn
__fmaf_rn
__frcp_rn
__fsqrt_rn
__frsqrt_rn
__fdiv_rn
__fadd_rz
__fsub_rz
__fmul_rz
__fmaf_rz
__frcp_rz
__fsqrt_rz
__frsqrt_rz
__fdiv_rz
__fadd_ru
__fsub_ru
__fmul_ru
__fmaf_ru
__frcp_ru
__fsqrt_ru
__frsqrt_ru
__fdiv_ru
__fadd_rd
__fsub_rd
__fmul_rd
__fmaf_rd
__frcp_rd
__fsqrt_rd
__frsqrt_rd
__fdiv_rd
__fdividef
__expf
__exp10f
__logf
__log2f
__log10f
__sinf
__cosf
__sincosf
__tanf
__powf
__dadd_rn
__dsub_rn
__dmul_rn
__fma_rn
__ddiv_rn
__drcp_rn
__dsqrt_rn
__dadd_rz
__dsub_rz
__dmul_rz
__fma_rz
__ddiv_rz
__drcp_rz
__dsqrt_rz
__dadd_ru
__dsub_ru
__dmul_ru
__fma_ru
__ddiv_ru
__drcp_ru
__dsqrt_ru
__dadd_rd
__dsub_rd
__dmul_rd
__fma_rd
__ddiv_rd
__drcp_rd
__dsqrt_rd
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment