Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Christoph Alt
pystencils
Commits
fce1a417
Commit
fce1a417
authored
May 21, 2021
by
Jan Hönig
Browse files
Merge branch 'sve' into 'master'
Sizeless vectorization See merge request
pycodegen/pystencils!234
parents
cc645538
cdf73d8f
Changes
15
Hide whitespace changes
Inline
Side-by-side
.gitlab-ci.yml
View file @
fce1a417
...
...
@@ -178,7 +178,7 @@ arm64v9:
extends
:
.multiarch_template
image
:
i10git.cs.fau.de:5005/pycodegen/pycodegen/arm64
variables
:
PYSTENCILS_SIMD
:
"
sve256,sve512"
PYSTENCILS_SIMD
:
"
sve256,sve512
,sve
"
ASAN_OPTIONS
:
detect_leaks=0
LD_PRELOAD
:
/usr/lib/aarch64-linux-gnu/libasan.so.6
before_script
:
...
...
@@ -186,6 +186,20 @@ arm64v9:
-
sed -i s/march=native/march=armv8-a+sve/g ~/.config/pystencils/config.json
-
sed -i s/g\+\+/clang++/g ~/.config/pystencils/config.json
riscv64
:
# The RISC-V vector extension is still experimental and needs special compiler flags.
# Once they are officially released, this job should be cleaned up to match the others.
extends
:
.multiarch_template
image
:
i10git.cs.fau.de:5005/pycodegen/pycodegen/riscv64
variables
:
PYSTENCILS_SIMD
:
"
rvv"
QEMU_CPU
:
"
rv64,x-v=true"
before_script
:
-
*multiarch_before_script
-
sed -i 's/march=native/march=rv64imfdv0p10 -menable-experimental-extensions/g' ~/.config/pystencils/config.json
-
sed -i s/g\+\+/clang++/g ~/.config/pystencils/config.json
-
sed -i 's/fopenmp/fopenmp=libgomp -I\/usr\/include\/riscv64-linux-gnu/g' ~/.config/pystencils/config.json
minimal-conda
:
stage
:
test
except
:
...
...
pystencils/alignedarray.py
View file @
fce1a417
...
...
@@ -28,13 +28,19 @@ def aligned_empty(shape, byte_alignment=True, dtype=np.float64, byte_offset=0, o
elif
byte_alignment
==
'cacheline'
:
cacheline_sizes
=
[
get_cacheline_size
(
is_name
)
for
is_name
in
instruction_sets
]
if
all
([
s
is
None
for
s
in
cacheline_sizes
]):
byte_alignment
=
max
([
get_vector_instruction_set
(
type_name
,
is_name
)[
'width'
]
*
np
.
dtype
(
dtype
).
itemsize
for
is_name
in
instruction_sets
])
widths
=
[
get_vector_instruction_set
(
type_name
,
is_name
)[
'width'
]
*
np
.
dtype
(
dtype
).
itemsize
for
is_name
in
instruction_sets
if
type
(
get_vector_instruction_set
(
type_name
,
is_name
)[
'width'
])
is
int
]
byte_alignment
=
64
if
all
([
s
is
None
for
s
in
widths
])
else
max
(
widths
)
else
:
byte_alignment
=
max
([
s
for
s
in
cacheline_sizes
if
s
is
not
None
])
elif
not
any
([
type
(
get_vector_instruction_set
(
type_name
,
is_name
)[
'width'
])
is
int
for
is_name
in
instruction_sets
]):
byte_alignment
=
64
else
:
byte_alignment
=
max
([
get_vector_instruction_set
(
type_name
,
is_name
)[
'width'
]
*
np
.
dtype
(
dtype
).
itemsize
for
is_name
in
instruction_sets
])
for
is_name
in
instruction_sets
if
type
(
get_vector_instruction_set
(
type_name
,
is_name
)[
'width'
])
is
int
])
if
(
not
align_inner_coordinate
)
or
(
not
hasattr
(
shape
,
'__len__'
)):
size
=
np
.
prod
(
shape
)
d
=
np
.
dtype
(
dtype
)
...
...
pystencils/backends/arm_instruction_sets.py
View file @
fce1a417
...
...
@@ -19,9 +19,8 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
if
instruction_set
!=
'neon'
and
not
instruction_set
.
startswith
(
'sve'
):
raise
NotImplementedError
(
instruction_set
)
if
instruction_set
==
'sve'
:
raise
NotImplementedError
(
"sizeless SVE is not implemented"
)
if
instruction_set
.
startswith
(
'sve'
):
cmp
=
'cmp'
elif
instruction_set
.
startswith
(
'sve'
):
cmp
=
'cmp'
bitwidth
=
int
(
instruction_set
[
3
:])
elif
instruction_set
==
'neon'
:
...
...
@@ -53,8 +52,16 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
'float'
:
32
,
'int'
:
32
}
width
=
bitwidth
//
bits
[
data_type
]
intwidth
=
bitwidth
//
bits
[
'int'
]
result
=
dict
()
if
instruction_set
==
'sve'
:
width
=
'svcntd()'
if
data_type
==
'double'
else
'svcntw()'
intwidth
=
'svcntw()'
result
[
'bytes'
]
=
'svcntb()'
else
:
width
=
bitwidth
//
bits
[
data_type
]
intwidth
=
bitwidth
//
bits
[
'int'
]
result
[
'bytes'
]
=
bitwidth
//
8
if
instruction_set
.
startswith
(
'sve'
):
prefix
=
'sv'
suffix
=
f
'_f
{
bits
[
data_type
]
}
'
...
...
@@ -62,11 +69,12 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
prefix
=
'v'
suffix
=
f
'q_f
{
bits
[
data_type
]
}
'
result
=
dict
()
result
[
'bytes'
]
=
bitwidth
//
8
predicate
=
f
'
{
prefix
}
whilelt_b
{
bits
[
data_type
]
}
(0,
{
width
}
)'
int_predicate
=
f
'
{
prefix
}
whilelt_b
{
bits
[
"int"
]
}
(0,
{
intwidth
}
)'
if
instruction_set
==
'sve'
:
predicate
=
f
'
{
prefix
}
whilelt_b
{
bits
[
data_type
]
}
_u64({{loop_counter}}, {{loop_stop}})'
int_predicate
=
f
'
{
prefix
}
whilelt_b
{
bits
[
"int"
]
}
_u64({{loop_counter}}, {{loop_stop}})'
else
:
predicate
=
f
'
{
prefix
}
whilelt_b
{
bits
[
data_type
]
}
(0,
{
width
}
)'
int_predicate
=
f
'
{
prefix
}
whilelt_b
{
bits
[
"int"
]
}
(0,
{
intwidth
}
)'
for
intrinsic_id
,
function_shortcut
in
base_names
.
items
():
function_shortcut
=
function_shortcut
.
strip
()
...
...
@@ -80,8 +88,13 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
result
[
intrinsic_id
]
=
prefix
+
name
+
suffix
+
undef
+
arg_string
result
[
'width'
]
=
width
result
[
'intwidth'
]
=
intwidth
if
instruction_set
==
'sve'
:
from
pystencils.backends.cbackend
import
CFunction
result
[
'width'
]
=
CFunction
(
width
,
"int"
)
result
[
'intwidth'
]
=
CFunction
(
intwidth
,
"int"
)
else
:
result
[
'width'
]
=
width
result
[
'intwidth'
]
=
intwidth
if
instruction_set
.
startswith
(
'sve'
):
result
[
'makeVecConst'
]
=
f
'svdup_f
{
bits
[
data_type
]
}
'
+
'({0})'
...
...
@@ -89,17 +102,17 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
result
[
'makeVecIndex'
]
=
f
'svindex_s
{
bits
[
"int"
]
}
'
+
'({0}, {1})'
vindex
=
f
'svindex_u
{
bits
[
data_type
]
}
(0, {{0}})'
result
[
's
catter
'
]
=
f
'svst1_scatter_u
{
bits
[
data_type
]
}
index_f
{
bits
[
data_type
]
}
(
{
predicate
}
, {{0}}, '
+
\
vindex
.
format
(
"{2}"
)
+
', {1})'
result
[
'
gather
'
]
=
f
'svld1_gather_u
{
bits
[
data_type
]
}
index_f
{
bits
[
data_type
]
}
(
{
predicate
}
, {{0}}, '
+
\
vindex
.
format
(
"{1}"
)
+
')'
result
[
's
toreS
'
]
=
f
'svst1_scatter_u
{
bits
[
data_type
]
}
index_f
{
bits
[
data_type
]
}
(
{
predicate
}
, {{0}}, '
+
\
vindex
.
format
(
"{2}"
)
+
', {1})'
result
[
'
loadS
'
]
=
f
'svld1_gather_u
{
bits
[
data_type
]
}
index_f
{
bits
[
data_type
]
}
(
{
predicate
}
, {{0}}, '
+
\
vindex
.
format
(
"{1}"
)
+
')'
result
[
'+int'
]
=
f
"svadd_s
{
bits
[
'int'
]
}
_x(
{
int_predicate
}
, "
+
"{0}, {1})"
result
[
'float'
]
=
'svfloat
32_s
t'
result
[
'double'
]
=
'svfloat
64_s
t'
result
[
'int'
]
=
f
'svint
{
bits
[
"int"
]
}
_
s
t'
result
[
'bool'
]
=
'svbool_
s
t'
result
[
'float'
]
=
f
'svfloat
{
bits
[
"float"
]
}
_
{
"s"
if
instruction_set
!=
"sve"
else
""
}
t'
result
[
'double'
]
=
f
'svfloat
{
bits
[
"double"
]
}
_
{
"s"
if
instruction_set
!=
"sve"
else
""
}
t'
result
[
'int'
]
=
f
'svint
{
bits
[
"int"
]
}
_
{
"s"
if
instruction_set
!=
"sve"
else
""
}
t'
result
[
'bool'
]
=
f
'svbool_
{
"s"
if
instruction_set
!=
"sve"
else
""
}
t'
result
[
'headers'
]
=
[
'<arm_sve.h>'
,
'"arm_neon_helpers.h"'
]
...
...
@@ -111,9 +124,10 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
result
[
'maskStoreU'
]
=
result
[
'storeU'
].
replace
(
predicate
,
'{2}'
)
result
[
'maskStoreA'
]
=
result
[
'storeA'
].
replace
(
predicate
,
'{2}'
)
result
[
'maskS
catter
'
]
=
result
[
's
catter
'
].
replace
(
predicate
,
'{3}'
)
result
[
'maskS
toreS
'
]
=
result
[
's
toreS
'
].
replace
(
predicate
,
'{3}'
)
result
[
'compile_flags'
]
=
[
f
'-msve-vector-bits=
{
bitwidth
}
'
]
if
instruction_set
!=
'sve'
:
result
[
'compile_flags'
]
=
[
f
'-msve-vector-bits=
{
bitwidth
}
'
]
else
:
result
[
'makeVecConst'
]
=
f
'vdupq_n_f
{
bits
[
data_type
]
}
'
+
'({0})'
result
[
'makeVec'
]
=
f
'makeVec_f
{
bits
[
data_type
]
}
'
+
'('
+
", "
.
join
([
'{'
+
str
(
i
)
+
'}'
for
i
in
...
...
@@ -137,7 +151,7 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
result
[
'any'
]
=
f
'vaddlvq_u8(vreinterpretq_u8_u
{
bits
[
data_type
]
}
({{0}})) > 0'
result
[
'all'
]
=
f
'vaddlvq_u8(vreinterpretq_u8_u
{
bits
[
data_type
]
}
({{0}})) == 16*0xff'
if
bitwidth
&
(
bitwidth
-
1
)
==
0
:
if
instruction_set
==
'sve'
or
bitwidth
&
(
bitwidth
-
1
)
==
0
:
# only power-of-2 vector sizes will evenly divide a cacheline
result
[
'cachelineSize'
]
=
'cachelineSize()'
result
[
'cachelineZero'
]
=
'cachelineZero((void*) {0})'
...
...
pystencils/backends/cbackend.py
View file @
fce1a417
...
...
@@ -6,6 +6,7 @@ from typing import Set
import
numpy
as
np
import
sympy
as
sp
from
sympy.core
import
S
from
sympy.core.cache
import
cacheit
from
sympy.logic.boolalg
import
BooleanFalse
,
BooleanTrue
from
pystencils.astnodes
import
KernelFunction
,
LoopOverCoordinate
,
Node
...
...
@@ -165,6 +166,23 @@ class PrintNode(CustomCodeNode):
self
.
headers
.
append
(
"<iostream>"
)
class
CFunction
(
TypedSymbol
):
def
__new__
(
cls
,
function
,
dtype
):
return
CFunction
.
__xnew_cached_
(
cls
,
function
,
dtype
)
def
__new_stage2__
(
cls
,
function
,
dtype
):
return
super
(
CFunction
,
cls
).
__xnew__
(
cls
,
function
,
dtype
)
__xnew__
=
staticmethod
(
__new_stage2__
)
__xnew_cached_
=
staticmethod
(
cacheit
(
__new_stage2__
))
def
__getnewargs__
(
self
):
return
self
.
name
,
self
.
dtype
def
__getnewargs_ex__
(
self
):
return
(
self
.
name
,
self
.
dtype
),
{}
# ------------------------------------------- Printer ------------------------------------------------------------------
...
...
@@ -184,6 +202,8 @@ class CBackend:
self
.
_indent
=
" "
self
.
_dialect
=
dialect
self
.
_signatureOnly
=
signature_only
self
.
_kwargs
=
{}
self
.
sympy_printer
.
_kwargs
=
self
.
_kwargs
def
__call__
(
self
,
node
):
prev_is
=
VectorType
.
instruction_set
...
...
@@ -205,7 +225,8 @@ class CBackend:
return
str
(
node
)
def
_print_KernelFunction
(
self
,
node
):
function_arguments
=
[
f
"
{
self
.
_print
(
s
.
symbol
.
dtype
)
}
{
s
.
symbol
.
name
}
"
for
s
in
node
.
get_parameters
()]
function_arguments
=
[
f
"
{
self
.
_print
(
s
.
symbol
.
dtype
)
}
{
s
.
symbol
.
name
}
"
for
s
in
node
.
get_parameters
()
if
not
type
(
s
.
symbol
)
is
CFunction
]
launch_bounds
=
""
if
self
.
_dialect
==
'cuda'
:
max_threads
=
node
.
indexing
.
max_threads_per_block
()
...
...
@@ -232,6 +253,8 @@ class CBackend:
condition
=
f
"
{
counter_symbol
}
<
{
self
.
sympy_printer
.
doprint
(
node
.
stop
)
}
"
update
=
f
"
{
counter_symbol
}
+=
{
self
.
sympy_printer
.
doprint
(
node
.
step
)
}
"
loop_str
=
f
"for (
{
start
}
;
{
condition
}
;
{
update
}
)"
self
.
_kwargs
[
'loop_counter'
]
=
counter_symbol
self
.
_kwargs
[
'loop_stop'
]
=
node
.
stop
prefix
=
"
\n
"
.
join
(
node
.
prefix_lines
)
if
prefix
:
...
...
@@ -265,7 +288,8 @@ class CBackend:
if
instr
not
in
self
.
_vector_instruction_set
:
self
.
_vector_instruction_set
[
instr
]
=
self
.
_vector_instruction_set
[
'store'
+
instr
[
-
1
]].
format
(
'{0}'
,
self
.
_vector_instruction_set
[
'blendv'
].
format
(
self
.
_vector_instruction_set
[
'load'
+
instr
[
-
1
]].
format
(
'{0}'
),
'{1}'
,
'{2}'
))
self
.
_vector_instruction_set
[
'load'
+
instr
[
-
1
]].
format
(
'{0}'
,
**
self
.
_kwargs
),
'{1}'
,
'{2}'
,
**
self
.
_kwargs
),
**
self
.
_kwargs
)
printed_mask
=
self
.
sympy_printer
.
doprint
(
mask
)
if
data_type
.
base_type
.
base_name
==
'double'
:
if
self
.
_vector_instruction_set
[
'double'
]
==
'__m256d'
:
...
...
@@ -287,9 +311,9 @@ class CBackend:
ptr
=
"&"
+
self
.
sympy_printer
.
doprint
(
node
.
lhs
.
args
[
0
])
if
stride
!=
1
:
instr
=
'maskS
catter
'
if
mask
!=
True
else
's
catter
'
# NOQA
instr
=
'maskS
toreS
'
if
mask
!=
True
else
's
toreS
'
# NOQA
return
self
.
_vector_instruction_set
[
instr
].
format
(
ptr
,
self
.
sympy_printer
.
doprint
(
rhs
),
stride
,
printed_mask
)
+
';'
stride
,
printed_mask
,
**
self
.
_kwargs
)
+
';'
pre_code
=
''
if
nontemporal
and
'cachelineZero'
in
self
.
_vector_instruction_set
:
...
...
@@ -301,22 +325,22 @@ class CBackend:
element_size
=
8
if
data_type
.
base_type
.
base_name
==
'double'
else
4
size_cond
=
f
"(
{
offset
}
+
{
CachelineSize
.
symbol
/
element_size
}
) <
{
size
}
"
pre_code
=
f
"if (
{
first_cond
}
&&
{
size_cond
}
) "
+
"{
\n\t
"
+
\
self
.
_vector_instruction_set
[
'cachelineZero'
].
format
(
ptr
)
+
';
\n
}
\n
'
self
.
_vector_instruction_set
[
'cachelineZero'
].
format
(
ptr
,
**
self
.
_kwargs
)
+
';
\n
}
\n
'
code
=
self
.
_vector_instruction_set
[
instr
].
format
(
ptr
,
self
.
sympy_printer
.
doprint
(
rhs
),
printed_mask
)
+
';'
printed_mask
,
**
self
.
_kwargs
)
+
';'
flushcond
=
f
"((uintptr_t)
{
ptr
}
&
{
CachelineSize
.
mask_symbol
}
) ==
{
CachelineSize
.
last_symbol
}
"
if
nontemporal
and
'flushCacheline'
in
self
.
_vector_instruction_set
:
code2
=
self
.
_vector_instruction_set
[
'flushCacheline'
].
format
(
ptr
,
self
.
sympy_printer
.
doprint
(
rhs
))
+
';'
ptr
,
self
.
sympy_printer
.
doprint
(
rhs
)
,
**
self
.
_kwargs
)
+
';'
code
=
f
"
{
code
}
\n
if (
{
flushcond
}
) {{
\n\t
{
code2
}
\n
}}"
elif
nontemporal
and
'storeAAndFlushCacheline'
in
self
.
_vector_instruction_set
:
tmpvar
=
'_tmp_'
+
hashlib
.
sha1
(
self
.
sympy_printer
.
doprint
(
rhs
).
encode
(
'ascii'
)).
hexdigest
()[:
8
]
code
=
'const '
+
self
.
_print
(
node
.
lhs
.
dtype
).
replace
(
' const'
,
''
)
+
' '
+
tmpvar
+
' = '
\
+
self
.
sympy_printer
.
doprint
(
rhs
)
+
';'
code1
=
self
.
_vector_instruction_set
[
instr
].
format
(
ptr
,
tmpvar
,
printed_mask
)
+
';'
code2
=
self
.
_vector_instruction_set
[
'storeAAndFlushCacheline'
].
format
(
ptr
,
tmpvar
,
printed_mask
)
\
+
';'
code1
=
self
.
_vector_instruction_set
[
instr
].
format
(
ptr
,
tmpvar
,
printed_mask
,
**
self
.
_kwargs
)
+
';'
code2
=
self
.
_vector_instruction_set
[
'storeAAndFlushCacheline'
].
format
(
ptr
,
tmpvar
,
printed_mask
,
**
self
.
_kwargs
)
+
';'
code
+=
f
"
\n
if (
{
flushcond
}
) {{
\n\t
{
code2
}
\n
}} else {{
\n\t
{
code1
}
\n
}}"
return
pre_code
+
code
else
:
...
...
@@ -617,16 +641,16 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
def
_print_Abs
(
self
,
expr
):
if
'abs'
in
self
.
instruction_set
and
isinstance
(
expr
.
args
[
0
],
vector_memory_access
):
return
self
.
instruction_set
[
'abs'
].
format
(
self
.
_print
(
expr
.
args
[
0
]))
return
self
.
instruction_set
[
'abs'
].
format
(
self
.
_print
(
expr
.
args
[
0
])
,
**
self
.
_kwargs
)
return
super
().
_print_Abs
(
expr
)
def
_print_Function
(
self
,
expr
):
if
isinstance
(
expr
,
vector_memory_access
):
arg
,
data_type
,
aligned
,
_
,
mask
,
stride
=
expr
.
args
if
stride
!=
1
:
return
self
.
instruction_set
[
'
gather
'
].
format
(
"& "
+
self
.
_print
(
arg
),
stride
)
return
self
.
instruction_set
[
'
loadS
'
].
format
(
"& "
+
self
.
_print
(
arg
),
stride
,
**
self
.
_kwargs
)
instruction
=
self
.
instruction_set
[
'loadA'
]
if
aligned
else
self
.
instruction_set
[
'loadU'
]
return
instruction
.
format
(
"& "
+
self
.
_print
(
arg
))
return
instruction
.
format
(
"& "
+
self
.
_print
(
arg
)
,
**
self
.
_kwargs
)
elif
isinstance
(
expr
,
cast_func
):
arg
,
data_type
=
expr
.
args
if
type
(
data_type
)
is
VectorType
:
...
...
@@ -640,19 +664,21 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
if
instruction
==
'makeVecInt'
and
'makeVecIndex'
in
self
.
instruction_set
:
increments
=
np
.
array
(
arg
)[
1
:]
-
np
.
array
(
arg
)[:
-
1
]
if
len
(
set
(
increments
))
==
1
:
return
self
.
instruction_set
[
'makeVecIndex'
].
format
(
printed_args
[
0
],
increments
[
0
])
return
self
.
instruction_set
[
instruction
].
format
(
*
printed_args
)
return
self
.
instruction_set
[
'makeVecIndex'
].
format
(
printed_args
[
0
],
increments
[
0
],
**
self
.
_kwargs
)
return
self
.
instruction_set
[
instruction
].
format
(
*
printed_args
,
**
self
.
_kwargs
)
else
:
is_boolean
=
get_type_of_expression
(
arg
)
==
create_type
(
"bool"
)
is_integer
=
get_type_of_expression
(
arg
)
==
create_type
(
"int"
)
or
\
(
isinstance
(
arg
,
TypedSymbol
)
and
arg
.
dtype
.
is_int
())
instruction
=
'makeVecConstBool'
if
is_boolean
else
\
'makeVecConstInt'
if
is_integer
else
'makeVecConst'
return
self
.
instruction_set
[
instruction
].
format
(
self
.
_print
(
arg
))
return
self
.
instruction_set
[
instruction
].
format
(
self
.
_print
(
arg
)
,
**
self
.
_kwargs
)
elif
expr
.
func
==
fast_division
:
result
=
self
.
_scalarFallback
(
'_print_Function'
,
expr
)
if
not
result
:
result
=
self
.
instruction_set
[
'/'
].
format
(
self
.
_print
(
expr
.
args
[
0
]),
self
.
_print
(
expr
.
args
[
1
]))
result
=
self
.
instruction_set
[
'/'
].
format
(
self
.
_print
(
expr
.
args
[
0
]),
self
.
_print
(
expr
.
args
[
1
]),
**
self
.
_kwargs
)
return
result
elif
expr
.
func
==
fast_sqrt
:
return
f
"(
{
self
.
_print
(
sp
.
sqrt
(
expr
.
args
[
0
]))
}
)"
...
...
@@ -660,7 +686,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result
=
self
.
_scalarFallback
(
'_print_Function'
,
expr
)
if
not
result
:
if
'rsqrt'
in
self
.
instruction_set
:
return
self
.
instruction_set
[
'rsqrt'
].
format
(
self
.
_print
(
expr
.
args
[
0
]))
return
self
.
instruction_set
[
'rsqrt'
].
format
(
self
.
_print
(
expr
.
args
[
0
])
,
**
self
.
_kwargs
)
else
:
return
f
"(
{
self
.
_print
(
1
/
sp
.
sqrt
(
expr
.
args
[
0
]))
}
)"
elif
isinstance
(
expr
,
vec_any
)
or
isinstance
(
expr
,
vec_all
):
...
...
@@ -672,8 +698,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
if
isinstance
(
expr
.
args
[
0
],
sp
.
Rel
):
op
=
expr
.
args
[
0
].
rel_op
if
(
instr
,
op
)
in
self
.
instruction_set
:
return
self
.
instruction_set
[(
instr
,
op
)].
format
(
*
[
self
.
_print
(
a
)
for
a
in
expr
.
args
[
0
].
args
])
return
self
.
instruction_set
[
instr
].
format
(
self
.
_print
(
expr
.
args
[
0
]))
return
self
.
instruction_set
[(
instr
,
op
)].
format
(
*
[
self
.
_print
(
a
)
for
a
in
expr
.
args
[
0
].
args
],
**
self
.
_kwargs
)
return
self
.
instruction_set
[
instr
].
format
(
self
.
_print
(
expr
.
args
[
0
]),
**
self
.
_kwargs
)
return
super
(
VectorizedCustomSympyPrinter
,
self
).
_print_Function
(
expr
)
...
...
@@ -686,7 +713,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
assert
len
(
arg_strings
)
>
0
result
=
arg_strings
[
0
]
for
item
in
arg_strings
[
1
:]:
result
=
self
.
instruction_set
[
'&'
].
format
(
result
,
item
)
result
=
self
.
instruction_set
[
'&'
].
format
(
result
,
item
,
**
self
.
_kwargs
)
return
result
def
_print_Or
(
self
,
expr
):
...
...
@@ -698,7 +725,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
assert
len
(
arg_strings
)
>
0
result
=
arg_strings
[
0
]
for
item
in
arg_strings
[
1
:]:
result
=
self
.
instruction_set
[
'|'
].
format
(
result
,
item
)
result
=
self
.
instruction_set
[
'|'
].
format
(
result
,
item
,
**
self
.
_kwargs
)
return
result
def
_print_Add
(
self
,
expr
,
order
=
None
):
...
...
@@ -739,7 +766,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
processed
=
summands
[
0
].
term
for
summand
in
summands
[
1
:]:
func
=
self
.
instruction_set
[
'-'
+
suffix
]
if
summand
.
sign
==
-
1
else
self
.
instruction_set
[
'+'
+
suffix
]
processed
=
func
.
format
(
processed
,
summand
.
term
)
processed
=
func
.
format
(
processed
,
summand
.
term
,
**
self
.
_kwargs
)
return
processed
def
_print_Pow
(
self
,
expr
):
...
...
@@ -747,21 +774,22 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
if
result
:
return
result
one
=
self
.
instruction_set
[
'makeVecConst'
].
format
(
1.0
)
one
=
self
.
instruction_set
[
'makeVecConst'
].
format
(
1.0
,
**
self
.
_kwargs
)
if
expr
.
exp
.
is_integer
and
expr
.
exp
.
is_number
and
0
<
expr
.
exp
<
8
:
return
"("
+
self
.
_print
(
sp
.
Mul
(
*
[
expr
.
base
]
*
expr
.
exp
,
evaluate
=
False
))
+
")"
elif
expr
.
exp
==
-
1
:
one
=
self
.
instruction_set
[
'makeVecConst'
].
format
(
1.0
)
return
self
.
instruction_set
[
'/'
].
format
(
one
,
self
.
_print
(
expr
.
base
))
one
=
self
.
instruction_set
[
'makeVecConst'
].
format
(
1.0
,
**
self
.
_kwargs
)
return
self
.
instruction_set
[
'/'
].
format
(
one
,
self
.
_print
(
expr
.
base
)
,
**
self
.
_kwargs
)
elif
expr
.
exp
==
0.5
:
return
self
.
instruction_set
[
'sqrt'
].
format
(
self
.
_print
(
expr
.
base
))
return
self
.
instruction_set
[
'sqrt'
].
format
(
self
.
_print
(
expr
.
base
)
,
**
self
.
_kwargs
)
elif
expr
.
exp
==
-
0.5
:
root
=
self
.
instruction_set
[
'sqrt'
].
format
(
self
.
_print
(
expr
.
base
))
return
self
.
instruction_set
[
'/'
].
format
(
one
,
root
)
root
=
self
.
instruction_set
[
'sqrt'
].
format
(
self
.
_print
(
expr
.
base
)
,
**
self
.
_kwargs
)
return
self
.
instruction_set
[
'/'
].
format
(
one
,
root
,
**
self
.
_kwargs
)
elif
expr
.
exp
.
is_integer
and
expr
.
exp
.
is_number
and
-
8
<
expr
.
exp
<
0
:
return
self
.
instruction_set
[
'/'
].
format
(
one
,
self
.
_print
(
sp
.
Mul
(
*
[
expr
.
base
]
*
(
-
expr
.
exp
),
evaluate
=
False
)))
self
.
_print
(
sp
.
Mul
(
*
[
expr
.
base
]
*
(
-
expr
.
exp
),
evaluate
=
False
)),
**
self
.
_kwargs
)
else
:
raise
ValueError
(
"Generic exponential not supported: "
+
str
(
expr
))
...
...
@@ -800,19 +828,19 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result
=
a_str
[
0
]
for
item
in
a_str
[
1
:]:
result
=
self
.
instruction_set
[
'*'
].
format
(
result
,
item
)
result
=
self
.
instruction_set
[
'*'
].
format
(
result
,
item
,
**
self
.
_kwargs
)
if
len
(
b
)
>
0
:
denominator_str
=
b_str
[
0
]
for
item
in
b_str
[
1
:]:
denominator_str
=
self
.
instruction_set
[
'*'
].
format
(
denominator_str
,
item
)
result
=
self
.
instruction_set
[
'/'
].
format
(
result
,
denominator_str
)
denominator_str
=
self
.
instruction_set
[
'*'
].
format
(
denominator_str
,
item
,
**
self
.
_kwargs
)
result
=
self
.
instruction_set
[
'/'
].
format
(
result
,
denominator_str
,
**
self
.
_kwargs
)
if
inside_add
:
return
sign
,
result
else
:
if
sign
<
0
:
return
self
.
instruction_set
[
'*'
].
format
(
self
.
_print
(
S
.
NegativeOne
),
result
)
return
self
.
instruction_set
[
'*'
].
format
(
self
.
_print
(
S
.
NegativeOne
),
result
,
**
self
.
_kwargs
)
else
:
return
result
...
...
@@ -820,13 +848,13 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result
=
self
.
_scalarFallback
(
'_print_Relational'
,
expr
)
if
result
:
return
result
return
self
.
instruction_set
[
expr
.
rel_op
].
format
(
self
.
_print
(
expr
.
lhs
),
self
.
_print
(
expr
.
rhs
))
return
self
.
instruction_set
[
expr
.
rel_op
].
format
(
self
.
_print
(
expr
.
lhs
),
self
.
_print
(
expr
.
rhs
)
,
**
self
.
_kwargs
)
def
_print_Equality
(
self
,
expr
):
result
=
self
.
_scalarFallback
(
'_print_Equality'
,
expr
)
if
result
:
return
result
return
self
.
instruction_set
[
'=='
].
format
(
self
.
_print
(
expr
.
lhs
),
self
.
_print
(
expr
.
rhs
))
return
self
.
instruction_set
[
'=='
].
format
(
self
.
_print
(
expr
.
lhs
),
self
.
_print
(
expr
.
rhs
)
,
**
self
.
_kwargs
)
def
_print_Piecewise
(
self
,
expr
):
result
=
self
.
_scalarFallback
(
'_print_Piecewise'
,
expr
)
...
...
@@ -847,10 +875,11 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
if
isinstance
(
condition
,
cast_func
)
and
get_type_of_expression
(
condition
.
args
[
0
])
==
create_type
(
"bool"
):
if
not
KERNCRAFT_NO_TERNARY_MODE
:
result
=
"(({}) ? ({}) : ({}))"
.
format
(
self
.
_print
(
condition
.
args
[
0
]),
self
.
_print
(
true_expr
),
result
)
result
,
**
self
.
_kwargs
)
else
:
print
(
"Warning - skipping ternary op"
)
else
:
# noinspection SpellCheckingInspection
result
=
self
.
instruction_set
[
'blendv'
].
format
(
result
,
self
.
_print
(
true_expr
),
self
.
_print
(
condition
))
result
=
self
.
instruction_set
[
'blendv'
].
format
(
result
,
self
.
_print
(
true_expr
),
self
.
_print
(
condition
),
**
self
.
_kwargs
)
return
result
pystencils/backends/riscv_instruction_sets.py
0 → 100644
View file @
fce1a417
def
get_argument_string
(
function_shortcut
,
last
=
''
):
args
=
function_shortcut
[
function_shortcut
.
index
(
'['
)
+
1
:
-
1
]
arg_string
=
"("
for
arg
in
args
.
split
(
","
):
arg
=
arg
.
strip
()
if
not
arg
:
continue
if
arg
in
(
'0'
,
'1'
,
'2'
,
'3'
,
'4'
,
'5'
):
arg_string
+=
"{"
+
arg
+
"},"
else
:
arg_string
+=
arg
+
","
if
last
:
arg_string
+=
last
+
','
arg_string
=
arg_string
[:
-
1
]
+
")"
return
arg_string
def
get_vector_instruction_set_riscv
(
data_type
=
'double'
,
instruction_set
=
'rvv'
):
assert
instruction_set
==
'rvv'
bits
=
{
'double'
:
64
,
'float'
:
32
,
'int'
:
32
}
base_names
=
{
'+'
:
'fadd_vv[0, 1]'
,
'-'
:
'fsub_vv[0, 1]'
,
'*'
:
'fmul_vv[0, 1]'
,
'/'
:
'fdiv_vv[0, 1]'
,
'sqrt'
:
'fsqrt_v[0]'
,
'loadU'
:
f
'le
{
bits
[
data_type
]
}
_v[0]'
,
'loadA'
:
f
'le
{
bits
[
data_type
]
}
_v[0]'
,
'storeU'
:
f
'se
{
bits
[
data_type
]
}
_v[0, 1]'
,
'storeA'
:
f
'se
{
bits
[
data_type
]
}
_v[0, 1]'
,
'maskStoreU'
:
f
'se
{
bits
[
data_type
]
}
_v[2, 0, 1]'
,
'maskStoreA'
:
f
'se
{
bits
[
data_type
]
}
_v[2, 0, 1]'
,
'loadS'
:
f
'lse
{
bits
[
data_type
]
}
_v[0, 1]'
,
'storeS'
:
f
'sse
{
bits
[
data_type
]
}
_v[0, 2, 1]'
,
'maskStoreS'
:
f
'sse
{
bits
[
data_type
]
}
_v[2, 0, 3, 1]'
,
'abs'
:
'fabs_v[0]'
,
'=='
:
'mfeq_vv[0, 1]'
,
'!='
:
'mfne_vv[0, 1]'
,
'<='
:
'mfle_vv[0, 1]'
,
'<'
:
'mflt_vv[0, 1]'
,
'>='
:
'mfge_vv[0, 1]'
,
'>'
:
'mfgt_vv[0, 1]'
,
'&'
:
'mand_mm[0, 1]'
,
'|'
:
'mor_mm[0, 1]'
,
'blendv'
:
'merge_vvm[2, 0, 1]'
,
'any'
:
'popc_m[0]'
,
'all'
:
'popc_m[0]'
,
}
result
=
dict
()
width
=
f
'vsetvlmax_e
{
bits
[
data_type
]
}
m1()'
intwidth
=
'vsetvlmax_e{bits["int"]}m1()'
result
[
'bytes'
]
=
'vsetvlmax_e8m1()'
prefix
=
'v'
suffix
=
f
'_f
{
bits
[
data_type
]
}
m1'
vl
=
'{loop_stop} - {loop_counter}'
int_vl
=
f
'(
{
vl
}
)*
{
bits
[
data_type
]
//
bits
[
"int"
]
}
'
for
intrinsic_id
,
function_shortcut
in
base_names
.
items
():
function_shortcut
=
function_shortcut
.
strip
()
name
=
function_shortcut
[:
function_shortcut
.
index
(
'['
)]
if
name
.
startswith
(
'mf'
):
suffix2
=
suffix
+
f
'_b
{
bits
[
data_type
]
}
'
elif
name
.
endswith
(
'_mm'
)
or
name
.
endswith
(
'_m'
):
suffix2
=
f
'_b
{
bits
[
data_type
]
}
'
elif
intrinsic_id
.
startswith
(
'mask'
):
suffix2
=
suffix
+
'_m'
else
:
suffix2
=
suffix
arg_string
=
get_argument_string
(
function_shortcut
,
last
=
vl
)
result
[
intrinsic_id
]
=
prefix
+
name
+
suffix2
+
arg_string
from
pystencils.backends.cbackend
import
CFunction
result
[
'width'
]
=
CFunction
(
width
,
"int"
)
result
[
'intwidth'
]
=
CFunction
(
intwidth
,
"int"
)
result
[
'makeVecConst'
]
=
f
'vfmv_v_f_f
{
bits
[
data_type
]
}
m1({{0}},
{
vl
}
)'
result
[
'makeVecConstInt'
]
=
f
'vmv_v_x_i
{
bits
[
"int"
]
}
m1({{0}},
{
int_vl
}
)'
result
[
'makeVecIndex'
]
=
f
'vmacc_vx_i
{
bits
[
"int"
]
}
m1(
{
result
[
"makeVecConstInt"
]
}
, {{1}}, '
+
\
f
'vid_v_i
{
bits
[
"int"
]
}
m1(
{
int_vl
}
),
{
int_vl
}
)'
result
[
'storeS'
]
=
result
[
'storeS'
].
replace
(
'{2}'
,
f
'{{2}}*
{
bits
[
data_type
]
//
8
}
'
)
result
[
'loadS'
]
=
result
[
'loadS'
].
replace
(
'{1}'
,
f
'{{1}}*
{
bits
[
data_type
]
//
8
}
'
)
result
[
'maskStoreS'
]
=
result
[
'maskStoreS'
].
replace
(
'{3}'
,
f
'{{3}}*
{
bits
[
data_type
]
//
8
}
'
)
result
[
'+int'
]
=
f
"vadd_vv_i
{
bits
[
'int'
]
}
m1({{0}}, {{1}},
{
int_vl
}
)"
result
[
'float'
]
=
f
'vfloat
{
bits
[
"float"
]
}
m1_t'
result
[
'double'
]
=
f
'vfloat
{
bits
[
"double"
]
}
m1_t'
result
[
'int'
]
=
f
'vint
{
bits
[
"int"
]
}
m1_t'
result
[
'bool'
]
=
f
'vbool
{
bits
[
data_type
]
}
_t'
result
[
'headers'
]
=
[
'<riscv_vector.h>'
]
result
[
'any'
]
+=
' > 0x0'
result
[
'all'
]
+=
f
' == vsetvl_e
{
bits
[
data_type
]
}
m1(
{
vl
}
)'
return
result
pystencils/backends/simd_instruction_sets.py
View file @
fce1a417
...
...
@@ -6,6 +6,7 @@ from ctypes import CDLL
from
pystencils.backends.x86_instruction_sets
import
get_vector_instruction_set_x86
from
pystencils.backends.arm_instruction_sets
import
get_vector_instruction_set_arm
from
pystencils.backends.ppc_instruction_sets
import
get_vector_instruction_set_ppc
from
pystencils.backends.riscv_instruction_sets
import
get_vector_instruction_set_riscv
def
get_vector_instruction_set
(
data_type
=
'double'
,
instruction_set
=
'avx'
):
...
...
@@ -13,6 +14,8 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):