Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Jonas Plewinski
pystencils
Commits
da7adf35
Commit
da7adf35
authored
Dec 07, 2018
by
Martin Bauer
Browse files
CPU backend now tests if the data type of array parameters is correct
parent
2e42e5ba
Changes
1
Hide whitespace changes
Inline
Side-by-side
cpu/cpujit.py
View file @
da7adf35
...
...
@@ -248,7 +248,7 @@ template_extract_array = """
PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
Py_buffer buffer_{name};
int buffer_{name}_res = PyObject_GetBuffer(obj_{name}, &buffer_{name}, PyBUF_STRIDES | PyBUF_WRITABLE);
int buffer_{name}_res = PyObject_GetBuffer(obj_{name}, &buffer_{name}, PyBUF_STRIDES | PyBUF_WRITABLE
| PyBUF_FORMAT
);
if (buffer_{name}_res == -1) {{ return NULL; }}
"""
...
...
@@ -333,26 +333,38 @@ def create_function_boilerplate_code(parameter_info, name, insert_checks=True):
post_call_code
+=
template_release_buffer
.
format
(
name
=
field
.
name
)
parameters
.
append
(
"({dtype} *)buffer_{name}.buf"
.
format
(
dtype
=
str
(
field
.
dtype
),
name
=
field
.
name
))
if
insert_checks
and
field
.
has_fixed_shape
:
shape_cond
=
[
"buffer_{name}.shape[{i}] == {s}"
.
format
(
s
=
s
,
name
=
field
.
name
,
i
=
i
)
for
i
,
s
in
enumerate
(
field
.
spatial_shape
)]
shape_cond
=
" && "
.
join
(
shape_cond
)
pre_call_code
+=
template_check_array
.
format
(
cond
=
shape_cond
,
what
=
"shape"
,
name
=
field
.
name
,
expected
=
str
(
field
.
shape
))
item_size
=
field
.
dtype
.
numpy_dtype
.
itemsize
expected_strides
=
[
e
*
item_size
for
e
in
field
.
spatial_strides
]
stride_check_code
=
"(buffer_{name}.strides[{i}] == {s} || buffer_{name}.shape[{i}]<=1)"
strides_cond
=
" && "
.
join
([
stride_check_code
.
format
(
s
=
s
,
i
=
i
,
name
=
field
.
name
)
for
i
,
s
in
enumerate
(
expected_strides
)])
pre_call_code
+=
template_check_array
.
format
(
cond
=
strides_cond
,
what
=
"strides"
,
name
=
field
.
name
,
expected
=
str
(
expected_strides
))
if
insert_checks
and
not
field
.
has_fixed_shape
:
if
FieldType
.
is_generic
(
field
):
variable_sized_normal_fields
.
add
(
field
)
elif
FieldType
.
is_indexed
(
field
):
variable_sized_index_fields
.
add
(
field
)
if
insert_checks
:
np_dtype
=
field
.
dtype
.
numpy_dtype
item_size
=
np_dtype
.
itemsize
if
np_dtype
.
isbuiltin
and
FieldType
.
is_generic
(
field
):
dtype_cond
=
"buffer_{name}.format[0] == '{format}'"
.
format
(
name
=
field
.
name
,
format
=
field
.
dtype
.
numpy_dtype
.
char
)
pre_call_code
+=
template_check_array
.
format
(
cond
=
dtype_cond
,
what
=
"data type"
,
name
=
field
.
name
,
expected
=
str
(
field
.
dtype
.
numpy_dtype
))
item_size_cond
=
"buffer_{name}.itemsize == {size}"
.
format
(
name
=
field
.
name
,
size
=
item_size
)
pre_call_code
+=
template_check_array
.
format
(
cond
=
item_size_cond
,
what
=
"itemsize"
,
name
=
field
.
name
,
expected
=
item_size
)
if
field
.
has_fixed_shape
:
shape_cond
=
[
"buffer_{name}.shape[{i}] == {s}"
.
format
(
s
=
s
,
name
=
field
.
name
,
i
=
i
)
for
i
,
s
in
enumerate
(
field
.
spatial_shape
)]
shape_cond
=
" && "
.
join
(
shape_cond
)
pre_call_code
+=
template_check_array
.
format
(
cond
=
shape_cond
,
what
=
"shape"
,
name
=
field
.
name
,
expected
=
str
(
field
.
shape
))
expected_strides
=
[
e
*
item_size
for
e
in
field
.
spatial_strides
]
stride_check_code
=
"(buffer_{name}.strides[{i}] == {s} || buffer_{name}.shape[{i}]<=1)"
strides_cond
=
" && "
.
join
([
stride_check_code
.
format
(
s
=
s
,
i
=
i
,
name
=
field
.
name
)
for
i
,
s
in
enumerate
(
expected_strides
)])
pre_call_code
+=
template_check_array
.
format
(
cond
=
strides_cond
,
what
=
"strides"
,
name
=
field
.
name
,
expected
=
str
(
expected_strides
))
else
:
if
FieldType
.
is_generic
(
field
):
variable_sized_normal_fields
.
add
(
field
)
elif
FieldType
.
is_indexed
(
field
):
variable_sized_index_fields
.
add
(
field
)
elif
param
.
is_field_stride
:
field
=
param
.
fields
[
0
]
item_size
=
field
.
dtype
.
numpy_dtype
.
itemsize
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment