diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py index 0293cce48273836b8179ec3cc98476bd3bec2e80..06a5fd44a62a7cd651d00e06eafa15edef6ec9a4 100644 --- a/src/pystencils/codegen/driver.py +++ b/src/pystencils/codegen/driver.py @@ -13,7 +13,7 @@ from ..types import create_numeric_type, PsIntegerType, PsScalarType from ..backend.memory import PsSymbol from ..backend.ast import PsAstNode -from ..backend.ast.structural import PsBlock, PsLoop, PsAssignment +from ..backend.ast.structural import PsBlock, PsLoop, PsAssignment, PsDeclaration from ..backend.ast.analysis import collect_undefined_symbols, collect_required_headers from ..backend.kernelcreation import ( KernelCreationContext, @@ -154,12 +154,16 @@ class DefaultKernelCreationDriver: self._intermediates.constants_eliminated = kernel_ast.clone() # Init local reduction variable copy - # for red, prop in self._ctx.symbols_with_reduction.items(): - # kernel_ast.statements = [PsAssignment(PsSymbolExpr(red), prop.init_val)] + kernel_ast.statements + for red, prop in self._ctx.symbols_with_reduction.items(): + kernel_ast.statements = [PsDeclaration(PsSymbolExpr(red), prop.init_val)] + kernel_ast.statements # Write back result to reduction target variable - # for red, prop in self._ctx.symbols_with_reduction.items(): - # kernel_ast.statements += [PsAssignment(PsSymbolExpr(prop.orig_symbol), PsSymbolExpr(red))] + for red, prop in self._ctx.symbols_with_reduction.items(): + kernel_ast.statements += [PsAssignment(PsSymbolExpr(prop.orig_symbol), PsSymbolExpr(red))] + + # TODO: can this be omitted? + typify = Typifier(self._ctx) + kernel_ast = typify(kernel_ast) # Target-Specific optimizations if self._cfg.target.is_cpu():