diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index 4b4604a2153d7be40f750140566703f6c02b6355..b9df6f6826c6575974f061e89b3b7ef0ae03cbe3 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -75,7 +75,7 @@ class KernelCreationContext: self._symbol_ctr_pattern = re.compile(r"__[0-9]+$") self._symbol_dup_table: defaultdict[str, int] = defaultdict(lambda: 0) - # TODO: add list of reduction symbols + self._symbols_with_reduction: dict[PsSymbol, ReductionSymbolProperty] = dict() self._fields_and_arrays: dict[str, FieldArrayPair] = dict() self._fields_collection = FieldsInKernel() @@ -170,6 +170,21 @@ class KernelCreationContext: self._symbols[old.name] = new + def add_reduction_to_symbol(self, symbol: PsSymbol, reduction: ReductionSymbolProperty): + """Adds a reduction property to a symbol. + + The symbol ``symbol`` should not have a reduction property and must exist in the symbol table. + """ + if self.find_symbol(symbol.name) is None: + raise PsInternalCompilerError( + "add_reduction_to_symbol: Symbol does not exist in the symbol table" + ) + + if symbol not in self._symbols_with_reduction and not symbol.get_properties(ReductionSymbolProperty): + self._symbols_with_reduction[symbol] = reduction + else: + raise PsInternalCompilerError(f"add_reduction_to_symbol: Symbol {symbol.name} already has a reduction property") + def duplicate_symbol( self, symb: PsSymbol, new_dtype: PsType | None = None ) -> PsSymbol: