Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Sebastian Bindgen
pystencils
Commits
faf330f8
Commit
faf330f8
authored
Jul 09, 2019
by
Stephan Seitz
Browse files
Add CudaBackend, CudaSympyPrinter
parent
9b9a4b54
Changes
3
Hide whitespace changes
Inline
Side-by-side
pystencils/backends/cbackend.py
View file @
faf330f8
...
...
@@ -32,6 +32,11 @@ __all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSy
KERNCRAFT_NO_TERNARY_MODE
=
False
class
UnsupportedCDialect
(
Exception
):
def
__init__
(
self
):
super
(
UnsupportedCDialect
,
self
).
__init__
()
def
generate_c
(
ast_node
:
Node
,
signature_only
:
bool
=
False
,
dialect
=
'c'
)
->
str
:
"""Prints an abstract syntax tree node as C or CUDA code.
...
...
@@ -52,9 +57,15 @@ def generate_c(ast_node: Node, signature_only: bool = False, dialect='c') -> str
ast_node
.
global_variables
.
update
(
d
.
symbols_defined
)
else
:
ast_node
.
global_variables
=
d
.
symbols_defined
printer
=
CBackend
(
signature_only
=
signature_only
,
vector_instruction_set
=
ast_node
.
instruction_set
,
dialect
=
dialect
)
if
dialect
==
'c'
:
printer
=
CBackend
(
signature_only
=
signature_only
,
vector_instruction_set
=
ast_node
.
instruction_set
)
elif
dialect
==
'cuda'
:
from
pystencils.backends.cuda_backend
import
CudaBackend
printer
=
CudaBackend
(
signature_only
=
signature_only
)
else
:
raise
UnsupportedCDialect
code
=
printer
(
ast_node
)
if
not
signature_only
and
isinstance
(
ast_node
,
KernelFunction
):
code
=
"
\n
"
+
code
...
...
@@ -141,9 +152,9 @@ class CBackend:
def
__init__
(
self
,
sympy_printer
=
None
,
signature_only
=
False
,
vector_instruction_set
=
None
,
dialect
=
'c'
):
if
sympy_printer
is
None
:
if
vector_instruction_set
is
not
None
:
self
.
sympy_printer
=
VectorizedCustomSympyPrinter
(
vector_instruction_set
,
dialect
)
self
.
sympy_printer
=
VectorizedCustomSympyPrinter
(
vector_instruction_set
)
else
:
self
.
sympy_printer
=
CustomSympyPrinter
(
dialect
)
self
.
sympy_printer
=
CustomSympyPrinter
()
else
:
self
.
sympy_printer
=
sympy_printer
...
...
@@ -164,12 +175,12 @@ class CBackend:
method_name
=
"_print_"
+
cls
.
__name__
if
hasattr
(
self
,
method_name
):
return
getattr
(
self
,
method_name
)(
node
)
raise
NotImplementedError
(
"CBackend
does not support node of type "
+
str
(
type
(
node
)))
raise
NotImplementedError
(
self
.
__class__
+
"
does not support node of type "
+
str
(
type
(
node
)))
def
_print_KernelFunction
(
self
,
node
):
function_arguments
=
[
"%s %s"
%
(
str
(
s
.
symbol
.
dtype
),
s
.
symbol
.
name
)
for
s
in
node
.
get_parameters
()]
launch_bounds
=
""
if
self
.
_
dialect
==
'cuda'
:
if
self
.
_
_class__
==
'cuda'
:
max_threads
=
node
.
indexing
.
max_threads_per_block
()
if
max_threads
:
launch_bounds
=
"__launch_bounds__({}) "
.
format
(
max_threads
)
...
...
@@ -241,10 +252,7 @@ class CBackend:
return
"free(%s - %d);"
%
(
self
.
sympy_printer
.
doprint
(
node
.
symbol
.
name
),
node
.
offset
(
align
))
def
_print_SkipIteration
(
self
,
_
):
if
self
.
_dialect
==
'cuda'
:
return
"return;"
else
:
return
"continue;"
return
"continue;"
def
_print_CustomCodeNode
(
self
,
node
):
return
node
.
get_code
(
self
.
_dialect
,
self
.
_vector_instruction_set
)
...
...
@@ -292,10 +300,9 @@ class CBackend:
# noinspection PyPep8Naming
class
CustomSympyPrinter
(
CCodePrinter
):
def
__init__
(
self
,
dialect
):
def
__init__
(
self
):
super
(
CustomSympyPrinter
,
self
).
__init__
()
self
.
_float_type
=
create_type
(
"float32"
)
self
.
_dialect
=
dialect
if
'Min'
in
self
.
known_functions
:
del
self
.
known_functions
[
'Min'
]
if
'Max'
in
self
.
known_functions
:
...
...
@@ -347,22 +354,13 @@ class CustomSympyPrinter(CCodePrinter):
else
:
return
"((%s)(%s))"
%
(
data_type
,
self
.
_print
(
arg
))
elif
isinstance
(
expr
,
fast_division
):
if
self
.
_dialect
==
"cuda"
:
return
"__fdividef(%s, %s)"
%
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
else
:
return
"({})"
.
format
(
self
.
_print
(
expr
.
args
[
0
]
/
expr
.
args
[
1
]))
return
"({})"
.
format
(
self
.
_print
(
expr
.
args
[
0
]
/
expr
.
args
[
1
]))
elif
isinstance
(
expr
,
fast_sqrt
):
if
self
.
_dialect
==
"cuda"
:
return
"__fsqrt_rn(%s)"
%
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
else
:
return
"({})"
.
format
(
self
.
_print
(
sp
.
sqrt
(
expr
.
args
[
0
])))
return
"({})"
.
format
(
self
.
_print
(
sp
.
sqrt
(
expr
.
args
[
0
])))
elif
isinstance
(
expr
,
vec_any
)
or
isinstance
(
expr
,
vec_all
):
return
self
.
_print
(
expr
.
args
[
0
])
elif
isinstance
(
expr
,
fast_inv_sqrt
):
if
self
.
_dialect
==
"cuda"
:
return
"__frsqrt_rn(%s)"
%
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
else
:
return
"({})"
.
format
(
self
.
_print
(
1
/
sp
.
sqrt
(
expr
.
args
[
0
])))
return
"({})"
.
format
(
self
.
_print
(
1
/
sp
.
sqrt
(
expr
.
args
[
0
])))
elif
expr
.
func
in
infix_functions
:
return
"(%s %s %s)"
%
(
self
.
_print
(
expr
.
args
[
0
]),
infix_functions
[
expr
.
func
],
self
.
_print
(
expr
.
args
[
1
]))
elif
expr
.
func
==
int_power_of_2
:
...
...
@@ -392,8 +390,8 @@ class CustomSympyPrinter(CCodePrinter):
class
VectorizedCustomSympyPrinter
(
CustomSympyPrinter
):
SummandInfo
=
namedtuple
(
"SummandInfo"
,
[
'sign'
,
'term'
])
def
__init__
(
self
,
instruction_set
,
dialect
):
super
(
VectorizedCustomSympyPrinter
,
self
).
__init__
(
dialect
=
dialect
)
def
__init__
(
self
,
instruction_set
):
super
(
VectorizedCustomSympyPrinter
,
self
).
__init__
()
self
.
instruction_set
=
instruction_set
def
_scalarFallback
(
self
,
func_name
,
expr
,
*
args
,
**
kwargs
):
...
...
pystencils/backends/cuda_backend.py
0 → 100644
View file @
faf330f8
from
os.path
import
dirname
,
join
from
pystencils.astnodes
import
Node
from
pystencils.backends.cbackend
import
(
CBackend
,
CustomSympyPrinter
,
generate_c
)
from
pystencils.fast_approximation
import
(
fast_division
,
fast_inv_sqrt
,
fast_sqrt
)
CUDA_KNOWN_FUNCTIONS
=
None
with
open
(
join
(
dirname
(
__file__
),
'cuda_known_functions.txt'
))
as
f
:
lines
=
f
.
readlines
()
CUDA_KNOWN_FUNCTIONS
=
{
l
.
strip
():
l
.
strip
()
for
l
in
lines
if
l
}
def
generate_cuda
(
astnode
:
Node
,
signature_only
:
bool
=
False
)
->
str
:
"""Prints an abstract syntax tree node as CUDA code.
Args:
ast_node:
signature_only:
Returns:
C-like code for the ast node and its descendants
"""
return
generate_c
(
astnode
,
signature_only
,
dialect
=
'cuda'
)
class
CudaBackend
(
CBackend
):
def
__init__
(
self
,
sympy_printer
=
None
,
signature_only
=
False
):
if
not
sympy_printer
:
sympy_printer
=
CudaSympyPrinter
()
super
().
__init__
(
sympy_printer
,
signature_only
,
dialect
=
'cuda'
)
def
_print_SharedMemoryAllocation
(
self
,
node
):
code
=
"__shared__ {dtype} {name}[{num_elements}];"
return
code
.
format
(
dtype
=
node
.
symbol
.
dtype
,
name
=
self
.
sympy_printer
.
doprint
(
node
.
symbol
.
name
),
num_elements
=
'*'
.
join
([
str
(
s
)
for
s
in
node
.
shared_mem
.
shape
]))
def
_print_ThreadBlockSynchronization
(
self
,
node
):
code
=
"__synchtreads();"
return
code
def
_print_TextureDeclaration
(
self
,
node
):
code
=
"texture<%s, cudaTextureType%iD, cudaReadModeElementType> %s;"
%
(
str
(
node
.
texture
.
field
.
dtype
),
node
.
texture
.
field
.
spatial_dimensions
,
node
.
texture
)
return
code
def
_print_SkipIteration
(
self
,
_
):
return
"return;"
class
CudaSympyPrinter
(
CustomSympyPrinter
):
def
__init__
(
self
):
super
(
CudaSympyPrinter
,
self
).
__init__
()
self
.
known_functions
=
CUDA_KNOWN_FUNCTIONS
def
_print_TextureAccess
(
self
,
node
):
if
node
.
texture
.
cubic_bspline_interpolation
:
template
=
"cubicTex%iDSimple<%s>(%s, %s)"
else
:
template
=
"tex%iD<%s>(%s, %s)"
code
=
template
%
(
node
.
texture
.
field
.
spatial_dimensions
,
str
(
node
.
texture
.
field
.
dtype
),
str
(
node
.
texture
),
', '
.
join
(
self
.
_print
(
o
)
for
o
in
node
.
offsets
)
)
return
code
def
_print_Function
(
self
,
expr
):
if
isinstance
(
expr
,
fast_division
):
return
"__fdividef(%s, %s)"
%
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
elif
isinstance
(
expr
,
fast_sqrt
):
return
"__fsqrt_rn(%s)"
%
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
elif
isinstance
(
expr
,
fast_inv_sqrt
):
return
"__frsqrt_rn(%s)"
%
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
return
super
().
_print_Function
(
expr
)
pystencils/backends/cuda_known_functions.txt
0 → 100644
View file @
faf330f8
__prof_trigger
printf
__syncthreads
__syncthreads_count
__syncthreads_and
__syncthreads_or
__syncwarp
__threadfence
__threadfence_block
__threadfence_system
atomicAdd
atomicSub
atomicExch
atomicMin
atomicMax
atomicInc
atomicDec
atomicAnd
atomicOr
atomicXor
atomicCAS
__all_sync
__any_sync
__ballot_sync
__active_mask
__shfl_sync
__shfl_up_sync
__shfl_down_sync
__shfl_xor_sync
__match_any_sync
__match_all_sync
__isGlobal
__isShared
__isConstant
__isLocal
tex1Dfetch
tex1D
tex2D
tex3D
rsqrtf
cbrtf
rcbrtf
hypotf
rhypotf
norm3df
rnorm3df
norm4df
rnorm4df
normf
rnormf
expf
exp2f
exp10f
expm1f
logf
log2f
log10f
log1pf
sinf
cosf
tanf
sincosf
sinpif
cospif
sincospif
asinf
acosf
atanf
atan2f
sinhf
coshf
tanhf
asinhf
acoshf
atanhf
powf
erff
erfcf
erfinvf
erfcinvf
erfcxf
normcdff
normcdfinvf
lgammaf
tgammaf
fmaf
frexpf
ldexpf
scalbnf
scalblnf
logbf
ilogbf
j0f
j1f
jnf
y0f
y1f
ynf
cyl_bessel_i0f
cyl_bessel_i1f
fmodf
remainderf
remquof
modff
fdimf
truncf
roundf
rintf
nearbyintf
ceilf
floorf
lrintf
lroundf
llrintf
llroundf
sqrt
rsqrt
cbrt
rcbrt
hypot
rhypot
norm3d
rnorm3d
norm4d
rnorm4d
norm
rnorm
exp
exp2
exp10
expm1
log
log2
log10
log1p
sin
cos
tan
sincos
sinpi
cospi
sincospi
asin
acos
atan
atan2
sinh
cosh
tanh
asinh
acosh
atanh
pow
erf
erfc
erfinv
erfcinv
erfcx
normcdf
normcdfinv
lgamma
tgamma
fma
frexp
ldexp
scalbn
scalbln
logb
ilogb
j0
j1
jn
y0
y1
yn
cyl_bessel_i0
cyl_bessel_i1
fmod
remainder
remquo
mod
fdim
trunc
round
rint
nearbyint
ceil
floor
lrint
lround
llrint
llround
__fdividef
__sinf
__cosf
__tanf
__sincosf
__logf
__log2f
__log10f
__expf
__exp10f
__powf
__fadd_rn
__fsub_rn
__fmul_rn
__fmaf_rn
__frcp_rn
__fsqrt_rn
__frsqrt_rn
__fdiv_rn
__fadd_rz
__fsub_rz
__fmul_rz
__fmaf_rz
__frcp_rz
__fsqrt_rz
__frsqrt_rz
__fdiv_rz
__fadd_ru
__fsub_ru
__fmul_ru
__fmaf_ru
__frcp_ru
__fsqrt_ru
__frsqrt_ru
__fdiv_ru
__fadd_rd
__fsub_rd
__fmul_rd
__fmaf_rd
__frcp_rd
__fsqrt_rd
__frsqrt_rd
__fdiv_rd
__fdividef
__expf
__exp10f
__logf
__log2f
__log10f
__sinf
__cosf
__sincosf
__tanf
__powf
__dadd_rn
__dsub_rn
__dmul_rn
__fma_rn
__ddiv_rn
__drcp_rn
__dsqrt_rn
__dadd_rz
__dsub_rz
__dmul_rz
__fma_rz
__ddiv_rz
__drcp_rz
__dsqrt_rz
__dadd_ru
__dsub_ru
__dmul_ru
__fma_ru
__ddiv_ru
__drcp_ru
__dsqrt_ru
__dadd_rd
__dsub_rd
__dmul_rd
__fma_rd
__ddiv_rd
__drcp_rd
__dsqrt_rd
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment