From 3e0daa67359c7ddc17264b7fd21aa0a0429552e5 Mon Sep 17 00:00:00 2001 From: zy69guqi <richard.angersbach@fau.de> Date: Wed, 22 Jan 2025 16:01:00 +0100 Subject: [PATCH] Propagate properties of reduction pointer symbols to kernel parameters --- src/pystencils/codegen/driver.py | 5 +++-- src/pystencils/jit/cpu_extension_module.py | 7 +++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py index 7f90f62ce..f414b953e 100644 --- a/src/pystencils/codegen/driver.py +++ b/src/pystencils/codegen/driver.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, replace from .target import Target from .config import CreateKernelConfig, OpenMpConfig, VectorizationConfig, AUTO from .kernel import Kernel, GpuKernel, GpuThreadsRange -from .properties import PsSymbolProperty, FieldShape, FieldStride, FieldBasePtr +from .properties import PsSymbolProperty, FieldShape, FieldStride, FieldBasePtr, ReductionPointerVariable from .parameters import Parameter from ..backend.ast.expressions import PsSymbolExpr, PsMemAcc, PsConstantExpr @@ -461,7 +461,8 @@ def _get_function_params( props: set[PsSymbolProperty] = set() for prop in symb.properties: match prop: - # TODO: how to export reduction result (via pointer)? + case ReductionPointerVariable(): + props.add(prop) case FieldShape() | FieldStride(): props.add(prop) case BufferBasePtr(buf): diff --git a/src/pystencils/jit/cpu_extension_module.py b/src/pystencils/jit/cpu_extension_module.py index befb033e6..c2c969eaa 100644 --- a/src/pystencils/jit/cpu_extension_module.py +++ b/src/pystencils/jit/cpu_extension_module.py @@ -13,7 +13,7 @@ from ..codegen import ( Kernel, Parameter, ) -from ..codegen.properties import FieldBasePtr, FieldShape, FieldStride +from ..codegen.properties import FieldBasePtr, FieldShape, FieldStride, ReductionPointerVariable from ..types import ( PsType, PsUnsignedIntegerType, @@ -265,7 +265,10 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ return self._array_buffers[field] def extract_scalar(self, param: Parameter) -> str: - if param not in self._scalar_extractions: + if any(isinstance(e, ReductionPointerVariable) for e in param.properties): + # TODO: implement + pass + elif param not in self._scalar_extractions: extract_func = self._scalar_extractor(param.dtype) code = self.TMPL_EXTRACT_SCALAR.format( name=param.name, -- GitLab