diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py index 7f90f62cef21da224622cb2c1583d46bf975335c..f414b953e85b0d57069839808074ac417f144dc6 100644 --- a/src/pystencils/codegen/driver.py +++ b/src/pystencils/codegen/driver.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, replace from .target import Target from .config import CreateKernelConfig, OpenMpConfig, VectorizationConfig, AUTO from .kernel import Kernel, GpuKernel, GpuThreadsRange -from .properties import PsSymbolProperty, FieldShape, FieldStride, FieldBasePtr +from .properties import PsSymbolProperty, FieldShape, FieldStride, FieldBasePtr, ReductionPointerVariable from .parameters import Parameter from ..backend.ast.expressions import PsSymbolExpr, PsMemAcc, PsConstantExpr @@ -461,7 +461,8 @@ def _get_function_params( props: set[PsSymbolProperty] = set() for prop in symb.properties: match prop: - # TODO: how to export reduction result (via pointer)? + case ReductionPointerVariable(): + props.add(prop) case FieldShape() | FieldStride(): props.add(prop) case BufferBasePtr(buf): diff --git a/src/pystencils/jit/cpu_extension_module.py b/src/pystencils/jit/cpu_extension_module.py index befb033e6f7969a5ffd9bc7742e9e7ab691da47d..c2c969eaae1057f7892150d10124f11d49e6d060 100644 --- a/src/pystencils/jit/cpu_extension_module.py +++ b/src/pystencils/jit/cpu_extension_module.py @@ -13,7 +13,7 @@ from ..codegen import ( Kernel, Parameter, ) -from ..codegen.properties import FieldBasePtr, FieldShape, FieldStride +from ..codegen.properties import FieldBasePtr, FieldShape, FieldStride, ReductionPointerVariable from ..types import ( PsType, PsUnsignedIntegerType, @@ -265,7 +265,10 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ return self._array_buffers[field] def extract_scalar(self, param: Parameter) -> str: - if param not in self._scalar_extractions: + if any(isinstance(e, ReductionPointerVariable) for e in param.properties): + # TODO: implement + pass + elif param not in self._scalar_extractions: extract_func = self._scalar_extractor(param.dtype) code = self.TMPL_EXTRACT_SCALAR.format( name=param.name,