diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 06d98a44e619124d3fcbd4ee226e28e6924b0152..d8fb1b91ef03372c66725ddd291aa710398da005 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -61,7 +61,7 @@ from ..ast.expressions import ( from ..ast.vector import PsVecMemAcc from ..constants import PsConstant -from ...types import PsNumericType, PsStructType, PsType +from ...types import PsNumericType, PsStructType, PsType, PsPointerType from ..exceptions import PsInputError from ..functions import PsMathFunction, MathFunctions, NumericLimitsFunctions from ..exceptions import FreezeError @@ -195,9 +195,9 @@ class FreezeExpressions: assert isinstance(lhs, PsSymbolExpr) # create kernel-local copy of lhs symbol to work with - new_lhs_symbol = PsSymbol(f"{lhs.symbol.name}_local", lhs.dtype) - new_lhs = PsSymbolExpr(new_lhs_symbol) - self._ctx.add_symbol(new_lhs_symbol) + new_lhs_symb = PsSymbol(f"{lhs.symbol.name}_local", rhs.dtype) + new_lhs = PsSymbolExpr(new_lhs_symb) + self._ctx.add_symbol(new_lhs_symb) # match for reduction operation and set neutral init_val and new rhs (similar to augmented assignment) new_rhs: PsExpression @@ -221,8 +221,13 @@ class FreezeExpressions: case _: raise FreezeError(f"Unsupported reduced assignment: {expr.op}.") + # replace original symbol with pointer-based type used for export + orig_symbol_as_ptr = PsSymbol(lhs.symbol.name, PsPointerType(rhs.dtype)) + self._ctx.replace_symbol(lhs.symbol, orig_symbol_as_ptr) + # set reduction symbol property in context - self._ctx.add_reduction_to_symbol(new_lhs_symbol, ReductionSymbolProperty(expr.op, init_val, lhs.symbol)) + init_val.dtype = rhs.dtype + self._ctx.add_reduction_to_symbol(new_lhs_symb, ReductionSymbolProperty(expr.op, init_val, orig_symbol_as_ptr)) return PsAssignment(new_lhs, new_rhs) diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py index 4b08b84ef4bea93550e475c6ab1fa3fb0314d441..04d7376d058bcc5d64ee7400b56189b57f294804 100644 --- a/src/pystencils/codegen/driver.py +++ b/src/pystencils/codegen/driver.py @@ -7,7 +7,7 @@ from .config import CreateKernelConfig, OpenMpConfig, VectorizationConfig, AUTO from .kernel import Kernel, GpuKernel, GpuThreadsRange from .properties import PsSymbolProperty, FieldShape, FieldStride, FieldBasePtr from .parameters import Parameter -from ..backend.ast.expressions import PsSymbolExpr +from ..backend.ast.expressions import PsSymbolExpr, PsMemAcc, PsConstantExpr from ..types import create_numeric_type, PsIntegerType, PsScalarType @@ -158,11 +158,9 @@ class DefaultKernelCreationDriver: # Write back result to reduction target variable for red, prop in self._ctx.symbols_with_reduction.items(): - kernel_ast.statements += [PsAssignment(PsSymbolExpr(prop.orig_symbol), PsSymbolExpr(red))] - - # TODO: can this be omitted? - typify = Typifier(self._ctx) - kernel_ast = typify(kernel_ast) + kernel_ast.statements += [PsAssignment( + PsMemAcc(PsSymbolExpr(prop.orig_symbol), PsConstantExpr(PsConstant(0, self._ctx.index_dtype))), + PsSymbolExpr(red))] # Target-Specific optimizations if self._cfg.target.is_cpu():