diff --git a/cpu/cpujit.py b/cpu/cpujit.py index 807004825a1fdbdb317b2d534c2320187a32e03e..c083a7fb3e5d07e2ca8862cbcd644cf8f98284ed 100644 --- a/cpu/cpujit.py +++ b/cpu/cpujit.py @@ -248,7 +248,7 @@ template_extract_array = """ PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}"); if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }}; Py_buffer buffer_{name}; -int buffer_{name}_res = PyObject_GetBuffer(obj_{name}, &buffer_{name}, PyBUF_STRIDES | PyBUF_WRITABLE); +int buffer_{name}_res = PyObject_GetBuffer(obj_{name}, &buffer_{name}, PyBUF_STRIDES | PyBUF_WRITABLE | PyBUF_FORMAT); if (buffer_{name}_res == -1) {{ return NULL; }} """ @@ -333,26 +333,38 @@ def create_function_boilerplate_code(parameter_info, name, insert_checks=True): post_call_code += template_release_buffer.format(name=field.name) parameters.append("({dtype} *)buffer_{name}.buf".format(dtype=str(field.dtype), name=field.name)) - if insert_checks and field.has_fixed_shape: - shape_cond = ["buffer_{name}.shape[{i}] == {s}".format(s=s, name=field.name, i=i) - for i, s in enumerate(field.spatial_shape)] - shape_cond = " && ".join(shape_cond) - pre_call_code += template_check_array.format(cond=shape_cond, what="shape", name=field.name, - expected=str(field.shape)) - - item_size = field.dtype.numpy_dtype.itemsize - expected_strides = [e * item_size for e in field.spatial_strides] - stride_check_code = "(buffer_{name}.strides[{i}] == {s} || buffer_{name}.shape[{i}]<=1)" - strides_cond = " && ".join([stride_check_code.format(s=s, i=i, name=field.name) - for i, s in enumerate(expected_strides)]) - pre_call_code += template_check_array.format(cond=strides_cond, what="strides", name=field.name, - expected=str(expected_strides)) - - if insert_checks and not field.has_fixed_shape: - if FieldType.is_generic(field): - variable_sized_normal_fields.add(field) - elif FieldType.is_indexed(field): - variable_sized_index_fields.add(field) + if insert_checks: + np_dtype = field.dtype.numpy_dtype + item_size = np_dtype.itemsize + + if np_dtype.isbuiltin and FieldType.is_generic(field): + dtype_cond = "buffer_{name}.format[0] == '{format}'".format(name=field.name, + format=field.dtype.numpy_dtype.char) + pre_call_code += template_check_array.format(cond=dtype_cond, what="data type", name=field.name, + expected=str(field.dtype.numpy_dtype)) + + item_size_cond = "buffer_{name}.itemsize == {size}".format(name=field.name, size=item_size) + pre_call_code += template_check_array.format(cond=item_size_cond, what="itemsize", name=field.name, + expected=item_size) + + if field.has_fixed_shape: + shape_cond = ["buffer_{name}.shape[{i}] == {s}".format(s=s, name=field.name, i=i) + for i, s in enumerate(field.spatial_shape)] + shape_cond = " && ".join(shape_cond) + pre_call_code += template_check_array.format(cond=shape_cond, what="shape", name=field.name, + expected=str(field.shape)) + + expected_strides = [e * item_size for e in field.spatial_strides] + stride_check_code = "(buffer_{name}.strides[{i}] == {s} || buffer_{name}.shape[{i}]<=1)" + strides_cond = " && ".join([stride_check_code.format(s=s, i=i, name=field.name) + for i, s in enumerate(expected_strides)]) + pre_call_code += template_check_array.format(cond=strides_cond, what="strides", name=field.name, + expected=str(expected_strides)) + else: + if FieldType.is_generic(field): + variable_sized_normal_fields.add(field) + elif FieldType.is_indexed(field): + variable_sized_index_fields.add(field) elif param.is_field_stride: field = param.fields[0] item_size = field.dtype.numpy_dtype.itemsize