diff --git a/src/pystencils/jit/cpu_extension_module.py b/src/pystencils/jit/cpu_extension_module.py index f9c04200c249594babede2bab3ff79d59e909045..d8d90c924ca044ec972541ea29002b17a5b36577 100644 --- a/src/pystencils/jit/cpu_extension_module.py +++ b/src/pystencils/jit/cpu_extension_module.py @@ -199,9 +199,9 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ """ def __init__(self) -> None: - self._array_buffers: dict[Field, str] = dict() - self._array_extractions: dict[Field, str] = dict() - self._array_frees: dict[Field, str] = dict() + self._array_buffers: dict[Any, str] = dict() + self._array_extractions: dict[Any, str] = dict() + self._array_frees: dict[Any, str] = dict() self._array_assoc_var_extractions: dict[Parameter, str] = dict() self._scalar_extractions: dict[Parameter, str] = dict() @@ -235,36 +235,37 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ else: return None - def extract_field(self, field: Field) -> str: + def extract_buffer(self, buffer: Any, name: str, dtype: PsType) -> str: """Adds an array, and returns the name of the underlying Py_Buffer.""" - if field not in self._array_extractions: - extraction_code = self.TMPL_EXTRACT_ARRAY.format(name=field.name) + if buffer not in self._array_extractions: + extraction_code = self.TMPL_EXTRACT_ARRAY.format(name=name) # Check array type - type_char = self._type_char(field.dtype) + type_char = self._type_char(dtype) if type_char is not None: - dtype_cond = f"buffer_{field.name}.format[0] == '{type_char}'" + dtype_cond = f"buffer_{name}.format[0] == '{type_char}'" extraction_code += self.TMPL_CHECK_ARRAY_TYPE.format( cond=dtype_cond, what="data type", - name=field.name, - expected=str(field.dtype), + name=name, + expected=str(dtype), ) # Check item size - itemsize = field.dtype.itemsize - item_size_cond = f"buffer_{field.name}.itemsize == {itemsize}" - extraction_code += self.TMPL_CHECK_ARRAY_TYPE.format( - cond=item_size_cond, what="itemsize", name=field.name, expected=itemsize - ) + itemsize = dtype.itemsize + if itemsize is not None: # itemsize of pointer not known (TODO?) + item_size_cond = f"buffer_{name}.itemsize == {itemsize}" + extraction_code += self.TMPL_CHECK_ARRAY_TYPE.format( + cond=item_size_cond, what="itemsize", name=name, expected=itemsize + ) - self._array_buffers[field] = f"buffer_{field.name}" - self._array_extractions[field] = extraction_code + self._array_buffers[buffer] = f"buffer_{name}" + self._array_extractions[buffer] = extraction_code - release_code = f"PyBuffer_Release(&buffer_{field.name});" - self._array_frees[field] = release_code + release_code = f"PyBuffer_Release(&buffer_{name});" + self._array_frees[buffer] = release_code - return self._array_buffers[field] + return self._array_buffers[buffer] def extract_scalar(self, param: Parameter) -> str: if param not in self._scalar_extractions: @@ -280,14 +281,20 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ def extract_reduction_ptr(self, param: Parameter) -> str: if param not in self._reduction_ptrs: - # TODO: implement - pass + ptr = param.reduction_pointer + buffer = self.extract_buffer(ptr, param.name, param.dtype) + code = f"{param.dtype.c_string()} {param.name} = ({param.dtype}) {buffer}.buf;" + + assert code is not None + + self._array_assoc_var_extractions[param] = code + return param.name def extract_array_assoc_var(self, param: Parameter) -> str: if param not in self._array_assoc_var_extractions: field = param.fields[0] - buffer = self.extract_field(field) + buffer = self.extract_buffer(field, field.name, field.dtype) code: str | None = None for prop in param.properties: