From a3025645e30265621d41315e7a0449768225c361 Mon Sep 17 00:00:00 2001
From: zy69guqi <richard.angersbach@fau.de>
Date: Fri, 17 Jan 2025 14:55:42 +0100
Subject: [PATCH] Fix min/max reductions

---
 src/pystencils/backend/kernelcreation/freeze.py | 16 +++++++++-------
 1 file changed, 9 insertions(+), 7 deletions(-)

diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py
index 9a34303e2..64230203f 100644
--- a/src/pystencils/backend/kernelcreation/freeze.py
+++ b/src/pystencils/backend/kernelcreation/freeze.py
@@ -193,29 +193,31 @@ class FreezeExpressions:
         assert isinstance(rhs, PsExpression)
         assert isinstance(lhs, PsSymbolExpr)
 
+        # match for reduction operation and set neutral init_val and new rhs (similar to augmented assignment)
+        new_rhs: PsExpression
         match expr.op:
             case "+":
-                op = add
                 init_val = PsConstant(0)
+                new_rhs = add(lhs.clone(), rhs)
             case "-":
-                op = sub
                 init_val = PsConstant(0)
+                new_rhs = sub(lhs.clone(), rhs)
             case "*":
-                op = mul
                 init_val = PsConstant(1)
-            # TODO: unsure if sp.Min & sp.Max are mapped by map_Min/map_Max afterwards
+                new_rhs = mul(lhs.clone(), rhs)
             case "min":
-                op = sp.Min
                 init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Min), [])
+                new_rhs = PsCall(PsMathFunction(MathFunctions.Min), [lhs.clone(), rhs])
             case "max":
-                op = sp.Max
                 init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Max), [])
+                new_rhs = PsCall(PsMathFunction(MathFunctions.Max), [lhs.clone(), rhs])
             case _:
                 raise FreezeError(f"Unsupported reduced assignment: {expr.op}.")
 
+        # set reduction symbol property in context
         self._ctx.add_reduction_to_symbol(lhs.symbol, ReductionSymbolProperty(expr.op, init_val))
 
-        return PsAssignment(lhs, op(lhs.clone(), rhs))
+        return PsAssignment(lhs, new_rhs)
 
     def map_Symbol(self, spsym: sp.Symbol) -> PsSymbolExpr:
         symb = self._ctx.get_symbol(spsym.name)
-- 
GitLab