Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Jonas Plewinski
pystencils
Commits
8a89d0bc
Commit
8a89d0bc
authored
Mar 15, 2019
by
Martin Bauer
Browse files
Fixes for fast_* nodes and SIMD printer
parent
2f5f6ad6
Changes
2
Hide whitespace changes
Inline
Side-by-side
backends/cbackend.py
View file @
8a89d0bc
...
...
@@ -326,14 +326,18 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
if
type
(
data_type
)
is
VectorType
:
return
self
.
instruction_set
[
'makeVec'
].
format
(
self
.
_print
(
arg
))
elif
expr
.
func
==
fast_division
:
return
self
.
instruction_set
[
'/'
].
format
(
self
.
_print
(
expr
.
args
[
0
]),
self
.
_print
(
expr
.
args
[
1
]))
result
=
self
.
_scalarFallback
(
'_print_Function'
,
expr
)
if
not
result
:
return
self
.
instruction_set
[
'/'
].
format
(
self
.
_print
(
expr
.
args
[
0
]),
self
.
_print
(
expr
.
args
[
1
]))
elif
expr
.
func
==
fast_sqrt
:
return
"({})"
.
format
(
self
.
_print
(
sp
.
sqrt
(
expr
.
args
[
0
])))
elif
expr
.
func
==
fast_inv_sqrt
:
if
self
.
instruction_set
[
'rsqrt'
]:
return
self
.
instruction_set
[
'rsqrt'
].
format
(
self
.
_print
(
expr
.
args
[
0
]))
else
:
return
"({})"
.
format
(
self
.
doprint
(
1
/
sp
.
sqrt
(
expr
.
args
[
0
])))
result
=
self
.
_scalarFallback
(
'_print_Function'
,
expr
)
if
not
result
:
if
self
.
instruction_set
[
'rsqrt'
]:
return
self
.
instruction_set
[
'rsqrt'
].
format
(
self
.
_print
(
expr
.
args
[
0
]))
else
:
return
"({})"
.
format
(
self
.
_print
(
1
/
sp
.
sqrt
(
expr
.
args
[
0
])))
return
super
(
VectorizedCustomSympyPrinter
,
self
).
_print_Function
(
expr
)
def
_print_And
(
self
,
expr
):
...
...
cpu/vectorization.py
View file @
8a89d0bc
...
...
@@ -2,6 +2,7 @@ import sympy as sp
import
warnings
from
typing
import
Union
,
Container
from
pystencils.backends.simd_instruction_sets
import
get_vector_instruction_set
from
pystencils.fast_approximation
import
fast_division
,
fast_sqrt
,
fast_inv_sqrt
from
pystencils.integer_functions
import
modulo_floor
,
modulo_ceil
from
pystencils.sympyextensions
import
fast_subs
from
pystencils.data_types
import
TypedSymbol
,
VectorType
,
get_type_of_expression
,
vector_memory_access
,
cast_func
,
\
...
...
@@ -118,10 +119,13 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
def
insert_vector_casts
(
ast_node
):
"""Inserts necessary casts from scalar values to vector values."""
handled_functions
=
(
sp
.
Add
,
sp
.
Mul
,
fast_division
,
fast_sqrt
,
fast_inv_sqrt
)
def
visit_expr
(
expr
):
if
isinstance
(
expr
,
cast_func
)
or
isinstance
(
expr
,
vector_memory_access
):
return
expr
elif
expr
.
func
in
(
sp
.
Add
,
sp
.
Mul
)
or
isinstance
(
expr
,
sp
.
Rel
)
or
isinstance
(
expr
,
sp
.
boolalg
.
BooleanFunction
):
elif
expr
.
func
in
handled_functions
or
isinstance
(
expr
,
sp
.
Rel
)
or
isinstance
(
expr
,
sp
.
boolalg
.
BooleanFunction
):
new_args
=
[
visit_expr
(
a
)
for
a
in
expr
.
args
]
arg_types
=
[
get_type_of_expression
(
a
)
for
a
in
new_args
]
if
not
any
(
type
(
t
)
is
VectorType
for
t
in
arg_types
):
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment