From 90ca9ead0199cd4f5988e6c43e9c9c5350f566b6 Mon Sep 17 00:00:00 2001
From: zy69guqi <richard.angersbach@fau.de>
Date: Mon, 20 Jan 2025 17:46:49 +0100
Subject: [PATCH] Try initializing kernel-local reduction variable copy

---
 .../backend/kernelcreation/freeze.py          | 28 +++++++++++--------
 src/pystencils/codegen/driver.py              | 12 +++++++-
 src/pystencils/codegen/properties.py          |  7 +++--
 3 files changed, 33 insertions(+), 14 deletions(-)

diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py
index 840329013..e0dcba8fd 100644
--- a/src/pystencils/backend/kernelcreation/freeze.py
+++ b/src/pystencils/backend/kernelcreation/freeze.py
@@ -7,6 +7,7 @@ import sympy.core.relational
 import sympy.logic.boolalg
 from sympy.codegen.ast import AssignmentBase, AugmentedAssignment
 
+from ..memory import PsSymbol
 from ...assignment import Assignment
 from ...simp import AssignmentCollection
 from ...sympyextensions import (
@@ -193,32 +194,37 @@ class FreezeExpressions:
         assert isinstance(rhs, PsExpression)
         assert isinstance(lhs, PsSymbolExpr)
 
+        # create kernel-local copy of lhs symbol to work with
+        new_lhs_symbol = PsSymbol(f"{lhs.symbol.name}_local", lhs.dtype)
+        new_lhs = PsSymbolExpr(new_lhs_symbol)
+        self._ctx.add_symbol(new_lhs_symbol)
+
         # match for reduction operation and set neutral init_val and new rhs (similar to augmented assignment)
         new_rhs: PsExpression
-        init_val: Any  # TODO: type?
+        init_val: PsExpression
         match expr.op:
             case "+":
-                init_val = PsConstant(0)
-                new_rhs = add(lhs.clone(), rhs)
+                init_val = PsConstantExpr(PsConstant(0))
+                new_rhs = add(new_lhs.clone(), rhs)
             case "-":
-                init_val = PsConstant(0)
-                new_rhs = sub(lhs.clone(), rhs)
+                init_val = PsConstantExpr(PsConstant(0))
+                new_rhs = sub(new_lhs.clone(), rhs)
             case "*":
-                init_val = PsConstant(1)
-                new_rhs = mul(lhs.clone(), rhs)
+                init_val = PsConstantExpr(PsConstant(1))
+                new_rhs = mul(new_lhs.clone(), rhs)
             case "min":
                 init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Min), [])
-                new_rhs = PsCall(PsMathFunction(MathFunctions.Min), [lhs.clone(), rhs])
+                new_rhs = PsCall(PsMathFunction(MathFunctions.Min), [new_lhs.clone(), rhs])
             case "max":
                 init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Max), [])
-                new_rhs = PsCall(PsMathFunction(MathFunctions.Max), [lhs.clone(), rhs])
+                new_rhs = PsCall(PsMathFunction(MathFunctions.Max), [new_lhs.clone(), rhs])
             case _:
                 raise FreezeError(f"Unsupported reduced assignment: {expr.op}.")
 
         # set reduction symbol property in context
-        self._ctx.add_reduction_to_symbol(lhs.symbol, ReductionSymbolProperty(expr.op, init_val))
+        self._ctx.add_reduction_to_symbol(new_lhs_symbol, ReductionSymbolProperty(expr.op, init_val, lhs.symbol))
 
-        return PsAssignment(lhs, new_rhs)
+        return PsAssignment(new_lhs, new_rhs)
 
     def map_Symbol(self, spsym: sp.Symbol) -> PsSymbolExpr:
         symb = self._ctx.get_symbol(spsym.name)
diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py
index 7bdec96cc..199860743 100644
--- a/src/pystencils/codegen/driver.py
+++ b/src/pystencils/codegen/driver.py
@@ -7,12 +7,13 @@ from .config import CreateKernelConfig, OpenMpConfig, VectorizationConfig, AUTO
 from .kernel import Kernel, GpuKernel, GpuThreadsRange
 from .properties import PsSymbolProperty, FieldShape, FieldStride, FieldBasePtr
 from .parameters import Parameter
+from ..backend.ast.expressions import PsSymbolExpr
 
 from ..types import create_numeric_type, PsIntegerType, PsScalarType
 
 from ..backend.memory import PsSymbol
 from ..backend.ast import PsAstNode
-from ..backend.ast.structural import PsBlock, PsLoop
+from ..backend.ast.structural import PsBlock, PsLoop, PsAssignment
 from ..backend.ast.analysis import collect_undefined_symbols, collect_required_headers
 from ..backend.kernelcreation import (
     KernelCreationContext,
@@ -151,6 +152,14 @@ class DefaultKernelCreationDriver:
         if self._intermediates is not None:
             self._intermediates.constants_eliminated = kernel_ast.clone()
 
+        #   Init local reduction variable copy
+        # for red, prop in self._ctx.symbols_with_reduction.items():
+        #     kernel_ast.statements = [PsAssignment(PsSymbolExpr(red), prop.init_val)] + kernel_ast.statements
+
+        #   Write back result to reduction target variable
+        # for red, prop in self._ctx.symbols_with_reduction.items():
+        #     kernel_ast.statements += [PsAssignment(PsSymbolExpr(prop.orig_symbol), PsSymbolExpr(red))]
+
         #   Target-Specific optimizations
         if self._cfg.target.is_cpu():
             kernel_ast = self._transform_for_cpu(kernel_ast)
@@ -449,6 +458,7 @@ def _get_function_params(
         props: set[PsSymbolProperty] = set()
         for prop in symb.properties:
             match prop:
+                # TODO: how to export reduction result (via pointer)?
                 case FieldShape() | FieldStride():
                     props.add(prop)
                 case BufferBasePtr(buf):
diff --git a/src/pystencils/codegen/properties.py b/src/pystencils/codegen/properties.py
index 0bad4e898..4b8e7f2bf 100644
--- a/src/pystencils/codegen/properties.py
+++ b/src/pystencils/codegen/properties.py
@@ -2,7 +2,6 @@ from __future__ import annotations
 from dataclasses import dataclass
 
 from ..field import Field
-from typing import Any
 
 
 @dataclass(frozen=True)
@@ -19,8 +18,12 @@ class UniqueSymbolProperty(PsSymbolProperty):
 class ReductionSymbolProperty(UniqueSymbolProperty):
     """Property for symbols specifying the operation and initial value for a reduction."""
 
+    from ..backend.memory import PsSymbol
+    from ..backend.ast.expressions import PsExpression
+
     op: str
-    init_val: Any  # TODO: type?
+    init_val: PsExpression
+    orig_symbol: PsSymbol
 
 
 @dataclass(frozen=True)
-- 
GitLab