From a3025645e30265621d41315e7a0449768225c361 Mon Sep 17 00:00:00 2001 From: zy69guqi <richard.angersbach@fau.de> Date: Fri, 17 Jan 2025 14:55:42 +0100 Subject: [PATCH] Fix min/max reductions --- src/pystencils/backend/kernelcreation/freeze.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 9a34303e2..64230203f 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -193,29 +193,31 @@ class FreezeExpressions: assert isinstance(rhs, PsExpression) assert isinstance(lhs, PsSymbolExpr) + # match for reduction operation and set neutral init_val and new rhs (similar to augmented assignment) + new_rhs: PsExpression match expr.op: case "+": - op = add init_val = PsConstant(0) + new_rhs = add(lhs.clone(), rhs) case "-": - op = sub init_val = PsConstant(0) + new_rhs = sub(lhs.clone(), rhs) case "*": - op = mul init_val = PsConstant(1) - # TODO: unsure if sp.Min & sp.Max are mapped by map_Min/map_Max afterwards + new_rhs = mul(lhs.clone(), rhs) case "min": - op = sp.Min init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Min), []) + new_rhs = PsCall(PsMathFunction(MathFunctions.Min), [lhs.clone(), rhs]) case "max": - op = sp.Max init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Max), []) + new_rhs = PsCall(PsMathFunction(MathFunctions.Max), [lhs.clone(), rhs]) case _: raise FreezeError(f"Unsupported reduced assignment: {expr.op}.") + # set reduction symbol property in context self._ctx.add_reduction_to_symbol(lhs.symbol, ReductionSymbolProperty(expr.op, init_val)) - return PsAssignment(lhs, op(lhs.clone(), rhs)) + return PsAssignment(lhs, new_rhs) def map_Symbol(self, spsym: sp.Symbol) -> PsSymbolExpr: symb = self._ctx.get_symbol(spsym.name) -- GitLab