diff --git a/src/pystencils/codegen/parameters.py b/src/pystencils/codegen/parameters.py
index d8411266ee514d4270a7a9d1c2fb24383f005329..094553517a56a189f4fa714749b5f7b5761f8e33 100644
--- a/src/pystencils/codegen/parameters.py
+++ b/src/pystencils/codegen/parameters.py
@@ -1,14 +1,14 @@
 from __future__ import annotations
 
 from warnings import warn
-from typing import Sequence, Iterable
+from typing import Sequence, Iterable, Optional
 
 from .properties import (
     PsSymbolProperty,
     _FieldProperty,
     FieldShape,
     FieldStride,
-    FieldBasePtr,
+    FieldBasePtr, ReductionPointerVariable,
 )
 from ..types import PsType
 from ..field import Field
@@ -39,6 +39,9 @@ class Parameter:
                 key=lambda f: f.name,
             )
         )
+        self._reduction_ptr: Optional[ReductionPointerVariable] = next(
+            (e for e in self._properties if isinstance(e, ReductionPointerVariable)), None
+        )
 
     @property
     def name(self):
@@ -79,6 +82,11 @@ class Parameter:
         """Set of fields associated with this parameter."""
         return self._fields
 
+    @property
+    def reduction_pointer(self) -> Optional[ReductionPointerVariable]:
+        """Reduction pointer associated with this parameter."""
+        return self._reduction_ptr
+
     def get_properties(
         self, prop_type: type[PsSymbolProperty] | tuple[type[PsSymbolProperty], ...]
     ) -> set[PsSymbolProperty]:
@@ -105,6 +113,10 @@ class Parameter:
         )
         return bool(self.get_properties(FieldBasePtr))
 
+    @property
+    def is_reduction_pointer(self) -> bool:
+        return bool(self._reduction_ptr)
+
     @property
     def is_field_stride(self) -> bool:  # pragma: no cover
         warn(
diff --git a/src/pystencils/jit/cpu_extension_module.py b/src/pystencils/jit/cpu_extension_module.py
index c2c969eaae1057f7892150d10124f11d49e6d060..f9c04200c249594babede2bab3ff79d59e909045 100644
--- a/src/pystencils/jit/cpu_extension_module.py
+++ b/src/pystencils/jit/cpu_extension_module.py
@@ -206,6 +206,8 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
         self._array_assoc_var_extractions: dict[Parameter, str] = dict()
         self._scalar_extractions: dict[Parameter, str] = dict()
 
+        self._reduction_ptrs: dict[Parameter, str] = dict()
+
         self._constraint_checks: list[str] = []
 
         self._call: str | None = None
@@ -265,10 +267,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
         return self._array_buffers[field]
 
     def extract_scalar(self, param: Parameter) -> str:
-        if any(isinstance(e, ReductionPointerVariable) for e in param.properties):
-            # TODO: implement
-            pass
-        elif param not in self._scalar_extractions:
+        if param not in self._scalar_extractions:
             extract_func = self._scalar_extractor(param.dtype)
             code = self.TMPL_EXTRACT_SCALAR.format(
                 name=param.name,
@@ -279,6 +278,12 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
 
         return param.name
 
+    def extract_reduction_ptr(self, param: Parameter) -> str:
+        if param not in self._reduction_ptrs:
+            # TODO: implement
+            pass
+        return param.name
+
     def extract_array_assoc_var(self, param: Parameter) -> str:
         if param not in self._array_assoc_var_extractions:
             field = param.fields[0]
@@ -306,7 +311,9 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
         return param.name
 
     def extract_parameter(self, param: Parameter):
-        if param.is_field_parameter:
+        if param.is_reduction_pointer:
+            self.extract_reduction_ptr(param)
+        elif param.is_field_parameter:
             self.extract_array_assoc_var(param)
         else:
             self.extract_scalar(param)