From f1c556e6f93d5fa042e12e8a0a9c57f3bdea47b7 Mon Sep 17 00:00:00 2001
From: zy69guqi <richard.angersbach@fau.de>
Date: Wed, 22 Jan 2025 16:30:43 +0100
Subject: [PATCH] Integrate reduction pointers to parameters.py

---
 src/pystencils/codegen/parameters.py       | 16 ++++++++++++++--
 src/pystencils/jit/cpu_extension_module.py | 17 ++++++++++++-----
 2 files changed, 26 insertions(+), 7 deletions(-)

diff --git a/src/pystencils/codegen/parameters.py b/src/pystencils/codegen/parameters.py
index d8411266e..094553517 100644
--- a/src/pystencils/codegen/parameters.py
+++ b/src/pystencils/codegen/parameters.py
@@ -1,14 +1,14 @@
 from __future__ import annotations
 
 from warnings import warn
-from typing import Sequence, Iterable
+from typing import Sequence, Iterable, Optional
 
 from .properties import (
     PsSymbolProperty,
     _FieldProperty,
     FieldShape,
     FieldStride,
-    FieldBasePtr,
+    FieldBasePtr, ReductionPointerVariable,
 )
 from ..types import PsType
 from ..field import Field
@@ -39,6 +39,9 @@ class Parameter:
                 key=lambda f: f.name,
             )
         )
+        self._reduction_ptr: Optional[ReductionPointerVariable] = next(
+            (e for e in self._properties if isinstance(e, ReductionPointerVariable)), None
+        )
 
     @property
     def name(self):
@@ -79,6 +82,11 @@ class Parameter:
         """Set of fields associated with this parameter."""
         return self._fields
 
+    @property
+    def reduction_pointer(self) -> Optional[ReductionPointerVariable]:
+        """Reduction pointer associated with this parameter."""
+        return self._reduction_ptr
+
     def get_properties(
         self, prop_type: type[PsSymbolProperty] | tuple[type[PsSymbolProperty], ...]
     ) -> set[PsSymbolProperty]:
@@ -105,6 +113,10 @@ class Parameter:
         )
         return bool(self.get_properties(FieldBasePtr))
 
+    @property
+    def is_reduction_pointer(self) -> bool:
+        return bool(self._reduction_ptr)
+
     @property
     def is_field_stride(self) -> bool:  # pragma: no cover
         warn(
diff --git a/src/pystencils/jit/cpu_extension_module.py b/src/pystencils/jit/cpu_extension_module.py
index c2c969eaa..f9c04200c 100644
--- a/src/pystencils/jit/cpu_extension_module.py
+++ b/src/pystencils/jit/cpu_extension_module.py
@@ -206,6 +206,8 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
         self._array_assoc_var_extractions: dict[Parameter, str] = dict()
         self._scalar_extractions: dict[Parameter, str] = dict()
 
+        self._reduction_ptrs: dict[Parameter, str] = dict()
+
         self._constraint_checks: list[str] = []
 
         self._call: str | None = None
@@ -265,10 +267,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
         return self._array_buffers[field]
 
     def extract_scalar(self, param: Parameter) -> str:
-        if any(isinstance(e, ReductionPointerVariable) for e in param.properties):
-            # TODO: implement
-            pass
-        elif param not in self._scalar_extractions:
+        if param not in self._scalar_extractions:
             extract_func = self._scalar_extractor(param.dtype)
             code = self.TMPL_EXTRACT_SCALAR.format(
                 name=param.name,
@@ -279,6 +278,12 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
 
         return param.name
 
+    def extract_reduction_ptr(self, param: Parameter) -> str:
+        if param not in self._reduction_ptrs:
+            # TODO: implement
+            pass
+        return param.name
+
     def extract_array_assoc_var(self, param: Parameter) -> str:
         if param not in self._array_assoc_var_extractions:
             field = param.fields[0]
@@ -306,7 +311,9 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
         return param.name
 
     def extract_parameter(self, param: Parameter):
-        if param.is_field_parameter:
+        if param.is_reduction_pointer:
+            self.extract_reduction_ptr(param)
+        elif param.is_field_parameter:
             self.extract_array_assoc_var(param)
         else:
             self.extract_scalar(param)
-- 
GitLab