Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Frederik Hennig
pystencils
Commits
d8f49def
Commit
d8f49def
authored
Oct 04, 2019
by
Martin Bauer
Browse files
First try to get a better const treatment
parent
443527ae
Changes
6
Hide whitespace changes
Inline
Side-by-side
pystencils/astnodes.py
View file @
d8f49def
...
...
@@ -506,7 +506,6 @@ class SympyAssignment(Node):
super
(
SympyAssignment
,
self
).
__init__
(
parent
=
None
)
self
.
_lhs_symbol
=
lhs_symbol
self
.
rhs
=
sp
.
sympify
(
rhs_expr
)
self
.
_is_const
=
is_const
self
.
_is_declaration
=
self
.
__is_declaration
()
def
__is_declaration
(
self
):
...
...
@@ -563,10 +562,6 @@ class SympyAssignment(Node):
def
is_declaration
(
self
):
return
self
.
_is_declaration
@
property
def
is_const
(
self
):
return
self
.
_is_const
def
replace
(
self
,
child
,
replacement
):
if
child
==
self
.
lhs
:
replacement
.
parent
=
self
...
...
pystencils/backends/cbackend.py
View file @
d8f49def
...
...
@@ -225,13 +225,9 @@ class CBackend:
def
_print_SympyAssignment
(
self
,
node
):
if
node
.
is_declaration
:
if
node
.
is_const
:
prefix
=
'const '
else
:
prefix
=
''
data_type
=
prefix
+
self
.
_print
(
node
.
lhs
.
dtype
).
replace
(
' const'
,
''
)
+
" "
return
"%s%s = %s;"
%
(
data_type
,
self
.
sympy_printer
.
doprint
(
node
.
lhs
),
self
.
sympy_printer
.
doprint
(
node
.
rhs
))
data_type
=
self
.
_print
(
node
.
lhs
.
dtype
)
return
"%s %s = %s;"
%
(
data_type
,
self
.
sympy_printer
.
doprint
(
node
.
lhs
),
self
.
sympy_printer
.
doprint
(
node
.
rhs
))
else
:
lhs_type
=
get_type_of_expression
(
node
.
lhs
)
if
type
(
lhs_type
)
is
VectorType
and
isinstance
(
node
.
lhs
,
cast_func
):
...
...
pystencils/cpu/vectorization.py
View file @
d8f49def
...
...
@@ -63,11 +63,11 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'avx',
if
assume_inner_stride_one
:
replace_inner_stride_with_one
(
kernel_ast
)
field_float_dtypes
=
set
(
f
.
dtype
for
f
in
all_fields
if
f
.
dtype
.
is_float
())
field_float_dtypes
=
set
(
f
.
dtype
.
numpy_dtype
for
f
in
all_fields
if
f
.
dtype
.
is_float
())
if
len
(
field_float_dtypes
)
!=
1
:
raise
NotImplementedError
(
"Cannot vectorize kernels that contain accesses "
"to differently typed floating point fields"
)
float_size
=
field_float_dtypes
.
pop
().
numpy_dtype
.
itemsize
float_size
=
field_float_dtypes
.
pop
().
itemsize
assert
float_size
in
(
8
,
4
)
vector_is
=
get_vector_instruction_set
(
'double'
if
float_size
==
8
else
'float'
,
instruction_set
=
instruction_set
)
...
...
@@ -148,7 +148,7 @@ def insert_vector_casts(ast_node):
return
expr
else
:
target_type
=
collate_types
(
arg_types
)
casted_args
=
[
cast_func
(
a
,
target_type
)
if
t
!=
target_type
else
a
casted_args
=
[
cast_func
(
a
,
target_type
)
if
not
t
.
equal_ignoring_const
(
target_type
)
else
a
for
a
,
t
in
zip
(
new_args
,
arg_types
)]
return
expr
.
func
(
*
casted_args
)
elif
expr
.
func
is
sp
.
Pow
:
...
...
@@ -167,11 +167,11 @@ def insert_vector_casts(ast_node):
if
type
(
condition_target_type
)
is
not
VectorType
and
type
(
result_target_type
)
is
VectorType
:
condition_target_type
=
VectorType
(
condition_target_type
,
width
=
result_target_type
.
width
)
casted_results
=
[
cast_func
(
a
,
result_target_type
)
if
t
!=
result_target_type
else
a
casted_results
=
[
cast_func
(
a
,
result_target_type
)
if
not
t
.
equal_ignoring_const
(
result_target_type
)
else
a
for
a
,
t
in
zip
(
new_results
,
types_of_results
)]
casted_conditions
=
[
cast_func
(
a
,
condition_target_type
)
if
t
!=
condition_target_type
and
a
is
not
True
else
a
if
not
t
.
equal_ignoring_const
(
condition_target_type
)
and
a
is
not
True
else
a
for
a
,
t
in
zip
(
new_conditions
,
types_of_conditions
)]
return
sp
.
Piecewise
(
*
[(
r
,
c
)
for
r
,
c
in
zip
(
casted_results
,
casted_conditions
)])
...
...
pystencils/data_types.py
View file @
d8f49def
...
...
@@ -453,7 +453,7 @@ def collate_types(types, forbid_collation_to_float=False):
types
=
tuple
(
t
for
t
in
types
if
t
.
is_float
())
# use numpy collation -> create type from numpy type -> and, put vector type around if necessary
result_numpy_type
=
np
.
result_type
(
*
(
t
.
numpy_dtype
for
t
in
types
))
result
=
BasicType
(
result_numpy_type
)
result
=
BasicType
(
result_numpy_type
,
const
=
any
(
t
.
const
for
t
in
types
)
)
if
vector_type
:
result
=
VectorType
(
result
,
vector_type
[
0
].
width
)
return
result
...
...
@@ -618,6 +618,12 @@ class BasicType(Type):
else
:
return
(
self
.
numpy_dtype
,
self
.
const
)
==
(
other
.
numpy_dtype
,
other
.
const
)
def
equal_ignoring_const
(
self
,
other
):
if
not
isinstance
(
other
,
BasicType
):
return
False
else
:
return
self
.
numpy_dtype
==
other
.
numpy_dtype
def
__hash__
(
self
):
return
hash
(
str
(
self
))
...
...
@@ -643,17 +649,23 @@ class VectorType(Type):
else
:
return
(
self
.
base_type
,
self
.
width
)
==
(
other
.
base_type
,
other
.
width
)
def
equal_ignoring_const
(
self
,
other
):
if
not
isinstance
(
other
,
VectorType
):
return
False
else
:
return
self
.
base_type
.
equal_ignoring_const
(
other
.
base_type
)
def
__str__
(
self
):
if
self
.
instruction_set
is
None
:
return
"%s[%d]"
%
(
self
.
base_type
,
self
.
width
)
else
:
if
self
.
base_type
==
create_type
(
"
int64
"
)
:
if
self
.
base_type
.
numpy_dtype
==
np
.
int64
:
return
self
.
instruction_set
[
'int'
]
elif
self
.
base_type
==
create_type
(
"
float64
"
)
:
elif
self
.
base_type
.
numpy_dtype
==
np
.
float64
:
return
self
.
instruction_set
[
'double'
]
elif
self
.
base_type
==
create_type
(
"
float32
"
)
:
elif
self
.
base_type
.
numpy_dtype
==
np
.
float32
:
return
self
.
instruction_set
[
'float'
]
elif
self
.
base_type
==
create_type
(
"
bool
"
)
:
elif
self
.
base_type
.
numpy_dtype
==
np
.
bool
:
return
self
.
instruction_set
[
'bool'
]
else
:
raise
NotImplementedError
()
...
...
@@ -692,6 +704,12 @@ class PointerType(Type):
else
:
return
(
self
.
base_type
,
self
.
const
,
self
.
restrict
)
==
(
other
.
base_type
,
other
.
const
,
other
.
restrict
)
def
equal_ignoring_const
(
self
,
other
):
if
not
isinstance
(
other
,
PointerType
):
return
False
else
:
return
self
.
base_type
.
equal_ignoring_const
(
other
.
base_type
)
def
__str__
(
self
):
components
=
[
str
(
self
.
base_type
),
'*'
]
if
self
.
restrict
:
...
...
@@ -743,6 +761,12 @@ class StructType:
else
:
return
(
self
.
numpy_dtype
,
self
.
const
)
==
(
other
.
numpy_dtype
,
other
.
const
)
def
equal_ignoring_const
(
self
,
other
):
if
not
isinstance
(
other
,
StructType
):
return
False
else
:
return
self
.
numpy_dtype
==
other
.
numpy_dtype
def
__str__
(
self
):
# structs are handled byte-wise
result
=
"uint8_t"
...
...
pystencils/kernelparameters.py
View file @
d8f49def
...
...
@@ -16,7 +16,7 @@ would reference back to the field.
from
sympy.core.cache
import
cacheit
from
pystencils.data_types
import
(
PointerType
,
TypedSymbol
,
create_composite_type_from_string
,
get_base_type
)
BasicType
,
PointerType
,
TypedSymbol
,
create_composite_type_from_string
,
get_base_type
)
SHAPE_DTYPE
=
create_composite_type_from_string
(
"const int64"
)
STRIDE_DTYPE
=
create_composite_type_from_string
(
"const int64"
)
...
...
@@ -78,7 +78,8 @@ class FieldPointerSymbol(TypedSymbol):
def
__new_stage2__
(
cls
,
field_name
,
field_dtype
,
const
):
name
=
"_data_{name}"
.
format
(
name
=
field_name
)
dtype
=
PointerType
(
get_base_type
(
field_dtype
),
const
=
const
,
restrict
=
True
)
base_type
=
BasicType
(
get_base_type
(
field_dtype
),
const
=
const
)
dtype
=
PointerType
(
base_type
,
const
=
True
,
restrict
=
True
)
obj
=
super
(
FieldPointerSymbol
,
cls
).
__xnew__
(
cls
,
name
,
dtype
)
obj
.
field_name
=
field_name
return
obj
...
...
pystencils/transformations.py
View file @
d8f49def
...
...
@@ -878,7 +878,9 @@ class KernelConstraintsCheck:
assert
isinstance
(
lhs
,
sp
.
Symbol
)
self
.
_update_accesses_lhs
(
lhs
)
if
not
isinstance
(
lhs
,
(
AbstractField
.
AbstractAccess
,
TypedSymbol
)):
return
TypedSymbol
(
lhs
.
name
,
self
.
_type_for_symbol
[
lhs
.
name
])
dtype
=
create_type
(
self
.
_type_for_symbol
[
lhs
.
name
])
dtype
.
const
=
True
return
TypedSymbol
(
lhs
.
name
,
dtype
)
else
:
return
lhs
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment