Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Stephan Seitz
pystencils
Commits
85053df1
Commit
85053df1
authored
Oct 15, 2019
by
Martin Bauer
Browse files
More general vectorization
parent
a822ffc9
Changes
5
Hide whitespace changes
Inline
Side-by-side
pystencils/backends/cbackend.py
View file @
85053df1
...
...
@@ -233,11 +233,17 @@ class CBackend:
self
.
sympy_printer
.
doprint
(
node
.
rhs
))
else
:
lhs_type
=
get_type_of_expression
(
node
.
lhs
)
printed_mask
=
""
if
type
(
lhs_type
)
is
VectorType
and
isinstance
(
node
.
lhs
,
cast_func
):
arg
,
data_type
,
aligned
,
nontemporal
=
node
.
lhs
.
args
arg
,
data_type
,
aligned
,
nontemporal
,
mask
=
node
.
lhs
.
args
instr
=
'storeU'
if
aligned
:
instr
=
'stream'
if
nontemporal
else
'storeA'
if
mask
!=
True
:
instr
=
'maskStore'
if
aligned
else
'maskStoreU'
printed_mask
=
self
.
sympy_printer
.
doprint
(
mask
)
if
self
.
_vector_instruction_set
[
'dataTypePrefix'
][
'double'
]
==
'__mm256d'
:
printed_mask
=
"_mm256_castpd_si256({})"
.
format
(
printed_mask
)
rhs_type
=
get_type_of_expression
(
node
.
rhs
)
if
type
(
rhs_type
)
is
not
VectorType
:
...
...
@@ -246,7 +252,8 @@ class CBackend:
rhs
=
node
.
rhs
return
self
.
_vector_instruction_set
[
instr
].
format
(
"&"
+
self
.
sympy_printer
.
doprint
(
node
.
lhs
.
args
[
0
]),
self
.
sympy_printer
.
doprint
(
rhs
))
+
';'
self
.
sympy_printer
.
doprint
(
rhs
),
printed_mask
)
+
';'
else
:
return
"%s = %s;"
%
(
self
.
sympy_printer
.
doprint
(
node
.
lhs
),
self
.
sympy_printer
.
doprint
(
node
.
rhs
))
...
...
@@ -450,7 +457,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
def
_print_Function
(
self
,
expr
):
if
isinstance
(
expr
,
vector_memory_access
):
arg
,
data_type
,
aligned
,
_
=
expr
.
args
arg
,
data_type
,
aligned
,
_
,
mask
=
expr
.
args
instruction
=
self
.
instruction_set
[
'loadA'
]
if
aligned
else
self
.
instruction_set
[
'loadU'
]
return
instruction
.
format
(
"& "
+
self
.
_print
(
arg
))
elif
isinstance
(
expr
,
cast_func
):
...
...
pystencils/backends/simd_instruction_sets.py
View file @
85053df1
...
...
@@ -32,7 +32,24 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
'storeU'
:
'storeu[0,1]'
,
'storeA'
:
'store[0,1]'
,
'stream'
:
'stream[0,1]'
,
'maskstore'
:
'mask_store[0, 2, 1]'
if
instruction_set
==
'avx512'
else
'maskstore[0, 2, 1]'
,
'maskload'
:
'mask_load[0, 2, 1]'
if
instruction_set
==
'avx512'
else
'maskload[0, 2, 1]'
}
if
instruction_set
==
'avx512'
:
base_names
.
update
({
'maskStore'
:
'mask_store[0, 2, 1]'
,
'maskStoreU'
:
'mask_storeu[0, 2, 1]'
,
'maskLoad'
:
'mask_load[2, 1, 0]'
,
'maskLoadU'
:
'mask_loadu[2, 1, 0]'
})
if
instruction_set
==
'avx'
:
base_names
.
update
({
'maskStore'
:
'maskstore[0, 2, 1]'
,
'maskStoreU'
:
'maskstore[0, 2, 1]'
,
'maskLoad'
:
'maskload[0, 1]'
,
'maskLoadU'
:
'maskloadu[0, 1]'
})
for
comparison_op
,
constant
in
comparisons
.
items
():
base_names
[
comparison_op
]
=
'cmp[0, 1, %s]'
%
(
constant
,)
...
...
pystencils/cpu/vectorization.py
View file @
85053df1
...
...
@@ -123,7 +123,7 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
nontemporal
=
False
if
hasattr
(
indexed
,
'field'
):
nontemporal
=
(
indexed
.
field
in
nontemporal_fields
)
or
(
indexed
.
field
.
name
in
nontemporal_fields
)
substitutions
[
indexed
]
=
vector_memory_access
(
indexed
,
vec_type
,
use_aligned_access
,
nontemporal
)
substitutions
[
indexed
]
=
vector_memory_access
(
indexed
,
vec_type
,
use_aligned_access
,
nontemporal
,
True
)
if
not
successful
:
warnings
.
warn
(
"Could not vectorize loop because of non-consecutive memory access"
)
continue
...
...
@@ -136,6 +136,30 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
fast_subs
(
loop_node
,
{
loop_counter_symbol
:
vector_loop_counter
},
skip
=
lambda
e
:
isinstance
(
e
,
ast
.
ResolvedFieldAccess
)
or
isinstance
(
e
,
vector_memory_access
))
mask_conditionals
(
loop_node
)
def
mask_conditionals
(
loop_body
):
def
visit_node
(
node
,
mask
):
if
isinstance
(
node
,
ast
.
Conditional
):
true_mask
=
sp
.
And
(
node
.
condition_expr
,
mask
)
visit_node
(
node
.
true_block
,
true_mask
)
if
node
.
false_block
:
false_mask
=
sp
.
And
(
sp
.
Not
(
node
.
condition_expr
),
mask
)
visit_node
(
node
,
false_mask
)
node
.
condition_expr
=
vec_any
(
node
.
condition_expr
)
elif
isinstance
(
node
,
ast
.
SympyAssignment
):
if
mask
is
not
True
:
s
=
{
ma
:
vector_memory_access
(
ma
.
args
[
0
],
ma
.
args
[
1
],
ma
.
args
[
2
],
ma
.
args
[
3
],
sp
.
And
(
mask
,
ma
.
args
[
4
]))
for
ma
in
node
.
atoms
(
vector_memory_access
)}
node
.
subs
(
s
)
else
:
for
arg
in
node
.
args
:
visit_node
(
arg
,
mask
)
visit_node
(
loop_body
,
mask
=
True
)
def
insert_vector_casts
(
ast_node
):
"""Inserts necessary casts from scalar values to vector values."""
...
...
@@ -143,8 +167,10 @@ def insert_vector_casts(ast_node):
handled_functions
=
(
sp
.
Add
,
sp
.
Mul
,
fast_division
,
fast_sqrt
,
fast_inv_sqrt
,
vec_any
,
vec_all
)
def
visit_expr
(
expr
):
if
isinstance
(
expr
,
cast_func
)
or
isinstance
(
expr
,
vector_memory_access
):
if
isinstance
(
expr
,
vector_memory_access
):
return
vector_memory_access
(
expr
.
args
[
0
],
expr
.
args
[
1
],
expr
.
args
[
2
],
expr
.
args
[
3
],
visit_expr
(
expr
.
args
[
4
]))
elif
isinstance
(
expr
,
cast_func
):
return
expr
elif
expr
.
func
in
handled_functions
or
isinstance
(
expr
,
sp
.
Rel
)
or
isinstance
(
expr
,
sp
.
boolalg
.
BooleanFunction
):
new_args
=
[
visit_expr
(
a
)
for
a
in
expr
.
args
]
...
...
@@ -199,10 +225,12 @@ def insert_vector_casts(ast_node):
new_lhs
=
TypedSymbol
(
assignment
.
lhs
.
name
,
new_lhs_type
)
substitution_dict
[
assignment
.
lhs
]
=
new_lhs
assignment
.
lhs
=
new_lhs
elif
isinstance
(
assignment
.
lhs
.
func
,
cast_func
):
lhs_type
=
assignment
.
lhs
.
args
[
1
]
if
type
(
lhs_type
)
is
VectorType
and
type
(
rhs_type
)
is
not
VectorType
:
assignment
.
rhs
=
cast_func
(
assignment
.
rhs
,
lhs_type
)
elif
isinstance
(
assignment
.
lhs
,
vector_memory_access
):
assignment
.
lhs
=
visit_expr
(
assignment
.
lhs
)
#elif isinstance(assignment.lhs, cast_func): # TODO check if necessary
# lhs_type = assignment.lhs.args[1]
# if type(lhs_type) is VectorType and type(rhs_type) is not VectorType:
# assignment.rhs = cast_func(assignment.rhs, lhs_type)
elif
isinstance
(
arg
,
ast
.
Conditional
):
arg
.
condition_expr
=
fast_subs
(
arg
.
condition_expr
,
substitution_dict
,
skip
=
lambda
e
:
isinstance
(
e
,
ast
.
ResolvedFieldAccess
))
...
...
pystencils/data_types.py
View file @
85053df1
...
...
@@ -190,7 +190,8 @@ class boolean_cast_func(cast_func, Boolean):
# noinspection PyPep8Naming
class
vector_memory_access
(
cast_func
):
nargs
=
(
4
,)
# Arguments are: read/write expression, type, aligned, nontemporal, mask (or none)
nargs
=
(
5
,)
# noinspection PyPep8Naming
...
...
pystencils/transformations.py
View file @
85053df1
...
...
@@ -891,10 +891,10 @@ class KernelConstraintsCheck:
if
isinstance
(
lhs
,
AbstractField
.
AbstractAccess
):
fai
=
self
.
FieldAndIndex
(
lhs
.
field
,
lhs
.
index
)
self
.
_field_writes
[
fai
].
add
(
lhs
.
offsets
)
if
len
(
self
.
_field_writes
[
fai
])
>
1
:
raise
ValueError
(
"Field {} is written at two different locations"
.
format
(
lhs
.
field
.
name
))
#
if len(self._field_writes[fai]) > 1:
#
raise ValueError(
#
"Field {} is written at two different locations".format(
#
lhs.field.name))
elif
isinstance
(
lhs
,
sp
.
Symbol
):
if
self
.
scopes
.
is_defined_locally
(
lhs
):
raise
ValueError
(
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment