From c51ae2b438d688f871ba45b69476ef0c3b475462 Mon Sep 17 00:00:00 2001
From: zy69guqi <richard.angersbach@fau.de>
Date: Tue, 21 Jan 2025 18:42:43 +0100
Subject: [PATCH] Set type of reduced variable to pointer and write back via
 PsMemAcc

---
 src/pystencils/backend/kernelcreation/freeze.py | 15 ++++++++++-----
 src/pystencils/codegen/driver.py                | 10 ++++------
 2 files changed, 14 insertions(+), 11 deletions(-)

diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py
index 06d98a44e..d8fb1b91e 100644
--- a/src/pystencils/backend/kernelcreation/freeze.py
+++ b/src/pystencils/backend/kernelcreation/freeze.py
@@ -61,7 +61,7 @@ from ..ast.expressions import (
 from ..ast.vector import PsVecMemAcc
 
 from ..constants import PsConstant
-from ...types import PsNumericType, PsStructType, PsType
+from ...types import PsNumericType, PsStructType, PsType, PsPointerType
 from ..exceptions import PsInputError
 from ..functions import PsMathFunction, MathFunctions, NumericLimitsFunctions
 from ..exceptions import FreezeError
@@ -195,9 +195,9 @@ class FreezeExpressions:
         assert isinstance(lhs, PsSymbolExpr)
 
         # create kernel-local copy of lhs symbol to work with
-        new_lhs_symbol = PsSymbol(f"{lhs.symbol.name}_local", lhs.dtype)
-        new_lhs = PsSymbolExpr(new_lhs_symbol)
-        self._ctx.add_symbol(new_lhs_symbol)
+        new_lhs_symb = PsSymbol(f"{lhs.symbol.name}_local", rhs.dtype)
+        new_lhs = PsSymbolExpr(new_lhs_symb)
+        self._ctx.add_symbol(new_lhs_symb)
 
         # match for reduction operation and set neutral init_val and new rhs (similar to augmented assignment)
         new_rhs: PsExpression
@@ -221,8 +221,13 @@ class FreezeExpressions:
             case _:
                 raise FreezeError(f"Unsupported reduced assignment: {expr.op}.")
 
+        # replace original symbol with pointer-based type used for export
+        orig_symbol_as_ptr = PsSymbol(lhs.symbol.name, PsPointerType(rhs.dtype))
+        self._ctx.replace_symbol(lhs.symbol, orig_symbol_as_ptr)
+
         # set reduction symbol property in context
-        self._ctx.add_reduction_to_symbol(new_lhs_symbol, ReductionSymbolProperty(expr.op, init_val, lhs.symbol))
+        init_val.dtype = rhs.dtype
+        self._ctx.add_reduction_to_symbol(new_lhs_symb, ReductionSymbolProperty(expr.op, init_val, orig_symbol_as_ptr))
 
         return PsAssignment(new_lhs, new_rhs)
 
diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py
index 06a5fd44a..20615ba21 100644
--- a/src/pystencils/codegen/driver.py
+++ b/src/pystencils/codegen/driver.py
@@ -7,7 +7,7 @@ from .config import CreateKernelConfig, OpenMpConfig, VectorizationConfig, AUTO
 from .kernel import Kernel, GpuKernel, GpuThreadsRange
 from .properties import PsSymbolProperty, FieldShape, FieldStride, FieldBasePtr
 from .parameters import Parameter
-from ..backend.ast.expressions import PsSymbolExpr
+from ..backend.ast.expressions import PsSymbolExpr, PsMemAcc, PsConstantExpr
 
 from ..types import create_numeric_type, PsIntegerType, PsScalarType
 
@@ -159,11 +159,9 @@ class DefaultKernelCreationDriver:
 
         #   Write back result to reduction target variable
         for red, prop in self._ctx.symbols_with_reduction.items():
-            kernel_ast.statements += [PsAssignment(PsSymbolExpr(prop.orig_symbol), PsSymbolExpr(red))]
-
-        # TODO: can this be omitted?
-        typify = Typifier(self._ctx)
-        kernel_ast = typify(kernel_ast)
+            kernel_ast.statements += [PsAssignment(
+                PsMemAcc(PsSymbolExpr(prop.orig_symbol), PsConstantExpr(PsConstant(0, self._ctx.index_dtype))),
+                PsSymbolExpr(red))]
 
         #   Target-Specific optimizations
         if self._cfg.target.is_cpu():
-- 
GitLab