From fe2537fa4b2cba6fc13f066d00b260032f5cb739 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Wed, 3 Apr 2024 17:11:19 +0200 Subject: [PATCH] Symbol Canonicalization, Loop-Invariant Code Motion, and AST Factory --- docs/source/backend/index.rst | 2 +- docs/source/backend/translation.rst | 3 + src/pystencils/backend/ast/analysis.py | 10 +- src/pystencils/backend/ast/expressions.py | 4 +- src/pystencils/backend/ast/structural.py | 11 +- src/pystencils/backend/constants.py | 14 +- src/pystencils/backend/emission.py | 2 +- .../backend/kernelcreation/__init__.py | 2 + .../backend/kernelcreation/ast_factory.py | 187 +++++++++++++++++ .../backend/kernelcreation/context.py | 8 + .../kernelcreation/cpu_optimization.py | 17 +- .../backend/kernelcreation/freeze.py | 28 ++- .../backend/kernelcreation/iteration_space.py | 74 +++---- .../backend/kernelcreation/typification.py | 8 +- .../backend/platforms/generic_cpu.py | 22 +- .../backend/platforms/generic_gpu.py | 6 +- .../backend/transformations/__init__.py | 4 + .../transformations/canonicalize_symbols.py | 125 +++++++++++ .../transformations/eliminate_constants.py | 4 +- .../hoist_loop_invariant_decls.py | 181 ++++++++++++++++ src/pystencils/kernelcreation.py | 19 +- .../kernelcreation/platform/test_basic_cpu.py | 2 +- .../kernelcreation/platform/test_basic_gpu.py | 2 +- .../kernelcreation/test_iteration_space.py | 10 +- .../test_canonicalize_symbols.py | 88 ++++++++ .../transformations/test_hoist_invariants.py | 195 ++++++++++++++++++ 26 files changed, 923 insertions(+), 105 deletions(-) create mode 100644 src/pystencils/backend/kernelcreation/ast_factory.py create mode 100644 src/pystencils/backend/transformations/canonicalize_symbols.py create mode 100644 src/pystencils/backend/transformations/hoist_loop_invariant_decls.py create mode 100644 tests/nbackend/transformations/test_canonicalize_symbols.py create mode 100644 tests/nbackend/transformations/test_hoist_invariants.py diff --git a/docs/source/backend/index.rst b/docs/source/backend/index.rst index e9ac5237b..1e3968bc0 100644 --- a/docs/source/backend/index.rst +++ b/docs/source/backend/index.rst @@ -21,7 +21,7 @@ Internal Representation The code generator translates the kernel from the SymPy frontend's symbolic language to an internal representation (IR), which is then emitted as code in the required dialect of C. -All names of classes associated with the internal kernel representation are prefixed `Ps...` +All names of classes associated with the internal kernel representation are prefixed ``Ps...`` to distinguis them from identically named front-end and SymPy classes. The IR comprises *symbols*, *constants*, *arrays*, the *iteration space* and the *abstract syntax tree*: diff --git a/docs/source/backend/translation.rst b/docs/source/backend/translation.rst index 157675980..a4c7d36b5 100644 --- a/docs/source/backend/translation.rst +++ b/docs/source/backend/translation.rst @@ -5,6 +5,9 @@ Kernel Translation .. autoclass:: pystencils.backend.kernelcreation.KernelCreationContext :members: +.. autoclass:: pystencils.backend.kernelcreation.AstFactory + :members: + .. autoclass:: pystencils.backend.kernelcreation.KernelAnalysis :members: diff --git a/src/pystencils/backend/ast/analysis.py b/src/pystencils/backend/ast/analysis.py index a6ef04ebd..0ea13c563 100644 --- a/src/pystencils/backend/ast/analysis.py +++ b/src/pystencils/backend/ast/analysis.py @@ -18,14 +18,14 @@ from ..exceptions import PsInternalCompilerError class UndefinedSymbolsCollector: - """Collector for undefined variables. + """Collect undefined symbols. - This class implements an AST visitor that collects all `PsTypedVariable`s that have been used + This class implements an AST visitor that collects all symbols that have been used in the AST without being defined prior to their usage. """ def __call__(self, node: PsAstNode) -> set[PsSymbol]: - """Returns all `PsTypedVariable`s that occur in the given AST without being defined prior to their usage.""" + """Returns all symbols that occur in the given AST without being defined prior to their usage.""" return self.visit(node) def visit(self, node: PsAstNode) -> set[PsSymbol]: @@ -79,8 +79,8 @@ class UndefinedSymbolsCollector: """Returns the set of variables declared by the given node which are visible in the enclosing scope.""" match node: - case PsDeclaration(lhs, _): - return {lhs.symbol} + case PsDeclaration(): + return {node.declared_symbol} case ( PsAssignment() diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index 8a66457a9..7c743a399 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -21,7 +21,7 @@ from .astnode import PsAstNode, PsLeafMixIn class PsExpression(PsAstNode, ABC): """Base class for all expressions. - + **Types:** Each expression should be annotated with its type. Upon construction, the `dtype` property of most expression nodes is unset; only constant expressions, symbol expressions, and array accesses immediately inherit their type from @@ -271,7 +271,7 @@ class PsVectorArrayAccess(PsArrayAccess): @property def alignment(self) -> int: return self._alignment - + def get_vector_type(self) -> PsVectorType: return cast(PsVectorType, self._dtype) diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py index 441faa606..47342cfed 100644 --- a/src/pystencils/backend/ast/structural.py +++ b/src/pystencils/backend/ast/structural.py @@ -4,6 +4,7 @@ from types import NoneType from .astnode import PsAstNode, PsLeafMixIn from .expressions import PsExpression, PsLvalue, PsSymbolExpr +from ..symbols import PsSymbol from .util import failing_cast @@ -121,7 +122,7 @@ class PsAssignment(PsAstNode): class PsDeclaration(PsAssignment): __match_args__ = ( - "declared_variable", + "lhs", "rhs", ) @@ -137,12 +138,8 @@ class PsDeclaration(PsAssignment): self._lhs = failing_cast(PsSymbolExpr, lvalue) @property - def declared_variable(self) -> PsSymbolExpr: - return cast(PsSymbolExpr, self._lhs) - - @declared_variable.setter - def declared_variable(self, lvalue: PsSymbolExpr): - self._lhs = lvalue + def declared_symbol(self) -> PsSymbol: + return cast(PsSymbolExpr, self._lhs).symbol def clone(self) -> PsDeclaration: return PsDeclaration(cast(PsSymbolExpr, self._lhs.clone()), self.rhs.clone()) diff --git a/src/pystencils/backend/constants.py b/src/pystencils/backend/constants.py index 125c1149b..b867d89d3 100644 --- a/src/pystencils/backend/constants.py +++ b/src/pystencils/backend/constants.py @@ -7,10 +7,10 @@ from .exceptions import PsInternalCompilerError class PsConstant: """Type-safe representation of typed numerical constants. - + This class models constants in the backend representation of kernels. A constant may be *untyped*, in which case its ``value`` may be any Python object. - + If the constant is *typed* (i.e. its ``dtype`` is not ``None``), its data type is used to check the validity of its ``value`` and to convert it into the type's internal representation. @@ -36,19 +36,19 @@ class PsConstant: def interpret_as(self, dtype: PsNumericType) -> PsConstant: """Interprets this *untyped* constant with the given data type. - + If this constant is already typed, raises an error. """ if self._dtype is not None: raise PsInternalCompilerError( f"Cannot interpret already typed constant {self} with type {dtype}" ) - + return PsConstant(self._value, dtype) - + def reinterpret_as(self, dtype: PsNumericType) -> PsConstant: """Reinterprets this constant with the given data type. - + Other than `interpret_as`, this method also works on typed constants. """ return PsConstant(self._value, dtype) @@ -60,7 +60,7 @@ class PsConstant: @property def dtype(self) -> PsNumericType | None: """This constant's data type, or ``None`` if it is untyped. - + The data type of a constant always has ``const == True``. """ return self._dtype diff --git a/src/pystencils/backend/emission.py b/src/pystencils/backend/emission.py index b742c598d..aa5f853a7 100644 --- a/src/pystencils/backend/emission.py +++ b/src/pystencils/backend/emission.py @@ -169,7 +169,7 @@ class CAstPrinter: return pc.indent(f"{self.visit(expr, pc)};") case PsDeclaration(lhs, rhs): - lhs_symb = lhs.symbol + lhs_symb = node.declared_symbol lhs_code = self._symbol_decl(lhs_symb) rhs_code = self.visit(rhs, pc) diff --git a/src/pystencils/backend/kernelcreation/__init__.py b/src/pystencils/backend/kernelcreation/__init__.py index 1cbddab4f..5de83caad 100644 --- a/src/pystencils/backend/kernelcreation/__init__.py +++ b/src/pystencils/backend/kernelcreation/__init__.py @@ -2,6 +2,7 @@ from .context import KernelCreationContext from .analysis import KernelAnalysis from .freeze import FreezeExpressions from .typification import Typifier +from .ast_factory import AstFactory from .iteration_space import ( FullIterationSpace, @@ -17,6 +18,7 @@ __all__ = [ "KernelAnalysis", "FreezeExpressions", "Typifier", + "AstFactory", "FullIterationSpace", "SparseIterationSpace", "create_full_iteration_space", diff --git a/src/pystencils/backend/kernelcreation/ast_factory.py b/src/pystencils/backend/kernelcreation/ast_factory.py new file mode 100644 index 000000000..b9bbe8cce --- /dev/null +++ b/src/pystencils/backend/kernelcreation/ast_factory.py @@ -0,0 +1,187 @@ +from typing import Any, Sequence, cast, overload + +import sympy as sp +from sympy.codegen.ast import AssignmentBase + +from ..ast import PsAstNode +from ..ast.expressions import PsExpression, PsSymbolExpr +from ..ast.structural import PsLoop, PsBlock, PsAssignment + +from ..symbols import PsSymbol +from ..constants import PsConstant + +from .context import KernelCreationContext +from .freeze import FreezeExpressions +from .typification import Typifier +from .iteration_space import FullIterationSpace + + +class AstFactory: + """Factory providing a convenient interface for building syntax trees. + + The `AstFactory` uses the defaults provided by the given `KernelCreationContext` to quickly create + AST nodes. Depending on context (numerical, loop indexing, etc.), symbols and constants receive either + ``ctx.default_dtype`` or ``ctx.index_dtype``. + + Args: + ctx: The kernel creation context + """ + + def __init__(self, ctx: KernelCreationContext): + self._ctx = ctx + self._freeze = FreezeExpressions(ctx) + self._typify = Typifier(ctx) + + @overload + def parse_sympy(self, sp_obj: sp.Expr) -> PsExpression: + pass + + @overload + def parse_sympy(self, sp_obj: AssignmentBase) -> PsAssignment: + pass + + def parse_sympy(self, sp_obj: sp.Expr | AssignmentBase) -> PsAstNode: + """Parse a SymPy expression or assignment through `FreezeExpressions` and `Typifier`. + + The expression or assignment will be typified in a numerical context, using the kernel + creation context's `default_dtype`. + + Args: + sp_obj: A SymPy expression or assignment + """ + return self._typify(self._freeze(sp_obj)) + + def parse_slice( + self, slic: slice, upper_limit: Any | None = None + ) -> tuple[PsExpression, PsExpression, PsExpression]: + """Parse a slice to obtain start, stop and step expressions for a loop or iteration space dimension. + + The slice entries may be instances of `PsExpression`, `PsSymbol` or `PsConstant`, in which case they + must typify with the kernel creation context's ``index_dtype``. + They may also be sympy expressions or integer constants, in which case they are parsed to AST objects + and must also typify with the kernel creation context's ``index_dtype``. + + If the slice's ``stop`` member is `None` or a negative `int`, `upper_limit` must be specified, which is then + used as the upper iteration limit as either ``upper_limit`` or ``upper_limit - stop``. + + Args: + slic: The iteration slice + upper_limit: Optionally, the upper iteration limit + """ + + if slic.stop is None or (isinstance(slic.stop, int) and slic.stop < 0): + if upper_limit is None: + raise ValueError( + "Must specify an upper iteration limit if `slice.stop` is `None` or a negative `int`" + ) + + def make_expr(val: Any) -> PsExpression: + match val: + case PsExpression(): + return self._typify.typify_expression(val, self._ctx.index_dtype)[0] + case PsSymbol() | PsConstant(): + return self._typify.typify_expression( + PsExpression.make(val), self._ctx.index_dtype + )[0] + case sp.Expr(): + return self._typify.typify_expression( + self._freeze(val), self._ctx.index_dtype + )[0] + case _: + return PsExpression.make(PsConstant(val, self._ctx.index_dtype)) + + start = make_expr(slic.start if slic.start is not None else 0) + stop = make_expr(slic.stop) if slic.stop is not None else make_expr(upper_limit) + step = make_expr(slic.step if slic.step is not None else 1) + + if isinstance(slic.stop, int) and slic.stop < 0: + stop = make_expr(upper_limit) + stop + + return start, stop, step + + def loop(self, ctr_name: str, iteration_slice: slice, body: PsBlock): + """Create a loop from a slice. + + Args: + ctr_name: Name of the loop counter + iteration_slice: The iteration region as a slice; see `parse_slice`. + body: The loop body + """ + ctr = PsExpression.make(self._ctx.get_symbol(ctr_name, self._ctx.index_dtype)) + + start, stop, step = self.parse_slice(iteration_slice) + + return PsLoop( + ctr, + start, + stop, + step, + body, + ) + + def loop_nest(self, counters: Sequence[str], slices: Sequence[slice], body: PsBlock) -> PsLoop: + """Create a loop nest from a sequence of slices. + + **Example:** + This snippet creates a 3D loop nest with ten iterations in each dimension:: + + >>> from pystencils import make_slice + >>> ctx = KernelCreationContext() + >>> factory = AstFactory(ctx) + >>> loop = factory.loop_nest(("i", "j", "k"), make_slice[:10,:10,:10], PsBlock([])) + + Args: + counters: Sequence of names for the loop counters + slices: Sequence of iteration slices; see also `parse_slice` + body: The loop body + """ + if not slices: + raise ValueError( + "At least one slice must be specified to create a loop nest." + ) + + ast = body + for ctr_name, sl in zip(counters[::-1], slices[::-1], strict=True): + ast = self.loop( + ctr_name, + sl, + PsBlock([ast]) if not isinstance(ast, PsBlock) else ast, + ) + + return cast(PsLoop, ast) + + def loops_from_ispace( + self, + ispace: FullIterationSpace, + body: PsBlock, + loop_order: Sequence[int] | None = None, + ) -> PsLoop: + """Create a loop nest from a dense iteration space. + + Args: + ispace: The iteration space object + body: The loop body + loop_order: Optionally, a permutation of integers indicating the order of loops + """ + dimensions = ispace.dimensions + + if loop_order is not None: + dimensions = [dimensions[coordinate] for coordinate in loop_order] + + outer_node: PsLoop | PsBlock = body + + for dimension in dimensions[::-1]: + outer_node = PsLoop( + PsSymbolExpr(dimension.counter), + dimension.start, + dimension.stop, + dimension.step, + ( + outer_node + if isinstance(outer_node, PsBlock) + else PsBlock([outer_node]) + ), + ) + + assert isinstance(outer_node, PsLoop) + return outer_node diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index 5ce373797..d48953a5b 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -106,6 +106,14 @@ class KernelCreationContext: return symb + def find_symbol(self, name: str) -> PsSymbol | None: + """Find a symbol with the given name in the symbol table, if it exists. + + Returns: + The symbol with the given name, or `None` if no such symbol exists. + """ + return self._symbols.get(name, None) + def add_symbol(self, symbol: PsSymbol): if symbol.name in self._symbols: raise PsInternalCompilerError(f"Duplicate symbol: {symbol.name}") diff --git a/src/pystencils/backend/kernelcreation/cpu_optimization.py b/src/pystencils/backend/kernelcreation/cpu_optimization.py index 47db57823..b0156c7e8 100644 --- a/src/pystencils/backend/kernelcreation/cpu_optimization.py +++ b/src/pystencils/backend/kernelcreation/cpu_optimization.py @@ -1,7 +1,9 @@ from __future__ import annotations +from typing import cast from .context import KernelCreationContext from ..platforms import GenericCpu +from ..transformations import CanonicalizeSymbols, HoistLoopInvariantDeclarations from ..ast.structural import PsBlock from ...config import CpuOptimConfig @@ -11,10 +13,19 @@ def optimize_cpu( ctx: KernelCreationContext, platform: GenericCpu, kernel_ast: PsBlock, - cfg: CpuOptimConfig, -): + cfg: CpuOptimConfig | None, +) -> PsBlock: """Carry out CPU-specific optimizations according to the given configuration.""" + canonicalize = CanonicalizeSymbols(ctx, True) + kernel_ast = cast(PsBlock, canonicalize(kernel_ast)) + + hoist_invariants = HoistLoopInvariantDeclarations(ctx) + kernel_ast = cast(PsBlock, hoist_invariants(kernel_ast)) + + if cfg is None: + return kernel_ast + if cfg.loop_blocking: raise NotImplementedError("Loop blocking not implemented yet.") @@ -26,3 +37,5 @@ def optimize_cpu( if cfg.use_cacheline_zeroing: raise NotImplementedError("CL-zeroing not implemented yet") + + return kernel_ast diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 6ce0264e2..a9f760e97 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -1,8 +1,9 @@ from typing import overload, cast, Any from functools import reduce -from operator import add, mul, sub +from operator import add, mul, sub, truediv import sympy as sp +from sympy.codegen.ast import AssignmentBase, AugmentedAssignment from ...sympyextensions import Assignment, AssignmentCollection, integer_functions from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc @@ -68,13 +69,13 @@ class FreezeExpressions: pass @overload - def __call__(self, obj: Assignment) -> PsAssignment: + def __call__(self, obj: AssignmentBase) -> PsAssignment: pass def __call__(self, obj: AssignmentCollection | sp.Basic) -> PsAstNode: if isinstance(obj, AssignmentCollection): return PsBlock([self.visit(asm) for asm in obj.all_assignments]) - elif isinstance(obj, Assignment): + elif isinstance(obj, AssignmentBase): return cast(PsAssignment, self.visit(obj)) elif isinstance(obj, sp.Expr): return cast(PsExpression, self.visit(obj)) @@ -128,6 +129,27 @@ class FreezeExpressions: f"Encountered unsupported expression on assignment left-hand side: {lhs}" ) + def map_AugmentedAssignment(self, expr: AugmentedAssignment): + lhs = self.visit(expr.lhs) + rhs = self.visit(expr.rhs) + + assert isinstance(lhs, PsExpression) + assert isinstance(rhs, PsExpression) + + match expr.op: + case "+=": + op = add + case "-=": + op = sub + case "*=": + op = mul + case "/=": + op = truediv + case _: + raise FreezeError(f"Unsupported augmented assignment: {expr.op}.") + + return PsAssignment(lhs, op(lhs.clone(), rhs)) + def map_Symbol(self, spsym: sp.Symbol) -> PsSymbolExpr: symb = self._ctx.get_symbol(spsym.name) return PsSymbolExpr(symb) diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py index 382adf7b6..5a093031c 100644 --- a/src/pystencils/backend/kernelcreation/iteration_space.py +++ b/src/pystencils/backend/kernelcreation/iteration_space.py @@ -5,8 +5,6 @@ from dataclasses import dataclass from functools import reduce from operator import mul -import sympy as sp - from ...defaults import DEFAULTS from ...sympyextensions import AssignmentCollection from ...field import Field, FieldType @@ -71,8 +69,8 @@ class FullIterationSpace(IterationSpace): @staticmethod def create_with_ghost_layers( ctx: KernelCreationContext, - archetype_field: Field, ghost_layers: int | Sequence[int | tuple[int, int]], + archetype_field: Field, ) -> FullIterationSpace: """Create an iteration space over an archetype field with ghost layers.""" @@ -123,56 +121,52 @@ class FullIterationSpace(IterationSpace): @staticmethod def create_from_slice( ctx: KernelCreationContext, - archetype_field: Field, iteration_slice: Sequence[slice], + archetype_field: Field | None = None, ): - archetype_array = ctx.get_array(archetype_field) - dim = archetype_field.spatial_dimensions - - if len(iteration_slice) != dim: + """Create an iteration space from a sequence of slices, optionally over an archetype field. + + Args: + ctx: The kernel creation context + iteration_slice: The iteration slices for each dimension; for valid formats, see `AstFactory.parse_slice` + archetype_field: Optionally, an archetype field that dictates the upper slice limits and loop order. + """ + dim = len(iteration_slice) + if dim == 0: raise ValueError( - f"Number of dimensions in slice ({len(iteration_slice)}) " - f" did not equal iteration space dimensionality ({dim})" + "At least one slice must be specified to create an iteration space" ) - counters = [ - ctx.get_symbol(name, ctx.index_dtype) - for name in DEFAULTS.spatial_counter_names[:dim] - ] - - from .freeze import FreezeExpressions - from .typification import Typifier - - freeze = FreezeExpressions(ctx) - typifier = Typifier(ctx) + archetype_size: tuple[PsSymbol | PsConstant | None, ...] + if archetype_field is not None: + archetype_array = ctx.get_array(archetype_field) - def expr_convert(expr) -> PsExpression: - if isinstance(expr, int): - return PsConstantExpr(PsConstant(expr, ctx.index_dtype)) - elif isinstance(expr, sp.Expr): - typed_expr, _ = typifier.typify_expression( - freeze.freeze_expression(expr), ctx.index_dtype + if archetype_field.spatial_dimensions != dim: + raise ValueError( + f"Number of dimensions in slice ({len(iteration_slice)}) " + f" did not equal iteration space dimensionality ({dim})" ) - return typed_expr - else: - raise ValueError(f"Invalid entry in slice: {expr}") - def to_dim(slic: slice, size: PsSymbol | PsConstant, ctr: PsSymbol): - size_expr = PsExpression.make(size) + archetype_size = archetype_array.shape[:dim] + else: + archetype_size = (None,) * dim - start = expr_convert(slic.start if slic.start is not None else 0) - stop = expr_convert(slic.stop) if slic.stop is not None else size_expr - step = expr_convert(slic.step if slic.step is not None else 1) + counters = [ + ctx.get_symbol(name, ctx.index_dtype) + for name in DEFAULTS.spatial_counter_names[:dim] + ] - if isinstance(slic.stop, int) and slic.stop < 0: - stop = size_expr + stop # todo + from .ast_factory import AstFactory + factory = AstFactory(ctx) + def to_dim(slic: slice, size: PsSymbol | PsConstant | None, ctr: PsSymbol): + start, stop, step = factory.parse_slice(slic, size) return FullIterationSpace.Dimension(start, stop, step, ctr) dimensions = [ to_dim(slic, size, ctr) for slic, size, ctr in zip( - iteration_slice, archetype_array.shape[:dim], counters, strict=True + iteration_slice, archetype_size, counters, strict=True ) ] @@ -399,13 +393,13 @@ def create_full_iteration_space( if ghost_layers is not None: return FullIterationSpace.create_with_ghost_layers( - ctx, archetype_field, ghost_layers + ctx, ghost_layers, archetype_field ) elif iteration_slice is not None: return FullIterationSpace.create_from_slice( - ctx, archetype_field, iteration_slice + ctx, iteration_slice, archetype_field ) else: return FullIterationSpace.create_with_ghost_layers( - ctx, archetype_field, inferred_gls + ctx, inferred_gls, archetype_field ) diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 190cd9e23..707b02c66 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -56,7 +56,7 @@ NodeT = TypeVar("NodeT", bound=PsAstNode) class TypeContext: """Typing context, with support for type inference and checking. - + Instances of this class are used to propagate and check data types across expression subtrees of the AST. Each type context has: @@ -185,7 +185,7 @@ class TypeContext: def _compatible(self, dtype: PsType): """Checks whether the given data type is compatible with the context's target type. - + If the target type is ``const``, they must be equal up to const qualification; if the target type is not ``const``, `dtype` must match it exactly. """ @@ -248,7 +248,7 @@ class Typifier: Some expressions (`PsSymbolExpr`, `PsArrayAccess`) encapsulate symbols and inherit their data types, but not necessarily their const-qualification. - A symbol with non-``const`` type may occur in a `PsSymbolExpr` with ``const`` type, + A symbol with non-``const`` type may occur in a `PsSymbolExpr` with ``const`` type, and an array base pointer with non-``const`` base type may be nested in a ``const`` `PsArrayAccess`, but not vice versa. """ @@ -321,7 +321,7 @@ class Typifier: def visit_expr(self, expr: PsExpression, tc: TypeContext) -> None: """Recursive processing of expression nodes. - + This method opens, expands, and closes typing contexts according to the respective expression's typing rules. It may add or check restrictions only when opening or closing a type context. diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py index a83bad4ba..6899ac947 100644 --- a/src/pystencils/backend/platforms/generic_cpu.py +++ b/src/pystencils/backend/platforms/generic_cpu.py @@ -7,6 +7,7 @@ from ...types import PsType, PsIeeeFloatType from .platform import Platform from ..exceptions import MaterializationError +from ..kernelcreation import AstFactory from ..kernelcreation.iteration_space import ( IterationSpace, FullIterationSpace, @@ -76,28 +77,17 @@ class GenericCpu(Platform): def _create_domain_loops( self, body: PsBlock, ispace: FullIterationSpace ) -> PsBlock: - - dimensions = ispace.dimensions + factory = AstFactory(self._ctx) # Determine loop order by permuting dimensions archetype_field = ispace.archetype_field if archetype_field is not None: loop_order = archetype_field.layout - dimensions = [dimensions[coordinate] for coordinate in loop_order] - - outer_block = body - - for dimension in dimensions[::-1]: - loop = PsLoop( - PsSymbolExpr(dimension.counter), - dimension.start, - dimension.stop, - dimension.step, - outer_block, - ) - outer_block = PsBlock([loop]) + else: + loop_order = None - return outer_block + loops = factory.loops_from_ispace(ispace, body, loop_order) + return PsBlock([loops]) def _create_sparse_loop(self, body: PsBlock, ispace: SparseIterationSpace): mappings = [ diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py index 64c0cd3e9..27fcdfac2 100644 --- a/src/pystencils/backend/platforms/generic_gpu.py +++ b/src/pystencils/backend/platforms/generic_gpu.py @@ -56,8 +56,10 @@ class GenericGpu(Platform): ] return indices[:dim] - - def select_function(self, math_function: PsMathFunction, dtype: PsType) -> CFunction: + + def select_function( + self, math_function: PsMathFunction, dtype: PsType + ) -> CFunction: raise NotImplementedError() # Internals diff --git a/src/pystencils/backend/transformations/__init__.py b/src/pystencils/backend/transformations/__init__.py index 8ef35e4fb..afb1e4fcd 100644 --- a/src/pystencils/backend/transformations/__init__.py +++ b/src/pystencils/backend/transformations/__init__.py @@ -1,10 +1,14 @@ from .eliminate_constants import EliminateConstants +from .canonicalize_symbols import CanonicalizeSymbols +from .hoist_loop_invariant_decls import HoistLoopInvariantDeclarations from .erase_anonymous_structs import EraseAnonymousStructTypes from .select_functions import SelectFunctions from .select_intrinsics import MaterializeVectorIntrinsics __all__ = [ "EliminateConstants", + "CanonicalizeSymbols", + "HoistLoopInvariantDeclarations", "EraseAnonymousStructTypes", "SelectFunctions", "MaterializeVectorIntrinsics", diff --git a/src/pystencils/backend/transformations/canonicalize_symbols.py b/src/pystencils/backend/transformations/canonicalize_symbols.py new file mode 100644 index 000000000..6fe922f28 --- /dev/null +++ b/src/pystencils/backend/transformations/canonicalize_symbols.py @@ -0,0 +1,125 @@ +from itertools import count + +from ..kernelcreation import KernelCreationContext +from ..symbols import PsSymbol +from ..exceptions import PsInternalCompilerError + +from ..ast import PsAstNode +from ..ast.structural import PsDeclaration, PsAssignment, PsLoop, PsConditional, PsBlock +from ..ast.expressions import PsSymbolExpr, PsExpression + +from ...types import constify + +__all__ = ["CanonicalizeSymbols"] + + +class CanonContext: + def __init__(self, ctx: KernelCreationContext) -> None: + self._ctx = ctx + self.encountered_symbols: set[PsSymbol] = set() + self.live_symbols_map: dict[PsSymbol, PsSymbol] = dict() + + self.updated_symbols: set[PsSymbol] = set() + + def deduplicate(self, symb: PsSymbol) -> PsSymbol: + if symb in self.live_symbols_map: + return self.live_symbols_map[symb] + elif symb not in self.encountered_symbols: + self.encountered_symbols.add(symb) + self.live_symbols_map[symb] = symb + return symb + else: + for i in count(): + replacement_name = f"{symb.name}__{i}" + if self._ctx.find_symbol(replacement_name) is None: + replacement = self._ctx.get_symbol(replacement_name, symb.dtype) + self.live_symbols_map[symb] = replacement + self.encountered_symbols.add(replacement) + return replacement + assert False, "unreachable code" + + def mark_as_updated(self, symb: PsSymbol): + self.updated_symbols.add(self.deduplicate(symb)) + + def is_live(self, symb: PsSymbol) -> bool: + return symb in self.live_symbols_map + + def end_lifespan(self, symb: PsSymbol): + if symb in self.live_symbols_map: + del self.live_symbols_map[symb] + + +class CanonicalizeSymbols: + """Remove duplicate symbol declarations and declare all non-updated symbols ``const``. + + The `CanonicalizeSymbols` pass will remove multiple declarations of the same symbol by + renaming all but the last occurence, and will optionally ``const``-qualify all symbols + encountered in the AST that are never updated. + """ + + def __init__(self, ctx: KernelCreationContext, constify: bool = True) -> None: + self._ctx = ctx + self._constify = constify + self._last_result: CanonContext | None = None + + def get_last_live_symbols(self) -> set[PsSymbol]: + if self._last_result is None: + raise PsInternalCompilerError("Pass was not executed yet") + return set(self._last_result.live_symbols_map.values()) + + def __call__(self, node: PsAstNode) -> PsAstNode: + cc = CanonContext(self._ctx) + self.visit(node, cc) + + # Any symbol encountered but never updated can be marked const + if self._constify: + for symb in cc.encountered_symbols - cc.updated_symbols: + if symb.dtype is not None: + symb.dtype = constify(symb.dtype) + + # Any symbols still alive now are function params or globals + # Might use that to populate KernelFunction + self._last_result = cc + + return node + + def visit(self, node: PsAstNode, cc: CanonContext): + """Traverse the AST in reverse pre-order to collect, deduplicate, and maybe constify all live symbols.""" + + match node: + case PsSymbolExpr(symb): + node.symbol = cc.deduplicate(symb) + return node + + case PsExpression(): + for c in node.children: + self.visit(c, cc) + + case PsDeclaration(lhs, rhs): + decl_symb = node.declared_symbol + self.visit(lhs, cc) + self.visit(rhs, cc) + cc.end_lifespan(decl_symb) + + case PsAssignment(lhs, rhs): + self.visit(lhs, cc) + self.visit(rhs, cc) + + if isinstance(lhs, PsSymbolExpr): + cc.mark_as_updated(lhs.symbol) + + case PsLoop(ctr, _, _, _, _): + for c in node.children[::-1]: + self.visit(c, cc) + cc.mark_as_updated(ctr.symbol) + cc.end_lifespan(ctr.symbol) + + case PsConditional(cond, then, els): + if els is not None: + self.visit(els, cc) + self.visit(then, cc) + self.visit(cond, cc) + + case PsBlock(statements): + for stmt in statements[::-1]: + self.visit(stmt, cc) diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py index 808743d69..22ad740fa 100644 --- a/src/pystencils/backend/transformations/eliminate_constants.py +++ b/src/pystencils/backend/transformations/eliminate_constants.py @@ -20,7 +20,6 @@ from ..ast.util import AstEqWrapper from ..constants import PsConstant from ..symbols import PsSymbol from ...types import PsIntegerType, PsIeeeFloatType, PsTypeError -from ..emission import CAstPrinter __all__ = ["EliminateConstants"] @@ -32,6 +31,9 @@ class ECContext: self._extracted_constants: dict[AstEqWrapper, PsSymbol] = dict() self._typifier = Typifier(ctx) + + from ..emission import CAstPrinter + self._printer = CAstPrinter(0) @property diff --git a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py new file mode 100644 index 000000000..592003815 --- /dev/null +++ b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py @@ -0,0 +1,181 @@ +from typing import cast + +from ..kernelcreation import KernelCreationContext +from ..ast import PsAstNode +from ..ast.structural import PsBlock, PsLoop, PsConditional, PsDeclaration, PsAssignment +from ..ast.expressions import ( + PsExpression, + PsSymbolExpr, + PsConstantExpr, + PsCall, + PsDeref, + PsSubscript, + PsUnOp, + PsBinOp, + PsArrayInitList, +) + +from ...types import PsDereferencableType +from ..symbols import PsSymbol +from ..functions import PsMathFunction + +__all__ = ["HoistLoopInvariantDeclarations"] + + +class HoistContext: + def __init__(self) -> None: + self.hoisted_nodes: list[PsDeclaration] = [] + self.assigned_symbols: set[PsSymbol] = set() + self.invariant_symbols: set[PsSymbol] = set() + + def _is_invariant(self, expr: PsExpression) -> bool: + def args_invariant(expr): + return all( + self._is_invariant(cast(PsExpression, arg)) for arg in expr.children + ) + + match expr: + case PsSymbolExpr(symbol): + return (symbol not in self.assigned_symbols) or ( + symbol in self.invariant_symbols + ) + + case PsConstantExpr(): + return True + + case PsCall(func): + return isinstance(func, PsMathFunction) and args_invariant(expr) + + case PsSubscript(ptr, _) | PsDeref(ptr): + ptr_type = cast(PsDereferencableType, ptr.get_dtype()) + return ptr_type.base_type.const and args_invariant(expr) + + case PsUnOp() | PsBinOp() | PsArrayInitList(): + return args_invariant(expr) + + case _: + return False + + +class HoistLoopInvariantDeclarations: + """Hoist loop-invariant declarations out of the loop nest. + + This transformation moves loop-invariant symbol declarations outside of the loop + nest to prevent their repeated execution within the loops. + If this transformation results in the complete elimination of a loop body, the respective loop + is removed. + + `HoistLoopInvariantDeclarations` assumes that symbols are canonical; + in particular, each symbol may have at most one declaration. + To ensure this, a `CanonicalizeSymbols` pass should be run before `HoistLoopInvariantDeclarations`. + + `HoistLoopInvariants` assumes that all `PsMathFunction`s are pure (have no side effects), + but makes no such assumption about instances of `CFunction`. + """ + + def __init__(self, ctx: KernelCreationContext): + self._ctx = ctx + + def __call__(self, node: PsAstNode) -> PsAstNode: + return self.visit(node) + + def visit(self, node: PsAstNode) -> PsAstNode: + """Search the outermost loop and start the hoisting cascade there.""" + match node: + case PsLoop(): + temp_block = PsBlock([node]) + temp_block = cast(PsBlock, self.visit(temp_block)) + if temp_block.statements == [node]: + return node + else: + return temp_block + + case PsBlock(statements): + statements_new: list[PsAstNode] = [] + for stmt in statements: + if isinstance(stmt, PsLoop): + loop = stmt + hc = self._hoist(loop) + statements_new += hc.hoisted_nodes + if loop.body.statements: + statements_new.append(loop) + else: + self.visit(stmt) + statements_new.append(stmt) + + node.statements = statements_new + return node + + case PsConditional(_, then, els): + self.visit(then) + if els is not None: + self.visit(els) + return node + + case _: + # if the node is none of the above, end the search + return node + + # end match + + def _hoist(self, loop: PsLoop) -> HoistContext: + """Hoist invariant declarations out of the given loop.""" + hc = HoistContext() + hc.assigned_symbols.add(loop.counter.symbol) + self._prepare_hoist(loop.body, hc) + self._hoist_from_block(loop.body, hc) + return hc + + def _prepare_hoist(self, node: PsAstNode, hc: HoistContext): + """Collect all symbols assigned within a loop body, + and recursively apply loop-invariant code motion to any nested loops.""" + match node: + case PsExpression(): + return + + case PsAssignment(PsSymbolExpr(lhs_symb), _): + hc.assigned_symbols.add(lhs_symb) + + case PsAssignment(_, _): + return + + case PsBlock(statements): + statements_new: list[PsAstNode] = [] + for stmt in statements: + if isinstance(stmt, PsLoop): + loop = stmt + nested_hc = self._hoist(loop) + hc.assigned_symbols |= nested_hc.assigned_symbols + statements_new += nested_hc.hoisted_nodes + if loop.body.statements: + statements_new.append(loop) + else: + self._prepare_hoist(stmt, hc) + statements_new.append(stmt) + node.statements = statements_new + + case _: + for c in node.children: + self._prepare_hoist(c, hc) + + def _hoist_from_block(self, block: PsBlock, hc: HoistContext): + """Hoist invariant declarations from the given block, and any directly nested blocks. + + This method processes only statements of the given block, and any blocks directly nested inside it. + It does not descend into control structures like conditionals and nested loops. + """ + statements_new: list[PsAstNode] = [] + + for node in block.statements: + if isinstance(node, PsDeclaration): + if hc._is_invariant(node.rhs): + hc.hoisted_nodes.append(node) + hc.invariant_symbols.add(node.declared_symbol) + else: + statements_new.append(node) + else: + if isinstance(node, PsBlock): + self._hoist_from_block(node, hc) + statements_new.append(node) + + block.statements = statements_new diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py index 3f17e66bd..0f6941cf5 100644 --- a/src/pystencils/kernelcreation.py +++ b/src/pystencils/kernelcreation.py @@ -25,7 +25,11 @@ from .backend.kernelcreation.iteration_space import ( ) from .backend.ast.analysis import collect_required_headers, collect_undefined_symbols -from .backend.transformations import EraseAnonymousStructTypes, EliminateConstants, SelectFunctions +from .backend.transformations import ( + EliminateConstants, + EraseAnonymousStructTypes, + SelectFunctions, +) from .sympyextensions import AssignmentCollection, Assignment @@ -38,7 +42,7 @@ def create_kernel( config: CreateKernelConfig = CreateKernelConfig(), ) -> KernelFunction: """Create a kernel function from a set of assignments. - + Args: assignments: The kernel's sequence of assignments, expressed using SymPy config: The configuration for the kernel translator @@ -84,6 +88,7 @@ def create_kernel( match config.target: case Target.GenericCPU: from .backend.platforms import GenericCpu + platform = GenericCpu(ctx) case _: # TODO: CUDA/HIP platform @@ -96,13 +101,13 @@ def create_kernel( elim_constants = EliminateConstants(ctx, extract_constant_exprs=True) kernel_ast = cast(PsBlock, elim_constants(kernel_ast)) - erase_anons = EraseAnonymousStructTypes(ctx) - kernel_ast = cast(PsBlock, erase_anons(kernel_ast)) - # Target-Specific optimizations - if config.target.is_cpu() and config.cpu_optim is not None: + if config.target.is_cpu(): from .backend.kernelcreation import optimize_cpu - optimize_cpu(ctx, platform, kernel_ast, config.cpu_optim) + kernel_ast = optimize_cpu(ctx, platform, kernel_ast, config.cpu_optim) + + erase_anons = EraseAnonymousStructTypes(ctx) + kernel_ast = cast(PsBlock, erase_anons(kernel_ast)) select_functions = SelectFunctions(platform) kernel_ast = cast(PsBlock, select_functions(kernel_ast)) diff --git a/tests/nbackend/kernelcreation/platform/test_basic_cpu.py b/tests/nbackend/kernelcreation/platform/test_basic_cpu.py index 7bfcd4e42..79043c880 100644 --- a/tests/nbackend/kernelcreation/platform/test_basic_cpu.py +++ b/tests/nbackend/kernelcreation/platform/test_basic_cpu.py @@ -23,7 +23,7 @@ def test_loop_nest(layout): # FZYX Order archetype_field = Field.create_generic("field", spatial_dimensions=3, layout=layout) - ispace = FullIterationSpace.create_with_ghost_layers(ctx, archetype_field, 0) + ispace = FullIterationSpace.create_with_ghost_layers(ctx, 0, archetype_field) loop_nest = platform.materialize_iteration_space(body, ispace) diff --git a/tests/nbackend/kernelcreation/platform/test_basic_gpu.py b/tests/nbackend/kernelcreation/platform/test_basic_gpu.py index df0e48fb0..e47f38e4d 100644 --- a/tests/nbackend/kernelcreation/platform/test_basic_gpu.py +++ b/tests/nbackend/kernelcreation/platform/test_basic_gpu.py @@ -23,6 +23,6 @@ def test_loop_nest(layout): # FZYX Order archetype_field = Field.create_generic("fzyx_field", spatial_dimensions=3, layout=layout) - ispace = FullIterationSpace.create_with_ghost_layers(ctx, archetype_field, 0) + ispace = FullIterationSpace.create_with_ghost_layers(ctx, 0, archetype_field) condition = platform.materialize_iteration_space(body, ispace) diff --git a/tests/nbackend/kernelcreation/test_iteration_space.py b/tests/nbackend/kernelcreation/test_iteration_space.py index 6b4145e98..1dfcfea2a 100644 --- a/tests/nbackend/kernelcreation/test_iteration_space.py +++ b/tests/nbackend/kernelcreation/test_iteration_space.py @@ -1,6 +1,5 @@ import pytest -from pystencils.defaults import DEFAULTS from pystencils.field import Field from pystencils.sympyextensions.typed_sympy import TypedSymbol, create_type @@ -8,6 +7,7 @@ from pystencils.backend.kernelcreation import KernelCreationContext, FullIterati from pystencils.backend.ast.expressions import PsAdd, PsConstantExpr, PsExpression from pystencils.backend.kernelcreation.typification import TypificationError +from pystencils.types import PsTypeError def test_slices(): @@ -17,7 +17,7 @@ def test_slices(): ctx.add_field(archetype_field) islice = (slice(1, -1, 1), slice(3, -3, 3), slice(0, None, -1)) - ispace = FullIterationSpace.create_from_slice(ctx, archetype_field, islice) + ispace = FullIterationSpace.create_from_slice(ctx, islice, archetype_field) archetype_arr = ctx.get_array(archetype_field) @@ -52,9 +52,9 @@ def test_invalid_slices(): ctx.add_field(archetype_field) islice = (slice(1, -1, 0.5),) - with pytest.raises(ValueError): - FullIterationSpace.create_from_slice(ctx, archetype_field, islice) + with pytest.raises(PsTypeError): + FullIterationSpace.create_from_slice(ctx, islice, archetype_field) islice = (slice(1, -1, TypedSymbol("w", dtype=create_type("double"))),) with pytest.raises(TypificationError): - FullIterationSpace.create_from_slice(ctx, archetype_field, islice) + FullIterationSpace.create_from_slice(ctx, islice, archetype_field) diff --git a/tests/nbackend/transformations/test_canonicalize_symbols.py b/tests/nbackend/transformations/test_canonicalize_symbols.py new file mode 100644 index 000000000..43f269163 --- /dev/null +++ b/tests/nbackend/transformations/test_canonicalize_symbols.py @@ -0,0 +1,88 @@ +# type: ignore +import sympy as sp + +from pystencils import Field, Assignment, AddAugmentedAssignment, make_slice, DEFAULTS + +from pystencils.backend.kernelcreation import ( + KernelCreationContext, + AstFactory, + FullIterationSpace, +) +from pystencils.backend.transformations import CanonicalizeSymbols +from pystencils.backend.ast.structural import PsConditional, PsBlock + + +def test_deduplication(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + canonicalize = CanonicalizeSymbols(ctx) + + f = Field.create_fixed_size("f", (5, 5), strides=(5, 1)) + x, y, z = sp.symbols("x, y, z") + + ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:, :], f) + ctx.set_iteration_space(ispace) + + ctr_1 = DEFAULTS.spatial_counters[1] + + then_branch = PsBlock( + [ + factory.parse_sympy(Assignment(x, y)), + factory.parse_sympy(Assignment(f.center(0), x)), + ] + ) + + else_branch = PsBlock( + [ + factory.parse_sympy(Assignment(x, z)), + factory.parse_sympy(Assignment(f.center(0), x)), + ] + ) + + ast = PsConditional( + factory.parse_sympy(ctr_1), + then_branch, + else_branch, + ) + + ast = factory.loops_from_ispace(ispace, PsBlock([ast])) + + ast = canonicalize(ast) + + assert canonicalize.get_last_live_symbols() == { + ctx.find_symbol("y"), + ctx.find_symbol("z"), + ctx.get_array(f).base_pointer, + } + + assert ctx.find_symbol("x") is not None + assert ctx.find_symbol("x__0") is not None + + assert then_branch.statements[0].declared_symbol.name == "x__0" + assert then_branch.statements[1].rhs.symbol.name == "x__0" + + assert else_branch.statements[0].declared_symbol.name == "x" + assert else_branch.statements[1].rhs.symbol.name == "x" + + assert ctx.find_symbol("x").dtype.const + assert ctx.find_symbol("x__0").dtype.const + assert ctx.find_symbol("y").dtype.const + assert ctx.find_symbol("z").dtype.const + + +def test_do_not_constify(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + canonicalize = CanonicalizeSymbols(ctx) + + x, z = sp.symbols("x, z") + + ast = factory.loop("i", make_slice[:10], PsBlock([ + factory.parse_sympy(Assignment(x, z)), + factory.parse_sympy(AddAugmentedAssignment(z, 1)) + ])) + + ast = canonicalize(ast) + + assert ctx.find_symbol("x").dtype.const + assert not ctx.find_symbol("z").dtype.const diff --git a/tests/nbackend/transformations/test_hoist_invariants.py b/tests/nbackend/transformations/test_hoist_invariants.py new file mode 100644 index 000000000..db78efce5 --- /dev/null +++ b/tests/nbackend/transformations/test_hoist_invariants.py @@ -0,0 +1,195 @@ +import sympy as sp + +from pystencils import ( + Field, + TypedSymbol, + Assignment, + AddAugmentedAssignment, + make_slice, +) +from pystencils.types.quick import Arr, Fp, Bool + +from pystencils.backend.ast.structural import ( + PsBlock, + PsLoop, + PsConditional, + PsDeclaration, +) + +from pystencils.backend.kernelcreation import ( + KernelCreationContext, + AstFactory, + FullIterationSpace, +) +from pystencils.backend.transformations import ( + CanonicalizeSymbols, + HoistLoopInvariantDeclarations, +) + + +def test_hoist_multiple_loops(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + canonicalize = CanonicalizeSymbols(ctx) + hoist = HoistLoopInvariantDeclarations(ctx) + + f = Field.create_fixed_size("f", (5, 5), strides=(5, 1)) + x, y, z = sp.symbols("x, y, z") + + ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:, :], f) + ctx.set_iteration_space(ispace) + + first_loop = factory.loops_from_ispace( + ispace, + PsBlock( + [ + factory.parse_sympy(Assignment(x, y)), + factory.parse_sympy(Assignment(f.center(0), x)), + ] + ), + ) + + second_loop = factory.loops_from_ispace( + ispace, + PsBlock( + [ + factory.parse_sympy(Assignment(x, z)), + factory.parse_sympy(Assignment(f.center(0), x)), + ] + ), + ) + + ast = PsBlock([first_loop, second_loop]) + + result = canonicalize(ast) + result = hoist(result) + + assert isinstance(result, PsBlock) + + assert ( + isinstance(result.statements[0], PsDeclaration) + and result.statements[0].declared_symbol.name == "x__0" + ) + + assert result.statements[1] == first_loop + + assert ( + isinstance(result.statements[2], PsDeclaration) + and result.statements[2].declared_symbol.name == "x" + ) + + assert result.statements[3] == second_loop + + +def test_hoist_with_recurrence(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + hoist = HoistLoopInvariantDeclarations(ctx) + + x, y = sp.symbols("x, y") + x_decl = factory.parse_sympy(Assignment(x, 1)) + x_update = factory.parse_sympy(AddAugmentedAssignment(x, 1)) + y_decl = factory.parse_sympy(Assignment(y, 2 * x)) + + loop = factory.loop("i", make_slice[0:10:1], PsBlock([y_decl, x_update])) + + ast = PsBlock([x_decl, loop]) + + result = hoist(ast) + + # x is updated in the loop, so nothing can be hoisted + assert isinstance(result, PsBlock) + assert result.statements == [x_decl, loop] + assert loop.body.statements == [y_decl, x_update] + + +def test_hoist_with_conditionals(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + hoist = HoistLoopInvariantDeclarations(ctx) + + x, y, z, w = sp.symbols("x, y, z, w") + x_decl = factory.parse_sympy(Assignment(x, 1)) + x_update = factory.parse_sympy(AddAugmentedAssignment(x, 1)) + y_decl = factory.parse_sympy(Assignment(y, 2 * x)) + z_decl = factory.parse_sympy(Assignment(z, 312)) + w_decl = factory.parse_sympy(Assignment(w, 142)) + + cond = factory.parse_sympy(TypedSymbol("cond", Bool())) + + inner_conditional = PsConditional(cond, PsBlock([x_update, z_decl])) + loop = factory.loop( + "i", + make_slice[0:10:1], + PsBlock([y_decl, w_decl, inner_conditional]), + ) + outer_conditional = PsConditional(cond, PsBlock([loop])) + + ast = PsBlock([x_decl, outer_conditional]) + + result = hoist(ast) + + # z is hidden inside conditional, so z cannot be hoisted + # x is updated conditionally, so y cannot be hoisted + assert isinstance(result, PsBlock) + assert result.statements == [x_decl, outer_conditional] + assert outer_conditional.branch_true.statements == [w_decl, loop] + assert loop.body.statements == [y_decl, inner_conditional] + + +def test_hoist_arrays(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + hoist = HoistLoopInvariantDeclarations(ctx) + + const_arr_symb = TypedSymbol( + "const_arr", + Arr(Fp(64, const=True), 10), + ) + const_array_decl = factory.parse_sympy(Assignment(const_arr_symb, tuple(range(10)))) + const_arr = sp.IndexedBase(const_arr_symb, shape=(10,)) + + arr_symb = TypedSymbol( + "arr", + Arr(Fp(64, const=False), 10), + ) + array_decl = factory.parse_sympy(Assignment(arr_symb, tuple(range(10)))) + arr = sp.IndexedBase(arr_symb, shape=(10,)) + + x, y = sp.symbols("x, y") + + nonconst_usage = factory.parse_sympy(Assignment(x, arr[3])) + const_usage = factory.parse_sympy(Assignment(y, const_arr[3])) + body = PsBlock([array_decl, const_array_decl, nonconst_usage, const_usage]) + + loop = factory.loop_nest(("i", "j"), make_slice[:10, :42], body) + + result = hoist(loop) + + assert isinstance(result, PsBlock) + assert result.statements == [array_decl, const_array_decl, const_usage, loop] + + assert isinstance(loop.body.statements[0], PsLoop) + assert loop.body.statements[0].body.statements == [nonconst_usage] + + +def test_hoisting_eliminates_loops(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + hoist = HoistLoopInvariantDeclarations(ctx) + + x, y, z = sp.symbols("x, y, z") + + invariant_decls = [ + factory.parse_sympy(Assignment(x, 42)), + factory.parse_sympy(Assignment(y, 2 * x)), + factory.parse_sympy(Assignment(z, x + 4 * y)), + ] + + ast = factory.loop_nest(("i", "j"), make_slice[:10, :42], PsBlock(invariant_decls)) + + ast = hoist(ast) + + assert isinstance(ast, PsBlock) + # All statements are hoisted and the loops are removed + assert ast.statements == invariant_decls -- GitLab