Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Jean-Noël Grad
pystencils
Commits
d8f49def
Commit
d8f49def
authored
Oct 04, 2019
by
Martin Bauer
Browse files
First try to get a better const treatment
parent
443527ae
Changes
6
Hide whitespace changes
Inline
Side-by-side
pystencils/astnodes.py
View file @
d8f49def
...
...
@@ -506,7 +506,6 @@ class SympyAssignment(Node):
super
(
SympyAssignment
,
self
).
__init__
(
parent
=
None
)
self
.
_lhs_symbol
=
lhs_symbol
self
.
rhs
=
sp
.
sympify
(
rhs_expr
)
self
.
_is_const
=
is_const
self
.
_is_declaration
=
self
.
__is_declaration
()
def
__is_declaration
(
self
):
...
...
@@ -563,10 +562,6 @@ class SympyAssignment(Node):
def
is_declaration
(
self
):
return
self
.
_is_declaration
@
property
def
is_const
(
self
):
return
self
.
_is_const
def
replace
(
self
,
child
,
replacement
):
if
child
==
self
.
lhs
:
replacement
.
parent
=
self
...
...
pystencils/backends/cbackend.py
View file @
d8f49def
...
...
@@ -225,13 +225,9 @@ class CBackend:
def
_print_SympyAssignment
(
self
,
node
):
if
node
.
is_declaration
:
if
node
.
is_const
:
prefix
=
'const '
else
:
prefix
=
''
data_type
=
prefix
+
self
.
_print
(
node
.
lhs
.
dtype
).
replace
(
' const'
,
''
)
+
" "
return
"%s%s = %s;"
%
(
data_type
,
self
.
sympy_printer
.
doprint
(
node
.
lhs
),
self
.
sympy_printer
.
doprint
(
node
.
rhs
))
data_type
=
self
.
_print
(
node
.
lhs
.
dtype
)
return
"%s %s = %s;"
%
(
data_type
,
self
.
sympy_printer
.
doprint
(
node
.
lhs
),
self
.
sympy_printer
.
doprint
(
node
.
rhs
))
else
:
lhs_type
=
get_type_of_expression
(
node
.
lhs
)
if
type
(
lhs_type
)
is
VectorType
and
isinstance
(
node
.
lhs
,
cast_func
):
...
...
pystencils/cpu/vectorization.py
View file @
d8f49def
...
...
@@ -63,11 +63,11 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'avx',
if
assume_inner_stride_one
:
replace_inner_stride_with_one
(
kernel_ast
)
field_float_dtypes
=
set
(
f
.
dtype
for
f
in
all_fields
if
f
.
dtype
.
is_float
())
field_float_dtypes
=
set
(
f
.
dtype
.
numpy_dtype
for
f
in
all_fields
if
f
.
dtype
.
is_float
())
if
len
(
field_float_dtypes
)
!=
1
:
raise
NotImplementedError
(
"Cannot vectorize kernels that contain accesses "
"to differently typed floating point fields"
)
float_size
=
field_float_dtypes
.
pop
().
numpy_dtype
.
itemsize
float_size
=
field_float_dtypes
.
pop
().
itemsize
assert
float_size
in
(
8
,
4
)
vector_is
=
get_vector_instruction_set
(
'double'
if
float_size
==
8
else
'float'
,
instruction_set
=
instruction_set
)
...
...
@@ -148,7 +148,7 @@ def insert_vector_casts(ast_node):
return
expr
else
:
target_type
=
collate_types
(
arg_types
)
casted_args
=
[
cast_func
(
a
,
target_type
)
if
t
!=
target_type
else
a
casted_args
=
[
cast_func
(
a
,
target_type
)
if
not
t
.
equal_ignoring_const
(
target_type
)
else
a
for
a
,
t
in
zip
(
new_args
,
arg_types
)]
return
expr
.
func
(
*
casted_args
)
elif
expr
.
func
is
sp
.
Pow
:
...
...
@@ -167,11 +167,11 @@ def insert_vector_casts(ast_node):
if
type
(
condition_target_type
)
is
not
VectorType
and
type
(
result_target_type
)
is
VectorType
:
condition_target_type
=
VectorType
(
condition_target_type
,
width
=
result_target_type
.
width
)
casted_results
=
[
cast_func
(
a
,
result_target_type
)
if
t
!=
result_target_type
else
a
casted_results
=
[
cast_func
(
a
,
result_target_type
)
if
not
t
.
equal_ignoring_const
(
result_target_type
)
else
a
for
a
,
t
in
zip
(
new_results
,
types_of_results
)]
casted_conditions
=
[
cast_func
(
a
,
condition_target_type
)
if
t
!=
condition_target_type
and
a
is
not
True
else
a
if
not
t
.
equal_ignoring_const
(
condition_target_type
)
and
a
is
not
True
else
a
for
a
,
t
in
zip
(
new_conditions
,
types_of_conditions
)]
return
sp
.
Piecewise
(
*
[(
r
,
c
)
for
r
,
c
in
zip
(
casted_results
,
casted_conditions
)])
...
...
pystencils/data_types.py
View file @
d8f49def
...
...
@@ -453,7 +453,7 @@ def collate_types(types, forbid_collation_to_float=False):
types
=
tuple
(
t
for
t
in
types
if
t
.
is_float
())
# use numpy collation -> create type from numpy type -> and, put vector type around if necessary
result_numpy_type
=
np
.
result_type
(
*
(
t
.
numpy_dtype
for
t
in
types
))
result
=
BasicType
(
result_numpy_type
)
result
=
BasicType
(
result_numpy_type
,
const
=
any
(
t
.
const
for
t
in
types
)
)
if
vector_type
:
result
=
VectorType
(
result
,
vector_type
[
0
].
width
)
return
result
...
...
@@ -618,6 +618,12 @@ class BasicType(Type):
else
:
return
(
self
.
numpy_dtype
,
self
.
const
)
==
(
other
.
numpy_dtype
,
other
.
const
)
def
equal_ignoring_const
(
self
,
other
):
if
not
isinstance
(
other
,
BasicType
):
return
False
else
:
return
self
.
numpy_dtype
==
other
.
numpy_dtype
def
__hash__
(
self
):
return
hash
(
str
(
self
))
...
...
@@ -643,17 +649,23 @@ class VectorType(Type):
else
:
return
(
self
.
base_type
,
self
.
width
)
==
(
other
.
base_type
,
other
.
width
)
def
equal_ignoring_const
(
self
,
other
):
if
not
isinstance
(
other
,
VectorType
):
return
False
else
:
return
self
.
base_type
.
equal_ignoring_const
(
other
.
base_type
)
def
__str__
(
self
):
if
self
.
instruction_set
is
None
:
return
"%s[%d]"
%
(
self
.
base_type
,
self
.
width
)
else
:
if
self
.
base_type
==
create_type
(
"
int64
"
)
:
if
self
.
base_type
.
numpy_dtype
==
np
.
int64
:
return
self
.
instruction_set
[
'int'
]
elif
self
.
base_type
==
create_type
(
"
float64
"
)
:
elif
self
.
base_type
.
numpy_dtype
==
np
.
float64
:
return
self
.
instruction_set
[
'double'
]
elif
self
.
base_type
==
create_type
(
"
float32
"
)
:
elif
self
.
base_type
.
numpy_dtype
==
np
.
float32
:
return
self
.
instruction_set
[
'float'
]
elif
self
.
base_type
==
create_type
(
"
bool
"
)
:
elif
self
.
base_type
.
numpy_dtype
==
np
.
bool
:
return
self
.
instruction_set
[
'bool'
]
else
:
raise
NotImplementedError
()
...
...
@@ -692,6 +704,12 @@ class PointerType(Type):
else
:
return
(
self
.
base_type
,
self
.
const
,
self
.
restrict
)
==
(
other
.
base_type
,
other
.
const
,
other
.
restrict
)
def
equal_ignoring_const
(
self
,
other
):
if
not
isinstance
(
other
,
PointerType
):
return
False
else
:
return
self
.
base_type
.
equal_ignoring_const
(
other
.
base_type
)
def
__str__
(
self
):
components
=
[
str
(
self
.
base_type
),
'*'
]
if
self
.
restrict
:
...
...
@@ -743,6 +761,12 @@ class StructType:
else
:
return
(
self
.
numpy_dtype
,
self
.
const
)
==
(
other
.
numpy_dtype
,
other
.
const
)
def
equal_ignoring_const
(
self
,
other
):
if
not
isinstance
(
other
,
StructType
):
return
False
else
:
return
self
.
numpy_dtype
==
other
.
numpy_dtype
def
__str__
(
self
):
# structs are handled byte-wise
result
=
"uint8_t"
...
...
pystencils/kernelparameters.py
View file @
d8f49def
...
...
@@ -16,7 +16,7 @@ would reference back to the field.
from
sympy.core.cache
import
cacheit
from
pystencils.data_types
import
(
PointerType
,
TypedSymbol
,
create_composite_type_from_string
,
get_base_type
)
BasicType
,
PointerType
,
TypedSymbol
,
create_composite_type_from_string
,
get_base_type
)
SHAPE_DTYPE
=
create_composite_type_from_string
(
"const int64"
)
STRIDE_DTYPE
=
create_composite_type_from_string
(
"const int64"
)
...
...
@@ -78,7 +78,8 @@ class FieldPointerSymbol(TypedSymbol):
def
__new_stage2__
(
cls
,
field_name
,
field_dtype
,
const
):
name
=
"_data_{name}"
.
format
(
name
=
field_name
)
dtype
=
PointerType
(
get_base_type
(
field_dtype
),
const
=
const
,
restrict
=
True
)
base_type
=
BasicType
(
get_base_type
(
field_dtype
),
const
=
const
)
dtype
=
PointerType
(
base_type
,
const
=
True
,
restrict
=
True
)
obj
=
super
(
FieldPointerSymbol
,
cls
).
__xnew__
(
cls
,
name
,
dtype
)
obj
.
field_name
=
field_name
return
obj
...
...
pystencils/transformations.py
View file @
d8f49def
...
...
@@ -878,7 +878,9 @@ class KernelConstraintsCheck:
assert
isinstance
(
lhs
,
sp
.
Symbol
)
self
.
_update_accesses_lhs
(
lhs
)
if
not
isinstance
(
lhs
,
(
AbstractField
.
AbstractAccess
,
TypedSymbol
)):
return
TypedSymbol
(
lhs
.
name
,
self
.
_type_for_symbol
[
lhs
.
name
])
dtype
=
create_type
(
self
.
_type_for_symbol
[
lhs
.
name
])
dtype
.
const
=
True
return
TypedSymbol
(
lhs
.
name
,
dtype
)
else
:
return
lhs
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment