Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Jonas Plewinski
pystencils
Commits
866e9fc0
Commit
866e9fc0
authored
May 14, 2018
by
Martin Bauer
Browse files
Fixes in vectorization to also support float kernels
parent
501b2d7e
Changes
7
Hide whitespace changes
Inline
Side-by-side
backends/cbackend.py
View file @
866e9fc0
...
...
@@ -213,6 +213,7 @@ class CustomSympyPrinter(CCodePrinter):
def
__init__
(
self
,
constants_as_floats
=
False
):
self
.
_constantsAsFloats
=
constants_as_floats
super
(
CustomSympyPrinter
,
self
).
__init__
()
self
.
_float_type
=
create_type
(
"float32"
)
def
_print_Pow
(
self
,
expr
):
"""Don't use std::pow function, for small integer exponents, write as multiplication"""
...
...
@@ -224,8 +225,6 @@ class CustomSympyPrinter(CCodePrinter):
def
_print_Rational
(
self
,
expr
):
"""Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0"""
res
=
str
(
expr
.
evalf
().
num
)
if
self
.
_constantsAsFloats
:
res
+=
"f"
return
res
def
_print_Equality
(
self
,
expr
):
...
...
@@ -237,12 +236,6 @@ class CustomSympyPrinter(CCodePrinter):
result
=
super
(
CustomSympyPrinter
,
self
).
_print_Piecewise
(
expr
)
return
result
.
replace
(
"
\n
"
,
""
)
def
_print_Float
(
self
,
expr
):
res
=
str
(
expr
)
if
self
.
_constantsAsFloats
:
res
+=
"f"
return
res
def
_print_Function
(
self
,
expr
):
function_map
=
{
bitwise_xor
:
'^'
,
...
...
@@ -255,7 +248,10 @@ class CustomSympyPrinter(CCodePrinter):
return
expr
.
to_c
(
self
.
_print
)
if
expr
.
func
==
cast_func
:
arg
,
data_type
=
expr
.
args
return
"*((%s)(& %s))"
%
(
PointerType
(
data_type
),
self
.
_print
(
arg
))
if
isinstance
(
arg
,
sp
.
Number
):
return
self
.
_typed_number
(
arg
,
data_type
)
else
:
return
"*((%s)(& %s))"
%
(
PointerType
(
data_type
),
self
.
_print
(
arg
))
elif
expr
.
func
==
modulo_floor
:
assert
all
(
get_type_of_expression
(
e
).
is_int
()
for
e
in
expr
.
args
)
return
"({dtype})({0} / {1}) * {1}"
.
format
(
*
expr
.
args
,
dtype
=
get_type_of_expression
(
expr
.
args
[
0
]))
...
...
@@ -264,6 +260,17 @@ class CustomSympyPrinter(CCodePrinter):
else
:
return
super
(
CustomSympyPrinter
,
self
).
_print_Function
(
expr
)
def
_typed_number
(
self
,
number
,
dtype
):
res
=
self
.
_print
(
number
)
if
dtype
.
is_float
:
if
dtype
==
self
.
_float_type
:
if
'.'
not
in
res
:
res
+=
".0f"
else
:
res
+=
"f"
return
res
else
:
return
res
# noinspection PyPep8Naming
class
VectorizedCustomSympyPrinter
(
CustomSympyPrinter
):
...
...
backends/simd_instruction_sets.py
View file @
866e9fc0
...
...
@@ -20,7 +20,7 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
'sqrt'
:
'sqrt[0]'
,
'makeVec'
:
'set[
0,0,0,0
]'
,
'makeVec'
:
'set[]'
,
'makeZero'
:
'setzero[]'
,
'loadU'
:
'loadu[0]'
,
...
...
@@ -31,6 +31,7 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
}
headers
=
{
'avx512'
:
[
'<immintrin.h>'
],
'avx'
:
[
'<immintrin.h>'
],
'sse'
:
[
'<xmmintrin.h>'
,
'<emmintrin.h>'
,
'<pmmintrin.h>'
,
'<tmmintrin.h>'
,
'<smmintrin.h>'
,
'<nmmintrin.h>'
]
}
...
...
@@ -54,32 +55,37 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
(
"float"
,
"avx512"
):
16
,
}
result
=
{}
result
=
{
'width'
:
width
[(
data_type
,
instruction_set
)],
}
pre
=
prefix
[
instruction_set
]
suf
=
suffix
[
data_type
]
for
intrinsic_id
,
function_shortcut
in
base_names
.
items
():
function_shortcut
=
function_shortcut
.
strip
()
name
=
function_shortcut
[:
function_shortcut
.
index
(
'['
)]
args
=
function_shortcut
[
function_shortcut
.
index
(
'['
)
+
1
:
-
1
]
arg_string
=
"("
for
arg
in
args
.
split
(
","
):
arg
=
arg
.
strip
()
if
not
arg
:
continue
if
arg
in
(
'0'
,
'1'
,
'2'
,
'3'
,
'4'
,
'5'
):
arg_string
+=
"{"
+
arg
+
"},"
else
:
arg_string
+=
arg
+
","
arg_string
=
arg_string
[:
-
1
]
+
")"
if
intrinsic_id
==
'makeVec'
:
arg_string
=
"({})"
.
format
(
","
.
join
([
"{0}"
]
*
result
[
'width'
]))
else
:
args
=
function_shortcut
[
function_shortcut
.
index
(
'['
)
+
1
:
-
1
]
arg_string
=
"("
for
arg
in
args
.
split
(
","
):
arg
=
arg
.
strip
()
if
not
arg
:
continue
if
arg
in
(
'0'
,
'1'
,
'2'
,
'3'
,
'4'
,
'5'
):
arg_string
+=
"{"
+
arg
+
"},"
else
:
arg_string
+=
arg
+
","
arg_string
=
arg_string
[:
-
1
]
+
")"
result
[
intrinsic_id
]
=
pre
+
"_"
+
name
+
"_"
+
suf
+
arg_string
result
[
'width'
]
=
width
[(
data_type
,
instruction_set
)]
result
[
'dataTypePrefix'
]
=
{
'double'
:
"_"
+
pre
+
'd'
,
'float'
:
"_"
+
pre
,
}
bit_width
=
result
[
'width'
]
*
64
bit_width
=
result
[
'width'
]
*
(
64
if
data_type
==
'double'
else
32
)
result
[
'double'
]
=
"__m%dd"
%
(
bit_width
,)
result
[
'float'
]
=
"__m%d"
%
(
bit_width
,)
result
[
'int'
]
=
"__m%di"
%
(
bit_width
,)
...
...
cpu/vectorization.py
View file @
866e9fc0
...
...
@@ -13,13 +13,13 @@ from pystencils.transformations import cut_loop, filtered_tree_iteration
from
pystencils.field
import
Field
def
vectorize
(
kernel_ast
:
ast
.
KernelFunction
,
vector_
instruction_set
:
str
=
'avx'
,
def
vectorize
(
kernel_ast
:
ast
.
KernelFunction
,
instruction_set
:
str
=
'avx'
,
assume_aligned
:
bool
=
False
,
nontemporal
:
Union
[
bool
,
Container
[
Union
[
str
,
Field
]]]
=
False
):
"""Explicit vectorization using SIMD vectorization via intrinsics.
Args:
kernel_ast: abstract syntax tree (KernelFunction node)
vector_
instruction_set: one of the supported vector instruction sets, currently ('sse', 'avx' and 'avx512')
instruction_set: one of the supported vector instruction sets, currently ('sse', 'avx' and 'avx512')
assume_aligned: assume that the first inner cell of each line is aligned. If false, only unaligned-loads are
used. If true, some of the loads are assumed to be from aligned memory addresses.
For example if x is the fastest coordinate, the access to center can be fetched via an
...
...
@@ -42,7 +42,7 @@ def vectorize(kernel_ast: ast.KernelFunction, vector_instruction_set: str = 'avx
float_size
=
field_float_dtypes
.
pop
().
numpy_dtype
.
itemsize
assert
float_size
in
(
8
,
4
)
vector_is
=
get_vector_instruction_set
(
'double'
if
float_size
==
8
else
'float'
,
instruction_set
=
vector_
instruction_set
)
instruction_set
=
instruction_set
)
vector_width
=
vector_is
[
'width'
]
kernel_ast
.
instruction_set
=
vector_is
...
...
data_types.py
View file @
866e9fc0
...
...
@@ -289,7 +289,10 @@ def get_type_of_expression(expr):
from
pystencils.astnodes
import
ResolvedFieldAccess
expr
=
sp
.
sympify
(
expr
)
if
isinstance
(
expr
,
sp
.
Integer
):
return
create_type
(
"int"
)
if
expr
==
1
or
expr
==
-
1
:
return
create_type
(
"int16"
)
else
:
return
create_type
(
"int"
)
elif
isinstance
(
expr
,
sp
.
Rational
)
or
isinstance
(
expr
,
sp
.
Float
):
return
create_type
(
"double"
)
elif
isinstance
(
expr
,
ResolvedFieldAccess
):
...
...
@@ -316,6 +319,8 @@ def get_type_of_expression(expr):
if
vec_args
:
result
=
VectorType
(
result
,
width
=
vec_args
[
0
].
width
)
return
result
elif
isinstance
(
expr
,
sp
.
Pow
):
return
get_type_of_expression
(
expr
.
args
[
0
])
elif
isinstance
(
expr
,
sp
.
Expr
):
types
=
tuple
(
get_type_of_expression
(
a
)
for
a
in
expr
.
args
)
return
collate_types
(
types
)
...
...
kernelcreation.py
View file @
866e9fc0
...
...
@@ -73,7 +73,7 @@ def create_kernel(assignments, target='cpu', data_type="double", iteration_slice
add_openmp
(
ast
,
num_threads
=
cpu_openmp
)
if
cpu_vectorize_info
:
if
cpu_vectorize_info
is
True
:
vectorize
(
ast
,
vector_
instruction_set
=
'avx'
,
assume_aligned
=
False
,
nontemporal
=
None
)
vectorize
(
ast
,
instruction_set
=
'avx'
,
assume_aligned
=
False
,
nontemporal
=
None
)
elif
isinstance
(
cpu_vectorize_info
,
dict
):
vectorize
(
ast
,
**
cpu_vectorize_info
)
else
:
...
...
llvm/llvm.py
View file @
866e9fc0
...
...
@@ -205,10 +205,17 @@ class LLVMPrinter(Printer):
node
=
self
.
_print
(
conversion
.
args
[
0
])
to_dtype
=
get_type_of_expression
(
conversion
)
from_dtype
=
get_type_of_expression
(
conversion
.
args
[
0
])
if
from_dtype
==
to_dtype
:
return
self
.
_print
(
conversion
.
args
[
0
])
# (From, to)
decision
=
{
(
create_composite_type_from_string
(
"int16"
),
create_composite_type_from_string
(
"int64"
)):
lambda
:
ir
.
Constant
(
self
.
integer
,
node
),
(
create_composite_type_from_string
(
"int"
),
create_composite_type_from_string
(
"double"
)):
functools
.
partial
(
self
.
builder
.
sitofp
,
node
,
self
.
fp_type
),
(
create_composite_type_from_string
(
"int16"
),
create_composite_type_from_string
(
"double"
)):
functools
.
partial
(
self
.
builder
.
sitofp
,
node
,
self
.
fp_type
),
(
create_composite_type_from_string
(
"double"
),
create_composite_type_from_string
(
"int"
)):
functools
.
partial
(
self
.
builder
.
fptosi
,
node
,
self
.
integer
),
(
create_composite_type_from_string
(
"double *"
),
...
...
transformations.py
View file @
866e9fc0
...
...
@@ -8,7 +8,7 @@ from sympy.tensor import IndexedBase
from
pystencils.assignment
import
Assignment
from
pystencils.field
import
Field
,
FieldType
from
pystencils.data_types
import
TypedSymbol
,
PointerType
,
StructType
,
get_base_type
,
cast_func
,
\
pointer_arithmetic_func
,
get_type_of_expression
,
collate_types
pointer_arithmetic_func
,
get_type_of_expression
,
collate_types
,
create_type
from
pystencils.slicing
import
normalize_slice
import
pystencils.astnodes
as
ast
...
...
@@ -716,9 +716,18 @@ class KernelConstraintsCheck:
return
rhs
elif
isinstance
(
rhs
,
sp
.
Symbol
):
return
TypedSymbol
(
symbol_name_to_variable_name
(
rhs
.
name
),
self
.
_type_for_symbol
[
rhs
.
name
])
else
:
new_args
=
[
self
.
process_expression
(
arg
)
for
arg
in
rhs
.
args
]
elif
isinstance
(
rhs
,
sp
.
Number
):
return
cast_func
(
rhs
,
create_type
(
self
.
_type_for_symbol
[
'_constant'
]))
elif
isinstance
(
rhs
,
sp
.
Mul
):
new_args
=
[
self
.
process_expression
(
arg
)
if
arg
not
in
(
-
1
,
1
)
else
arg
for
arg
in
rhs
.
args
]
return
rhs
.
func
(
*
new_args
)
if
new_args
else
rhs
else
:
if
isinstance
(
rhs
,
sp
.
Pow
):
# don't process exponents -> they should remain integers
return
sp
.
Pow
(
self
.
process_expression
(
rhs
.
args
[
0
]),
rhs
.
args
[
1
])
else
:
new_args
=
[
self
.
process_expression
(
arg
)
for
arg
in
rhs
.
args
]
return
rhs
.
func
(
*
new_args
)
if
new_args
else
rhs
@
property
def
fields_written
(
self
):
...
...
@@ -800,10 +809,13 @@ def add_types(eqs, type_for_symbol, check_independence_condition):
def
insert_casts
(
node
):
"""Checks the types and inserts casts and pointer arithmetic where necessary
"""Checks the types and inserts casts and pointer arithmetic where necessary
.
:param node: the head node of the ast
:return: modified ast
Args:
node: the head node of the ast
Returns:
modified AST
"""
def
cast
(
zipped_args_types
,
target_dtype
):
"""
...
...
@@ -839,7 +851,7 @@ def insert_casts(node):
new_args
=
sp
.
Add
(
*
new_args
)
if
len
(
new_args
)
>
0
else
new_args
return
pointer_arithmetic_func
(
pointer
,
new_args
)
if
isinstance
(
node
,
sp
.
AtomicExpr
):
if
isinstance
(
node
,
sp
.
AtomicExpr
)
or
isinstance
(
node
,
cast_func
)
:
return
node
args
=
[]
for
arg
in
node
.
args
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment