From 3e0daa67359c7ddc17264b7fd21aa0a0429552e5 Mon Sep 17 00:00:00 2001
From: zy69guqi <richard.angersbach@fau.de>
Date: Wed, 22 Jan 2025 16:01:00 +0100
Subject: [PATCH] Propagate properties of reduction pointer symbols to kernel
 parameters

---
 src/pystencils/codegen/driver.py           | 5 +++--
 src/pystencils/jit/cpu_extension_module.py | 7 +++++--
 2 files changed, 8 insertions(+), 4 deletions(-)

diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py
index 7f90f62ce..f414b953e 100644
--- a/src/pystencils/codegen/driver.py
+++ b/src/pystencils/codegen/driver.py
@@ -5,7 +5,7 @@ from dataclasses import dataclass, replace
 from .target import Target
 from .config import CreateKernelConfig, OpenMpConfig, VectorizationConfig, AUTO
 from .kernel import Kernel, GpuKernel, GpuThreadsRange
-from .properties import PsSymbolProperty, FieldShape, FieldStride, FieldBasePtr
+from .properties import PsSymbolProperty, FieldShape, FieldStride, FieldBasePtr, ReductionPointerVariable
 from .parameters import Parameter
 from ..backend.ast.expressions import PsSymbolExpr, PsMemAcc, PsConstantExpr
 
@@ -461,7 +461,8 @@ def _get_function_params(
         props: set[PsSymbolProperty] = set()
         for prop in symb.properties:
             match prop:
-                # TODO: how to export reduction result (via pointer)?
+                case ReductionPointerVariable():
+                    props.add(prop)
                 case FieldShape() | FieldStride():
                     props.add(prop)
                 case BufferBasePtr(buf):
diff --git a/src/pystencils/jit/cpu_extension_module.py b/src/pystencils/jit/cpu_extension_module.py
index befb033e6..c2c969eaa 100644
--- a/src/pystencils/jit/cpu_extension_module.py
+++ b/src/pystencils/jit/cpu_extension_module.py
@@ -13,7 +13,7 @@ from ..codegen import (
     Kernel,
     Parameter,
 )
-from ..codegen.properties import FieldBasePtr, FieldShape, FieldStride
+from ..codegen.properties import FieldBasePtr, FieldShape, FieldStride, ReductionPointerVariable
 from ..types import (
     PsType,
     PsUnsignedIntegerType,
@@ -265,7 +265,10 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
         return self._array_buffers[field]
 
     def extract_scalar(self, param: Parameter) -> str:
-        if param not in self._scalar_extractions:
+        if any(isinstance(e, ReductionPointerVariable) for e in param.properties):
+            # TODO: implement
+            pass
+        elif param not in self._scalar_extractions:
             extract_func = self._scalar_extractor(param.dtype)
             code = self.TMPL_EXTRACT_SCALAR.format(
                 name=param.name,
-- 
GitLab