From c6eedfcda96e84e8279ee624ba3c113f2339bfbe Mon Sep 17 00:00:00 2001 From: zy69guqi <richard.angersbach@fau.de> Date: Wed, 22 Jan 2025 15:16:32 +0100 Subject: [PATCH] Split reduction var property into local and pointer-based reduction var properties --- .../backend/kernelcreation/context.py | 54 ++++++++++++++----- .../backend/kernelcreation/freeze.py | 31 ++++++----- .../backend/transformations/add_pragmas.py | 4 +- src/pystencils/codegen/driver.py | 10 ++-- src/pystencils/codegen/properties.py | 16 ++++-- 5 files changed, 78 insertions(+), 37 deletions(-) diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index a8728e6ac..2f46a7421 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -9,7 +9,7 @@ from ...defaults import DEFAULTS from ...field import Field, FieldType from ...sympyextensions.typed_sympy import TypedSymbol, DynamicType -from ...codegen.properties import ReductionSymbolProperty +from ...codegen.properties import LocalReductionVariable, ReductionPointerVariable from ..memory import PsSymbol, PsBuffer from ..constants import PsConstant @@ -77,7 +77,8 @@ class KernelCreationContext: self._symbol_ctr_pattern = re.compile(r"__[0-9]+$") self._symbol_dup_table: defaultdict[str, int] = defaultdict(lambda: 0) - self._symbols_with_reduction: dict[PsSymbol, ReductionSymbolProperty] = dict() + self._local_reduction_symbols: dict[PsSymbol, LocalReductionVariable] = dict() + self._reduction_ptr_symbols: dict[PsSymbol, ReductionPointerVariable] = dict() self._fields_and_arrays: dict[str, FieldArrayPair] = dict() self._fields_collection = FieldsInKernel() @@ -172,21 +173,41 @@ class KernelCreationContext: self._symbols[old.name] = new - def add_reduction_to_symbol(self, symbol: PsSymbol, reduction: ReductionSymbolProperty): - """Adds a reduction property to a symbol. + def add_local_reduction_symbol(self, local_symb: PsSymbol, local_var_prop: LocalReductionVariable): + """Adds entry for a symbol and its property to the lookup table for local reduction variables. - The symbol ``symbol`` should not have a reduction property and must exist in the symbol table. + The symbol ``symbol`` should not have a 'LocalReductionSymbol' property and shall not exist in the symbol table. """ - if self.find_symbol(symbol.name) is None: + if self.find_symbol(local_symb.name) is not None: raise PsInternalCompilerError( - f"add_reduction_to_symbol: {symbol.name} does not exist in the symbol table" + f"add_local_reduction_symbol: {local_symb.name} already exist in the symbol table" ) + self.add_symbol(local_symb) - if symbol not in self._symbols_with_reduction and not symbol.get_properties(ReductionSymbolProperty): - symbol.add_property(reduction) - self._symbols_with_reduction[symbol] = reduction + if local_symb not in self._local_reduction_symbols and not local_symb.get_properties(LocalReductionVariable): + local_symb.add_property(local_var_prop) + self._local_reduction_symbols[local_symb] = local_var_prop else: - raise PsInternalCompilerError(f"add_reduction_to_symbol: {symbol.name} already has a reduction property") + raise PsInternalCompilerError( + f"add_local_reduction_symbol: {local_symb.name} already exists in local reduction table" + ) + + def add_reduction_ptr_symbol(self, orig_symb: PsSymbol, ptr_symb: PsSymbol, ptr_var_prop: ReductionPointerVariable): + """Replaces reduction symbol with a pointer-based counterpart used for export + and adds the new symbol and its property to the lookup table for pointer-based reduction variables + + The symbol ``ptr_symbol`` should not exist in the symbol table. + """ + self.replace_symbol(orig_symb, ptr_symb) + + if ptr_symb not in self._reduction_ptr_symbols and not ptr_symb.get_properties( + ReductionPointerVariable): + ptr_symb.add_property(ptr_var_prop) + self._reduction_ptr_symbols[ptr_symb] = ptr_var_prop + else: + raise PsInternalCompilerError( + f"add_reduction_ptr_symbol: {ptr_symb.name} already exists in pointer-based reduction variable table " + ) def duplicate_symbol( self, symb: PsSymbol, new_dtype: PsType | None = None @@ -224,9 +245,14 @@ class KernelCreationContext: return self._symbols.values() @property - def symbols_with_reduction(self) -> dict[PsSymbol, ReductionSymbolProperty]: - """Return a dictionary holding symbols and their reduction property.""" - return self._symbols_with_reduction + def local_reduction_symbols(self) -> dict[PsSymbol, LocalReductionVariable]: + """Return a dictionary holding kernel-local reduction symbols and their reduction properties.""" + return self._local_reduction_symbols + + @property + def reduction_pointer_symbols(self) -> dict[PsSymbol, ReductionPointerVariable]: + """Return a dictionary holding pointer-based reduction symbols and their reduction properties.""" + return self._reduction_ptr_symbols # Fields and Arrays diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index d8fb1b91e..1e9984def 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -66,7 +66,7 @@ from ..exceptions import PsInputError from ..functions import PsMathFunction, MathFunctions, NumericLimitsFunctions from ..exceptions import FreezeError -from ...codegen.properties import ReductionSymbolProperty +from ...codegen.properties import LocalReductionVariable, ReductionPointerVariable ExprLike = ( @@ -194,40 +194,45 @@ class FreezeExpressions: assert isinstance(rhs, PsExpression) assert isinstance(lhs, PsSymbolExpr) + orig_lhs_symb = lhs.symbol + dtype = rhs.dtype # TODO: kernel with (implicit) up/downcasts? + + # replace original symbol with pointer-based type used for export + orig_lhs_symb_as_ptr = PsSymbol(orig_lhs_symb.name, PsPointerType(dtype)) + # create kernel-local copy of lhs symbol to work with - new_lhs_symb = PsSymbol(f"{lhs.symbol.name}_local", rhs.dtype) + new_lhs_symb = PsSymbol(f"{orig_lhs_symb.name}_local", dtype) new_lhs = PsSymbolExpr(new_lhs_symb) - self._ctx.add_symbol(new_lhs_symb) # match for reduction operation and set neutral init_val and new rhs (similar to augmented assignment) new_rhs: PsExpression init_val: PsExpression match expr.op: case "+": - init_val = PsConstantExpr(PsConstant(0)) + init_val = PsConstantExpr(PsConstant(0, dtype)) new_rhs = add(new_lhs.clone(), rhs) case "-": - init_val = PsConstantExpr(PsConstant(0)) + init_val = PsConstantExpr(PsConstant(0, dtype)) new_rhs = sub(new_lhs.clone(), rhs) case "*": - init_val = PsConstantExpr(PsConstant(1)) + init_val = PsConstantExpr(PsConstant(1, dtype)) new_rhs = mul(new_lhs.clone(), rhs) case "min": init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Max), []) + init_val.dtype = dtype new_rhs = PsCall(PsMathFunction(MathFunctions.Min), [new_lhs.clone(), rhs]) case "max": init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Min), []) + init_val.dtype = dtype new_rhs = PsCall(PsMathFunction(MathFunctions.Max), [new_lhs.clone(), rhs]) case _: raise FreezeError(f"Unsupported reduced assignment: {expr.op}.") - # replace original symbol with pointer-based type used for export - orig_symbol_as_ptr = PsSymbol(lhs.symbol.name, PsPointerType(rhs.dtype)) - self._ctx.replace_symbol(lhs.symbol, orig_symbol_as_ptr) - - # set reduction symbol property in context - init_val.dtype = rhs.dtype - self._ctx.add_reduction_to_symbol(new_lhs_symb, ReductionSymbolProperty(expr.op, init_val, orig_symbol_as_ptr)) + # set reduction symbol properties (local/pointer variables) in context + self._ctx.add_local_reduction_symbol(new_lhs_symb, + LocalReductionVariable(expr.op, init_val, orig_lhs_symb_as_ptr)) + self._ctx.add_reduction_ptr_symbol(orig_lhs_symb, orig_lhs_symb_as_ptr, + ReductionPointerVariable(expr.op, new_lhs_symb)) return PsAssignment(new_lhs, new_rhs) diff --git a/src/pystencils/backend/transformations/add_pragmas.py b/src/pystencils/backend/transformations/add_pragmas.py index 6d72e1550..44d1d1ede 100644 --- a/src/pystencils/backend/transformations/add_pragmas.py +++ b/src/pystencils/backend/transformations/add_pragmas.py @@ -112,8 +112,8 @@ class AddOpenMP: pragma_text += " parallel" if not omp_params.omit_parallel_construct else "" pragma_text += f" for schedule({omp_params.schedule})" - if bool(ctx.symbols_with_reduction): - for symbol, reduction in ctx.symbols_with_reduction.items(): + if bool(ctx.local_reduction_symbols): + for symbol, reduction in ctx.local_reduction_symbols.items(): if isinstance(symbol.dtype, PsScalarType): pragma_text += f" reduction({reduction.op}: {symbol.name})" else: diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py index 20615ba21..7f90f62ce 100644 --- a/src/pystencils/codegen/driver.py +++ b/src/pystencils/codegen/driver.py @@ -154,14 +154,14 @@ class DefaultKernelCreationDriver: self._intermediates.constants_eliminated = kernel_ast.clone() # Init local reduction variable copy - for red, prop in self._ctx.symbols_with_reduction.items(): - kernel_ast.statements = [PsDeclaration(PsSymbolExpr(red), prop.init_val)] + kernel_ast.statements + for local_red, prop in self._ctx.local_reduction_symbols.items(): + kernel_ast.statements = [PsDeclaration(PsSymbolExpr(local_red), prop.init_val)] + kernel_ast.statements # Write back result to reduction target variable - for red, prop in self._ctx.symbols_with_reduction.items(): + for red_ptr, prop in self._ctx.reduction_pointer_symbols.items(): kernel_ast.statements += [PsAssignment( - PsMemAcc(PsSymbolExpr(prop.orig_symbol), PsConstantExpr(PsConstant(0, self._ctx.index_dtype))), - PsSymbolExpr(red))] + PsMemAcc(PsSymbolExpr(red_ptr), PsConstantExpr(PsConstant(0, self._ctx.index_dtype))), + PsSymbolExpr(prop.local_symbol))] # Target-Specific optimizations if self._cfg.target.is_cpu(): diff --git a/src/pystencils/codegen/properties.py b/src/pystencils/codegen/properties.py index 4b8e7f2bf..1e71c5b98 100644 --- a/src/pystencils/codegen/properties.py +++ b/src/pystencils/codegen/properties.py @@ -15,15 +15,25 @@ class UniqueSymbolProperty(PsSymbolProperty): @dataclass(frozen=True) -class ReductionSymbolProperty(UniqueSymbolProperty): - """Property for symbols specifying the operation and initial value for a reduction.""" +class LocalReductionVariable(PsSymbolProperty): + """Property for symbols specifying the operation and initial value for a kernel-local reduction variable.""" from ..backend.memory import PsSymbol from ..backend.ast.expressions import PsExpression op: str init_val: PsExpression - orig_symbol: PsSymbol + ptr_symbol: PsSymbol + + +@dataclass(frozen=True) +class ReductionPointerVariable(PsSymbolProperty): + """Property for pointer-type symbols exporting the reduction result from the kernel.""" + + from ..backend.memory import PsSymbol + + op: str + local_symbol: PsSymbol @dataclass(frozen=True) -- GitLab