Commit da7adf35 authored by Martin Bauer's avatar Martin Bauer
Browse files

CPU backend now tests if the data type of array parameters is correct

parent 2e42e5ba
......@@ -248,7 +248,7 @@ template_extract_array = """
PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
Py_buffer buffer_{name};
int buffer_{name}_res = PyObject_GetBuffer(obj_{name}, &buffer_{name}, PyBUF_STRIDES | PyBUF_WRITABLE);
int buffer_{name}_res = PyObject_GetBuffer(obj_{name}, &buffer_{name}, PyBUF_STRIDES | PyBUF_WRITABLE | PyBUF_FORMAT);
if (buffer_{name}_res == -1) {{ return NULL; }}
"""
......@@ -333,26 +333,38 @@ def create_function_boilerplate_code(parameter_info, name, insert_checks=True):
post_call_code += template_release_buffer.format(name=field.name)
parameters.append("({dtype} *)buffer_{name}.buf".format(dtype=str(field.dtype), name=field.name))
if insert_checks and field.has_fixed_shape:
shape_cond = ["buffer_{name}.shape[{i}] == {s}".format(s=s, name=field.name, i=i)
for i, s in enumerate(field.spatial_shape)]
shape_cond = " && ".join(shape_cond)
pre_call_code += template_check_array.format(cond=shape_cond, what="shape", name=field.name,
expected=str(field.shape))
item_size = field.dtype.numpy_dtype.itemsize
expected_strides = [e * item_size for e in field.spatial_strides]
stride_check_code = "(buffer_{name}.strides[{i}] == {s} || buffer_{name}.shape[{i}]<=1)"
strides_cond = " && ".join([stride_check_code.format(s=s, i=i, name=field.name)
for i, s in enumerate(expected_strides)])
pre_call_code += template_check_array.format(cond=strides_cond, what="strides", name=field.name,
expected=str(expected_strides))
if insert_checks and not field.has_fixed_shape:
if FieldType.is_generic(field):
variable_sized_normal_fields.add(field)
elif FieldType.is_indexed(field):
variable_sized_index_fields.add(field)
if insert_checks:
np_dtype = field.dtype.numpy_dtype
item_size = np_dtype.itemsize
if np_dtype.isbuiltin and FieldType.is_generic(field):
dtype_cond = "buffer_{name}.format[0] == '{format}'".format(name=field.name,
format=field.dtype.numpy_dtype.char)
pre_call_code += template_check_array.format(cond=dtype_cond, what="data type", name=field.name,
expected=str(field.dtype.numpy_dtype))
item_size_cond = "buffer_{name}.itemsize == {size}".format(name=field.name, size=item_size)
pre_call_code += template_check_array.format(cond=item_size_cond, what="itemsize", name=field.name,
expected=item_size)
if field.has_fixed_shape:
shape_cond = ["buffer_{name}.shape[{i}] == {s}".format(s=s, name=field.name, i=i)
for i, s in enumerate(field.spatial_shape)]
shape_cond = " && ".join(shape_cond)
pre_call_code += template_check_array.format(cond=shape_cond, what="shape", name=field.name,
expected=str(field.shape))
expected_strides = [e * item_size for e in field.spatial_strides]
stride_check_code = "(buffer_{name}.strides[{i}] == {s} || buffer_{name}.shape[{i}]<=1)"
strides_cond = " && ".join([stride_check_code.format(s=s, i=i, name=field.name)
for i, s in enumerate(expected_strides)])
pre_call_code += template_check_array.format(cond=strides_cond, what="strides", name=field.name,
expected=str(expected_strides))
else:
if FieldType.is_generic(field):
variable_sized_normal_fields.add(field)
elif FieldType.is_indexed(field):
variable_sized_index_fields.add(field)
elif param.is_field_stride:
field = param.fields[0]
item_size = field.dtype.numpy_dtype.itemsize
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment