Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Jonas Plewinski
pystencils
Commits
bec1010f
Commit
bec1010f
authored
Sep 22, 2019
by
Stephan Seitz
Browse files
llvm: Implement LLVMPrinter._print_ThreadIndexingSymbol
parent
a4b64edf
Changes
3
Hide whitespace changes
Inline
Side-by-side
pystencils/gpucuda/indexing.py
View file @
bec1010f
...
...
@@ -24,10 +24,10 @@ class ThreadIndexingSymbol(TypedSymbol):
__xnew_cached_
=
staticmethod
(
cacheit
(
__new_stage2__
))
BLOCK_IDX
=
[
ThreadIndexingSymbol
(
"blockIdx."
+
coord
,
create_type
(
"int"
))
for
coord
in
(
'x'
,
'y'
,
'z'
)]
THREAD_IDX
=
[
ThreadIndexingSymbol
(
"threadIdx."
+
coord
,
create_type
(
"int"
))
for
coord
in
(
'x'
,
'y'
,
'z'
)]
BLOCK_DIM
=
[
ThreadIndexingSymbol
(
"blockDim."
+
coord
,
create_type
(
"int"
))
for
coord
in
(
'x'
,
'y'
,
'z'
)]
GRID_DIM
=
[
ThreadIndexingSymbol
(
"gridDim."
+
coord
,
create_type
(
"int"
))
for
coord
in
(
'x'
,
'y'
,
'z'
)]
BLOCK_IDX
=
[
ThreadIndexingSymbol
(
"blockIdx."
+
coord
,
create_type
(
"int
32
"
))
for
coord
in
(
'x'
,
'y'
,
'z'
)]
THREAD_IDX
=
[
ThreadIndexingSymbol
(
"threadIdx."
+
coord
,
create_type
(
"int
32
"
))
for
coord
in
(
'x'
,
'y'
,
'z'
)]
BLOCK_DIM
=
[
ThreadIndexingSymbol
(
"blockDim."
+
coord
,
create_type
(
"int
32
"
))
for
coord
in
(
'x'
,
'y'
,
'z'
)]
GRID_DIM
=
[
ThreadIndexingSymbol
(
"gridDim."
+
coord
,
create_type
(
"int
32
"
))
for
coord
in
(
'x'
,
'y'
,
'z'
)]
class
AbstractIndexing
(
abc
.
ABC
):
...
...
pystencils/llvm/kernelcreation.py
View file @
bec1010f
...
...
@@ -3,7 +3,7 @@ from pystencils.transformations import insert_casts
def
create_kernel
(
assignments
,
function_name
=
"kernel"
,
type_info
=
None
,
split_groups
=
(),
iteration_slice
=
None
,
ghost_layers
=
None
):
iteration_slice
=
None
,
ghost_layers
=
None
,
target
=
'cpu'
):
"""
Creates an abstract syntax tree for a kernel function, by taking a list of update rules.
...
...
@@ -25,9 +25,20 @@ def create_kernel(assignments, function_name="kernel", type_info=None, split_gro
:return: :class:`pystencils.ast.KernelFunction` node
"""
from
pystencils.cpu
import
create_kernel
code
=
create_kernel
(
assignments
,
function_name
,
type_info
,
split_groups
,
iteration_slice
,
ghost_layers
)
if
target
==
'cpu'
:
from
pystencils.cpu
import
create_kernel
code
=
create_kernel
(
assignments
,
function_name
,
type_info
,
split_groups
,
iteration_slice
,
ghost_layers
)
elif
target
==
'gpu'
:
from
pystencils.gpucuda.kernelcreation
import
create_cuda_kernel
code
=
create_cuda_kernel
(
assignments
,
function_name
,
type_info
,
iteration_slice
=
iteration_slice
,
ghost_layers
=
ghost_layers
)
else
:
NotImplementedError
()
code
.
body
=
insert_casts
(
code
.
body
)
code
.
_compile_function
=
make_python_function
code
.
_backend
=
'llvm'
return
code
pystencils/llvm/llvm.py
View file @
bec1010f
import
functools
import
llvmlite.ir
as
ir
import
llvmlite.llvmpy.core
as
lc
import
sympy
as
sp
from
sympy
import
Indexed
,
S
from
sympy.printing.printer
import
Printer
...
...
@@ -12,10 +13,18 @@ from pystencils.data_types import (
from
pystencils.llvm.control_flow
import
Loop
# From Numba
def
_call_sreg
(
builder
,
name
):
module
=
builder
.
module
fnty
=
lc
.
Type
.
function
(
lc
.
Type
.
int
(),
())
fn
=
module
.
get_or_insert_function
(
fnty
,
name
=
name
)
return
builder
.
call
(
fn
,
())
def
generate_llvm
(
ast_node
,
module
=
None
,
builder
=
None
):
"""Prints the ast as llvm code."""
if
module
is
None
:
module
=
ir
.
Module
()
module
=
lc
.
Module
()
if
builder
is
None
:
builder
=
ir
.
IRBuilder
()
printer
=
LLVMPrinter
(
module
,
builder
)
...
...
@@ -330,3 +339,19 @@ class LLVMPrinter(Printer):
mro
=
"None"
raise
TypeError
(
"Unsupported type for LLVM JIT conversion: Expression:
\"
%s
\"
, Type:
\"
%s
\"
, MRO:%s"
%
(
expr
,
type
(
expr
),
mro
))
# from: https://llvm.org/docs/NVPTXUsage.html#nvptx-intrinsics
INDEXING_FUNCTION_MAPPING
=
{
'blockIdx'
:
'llvm.nvvm.read.ptx.sreg.ctaid'
,
'threadIdx'
:
'llvm.nvvm.read.ptx.sreg.tid'
,
'blockDim'
:
'llvm.nvvm.read.ptx.sreg.ntid'
,
'gridDim'
:
'llvm.nvvm.read.ptx.sreg.nctaid'
}
def
_print_ThreadIndexingSymbol
(
self
,
node
):
symbol_name
:
str
=
node
.
name
function_name
,
dimension
=
tuple
(
symbol_name
.
split
(
"."
))
function_name
=
self
.
INDEXING_FUNCTION_MAPPING
[
function_name
]
name
=
f
"
{
function_name
}
.
{
dimension
}
"
return
self
.
builder
.
zext
(
_call_sreg
(
self
.
builder
,
name
),
self
.
integer
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment