diff --git a/docs/source/backend/index.rst b/docs/source/backend/index.rst index e9ac5237b0565f08d8d604831df9448e24447b3f..1e3968bc0e4137b7a43796911de646cdddcb9ab6 100644 --- a/docs/source/backend/index.rst +++ b/docs/source/backend/index.rst @@ -21,7 +21,7 @@ Internal Representation The code generator translates the kernel from the SymPy frontend's symbolic language to an internal representation (IR), which is then emitted as code in the required dialect of C. -All names of classes associated with the internal kernel representation are prefixed `Ps...` +All names of classes associated with the internal kernel representation are prefixed ``Ps...`` to distinguis them from identically named front-end and SymPy classes. The IR comprises *symbols*, *constants*, *arrays*, the *iteration space* and the *abstract syntax tree*: diff --git a/docs/source/backend/translation.rst b/docs/source/backend/translation.rst index 1576759806d9c287d08929be5b3e827a1853051f..a4c7d36b58701f76e7e35283e12d01e705af1890 100644 --- a/docs/source/backend/translation.rst +++ b/docs/source/backend/translation.rst @@ -5,6 +5,9 @@ Kernel Translation .. autoclass:: pystencils.backend.kernelcreation.KernelCreationContext :members: +.. autoclass:: pystencils.backend.kernelcreation.AstFactory + :members: + .. autoclass:: pystencils.backend.kernelcreation.KernelAnalysis :members: diff --git a/src/pystencils/backend/ast/analysis.py b/src/pystencils/backend/ast/analysis.py index a6ef04ebda491c1d41fbdfcc7473de4ab72b48bc..0ea13c563e244c7dda9a47b9ba55d7eac4a0df1c 100644 --- a/src/pystencils/backend/ast/analysis.py +++ b/src/pystencils/backend/ast/analysis.py @@ -18,14 +18,14 @@ from ..exceptions import PsInternalCompilerError class UndefinedSymbolsCollector: - """Collector for undefined variables. + """Collect undefined symbols. - This class implements an AST visitor that collects all `PsTypedVariable`s that have been used + This class implements an AST visitor that collects all symbols that have been used in the AST without being defined prior to their usage. """ def __call__(self, node: PsAstNode) -> set[PsSymbol]: - """Returns all `PsTypedVariable`s that occur in the given AST without being defined prior to their usage.""" + """Returns all symbols that occur in the given AST without being defined prior to their usage.""" return self.visit(node) def visit(self, node: PsAstNode) -> set[PsSymbol]: @@ -79,8 +79,8 @@ class UndefinedSymbolsCollector: """Returns the set of variables declared by the given node which are visible in the enclosing scope.""" match node: - case PsDeclaration(lhs, _): - return {lhs.symbol} + case PsDeclaration(): + return {node.declared_symbol} case ( PsAssignment() diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index 8a66457a9f516f5089c75888c4cb753a296005b2..7c743a3997071a2f1515dd5802ee6f69cb741375 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -21,7 +21,7 @@ from .astnode import PsAstNode, PsLeafMixIn class PsExpression(PsAstNode, ABC): """Base class for all expressions. - + **Types:** Each expression should be annotated with its type. Upon construction, the `dtype` property of most expression nodes is unset; only constant expressions, symbol expressions, and array accesses immediately inherit their type from @@ -271,7 +271,7 @@ class PsVectorArrayAccess(PsArrayAccess): @property def alignment(self) -> int: return self._alignment - + def get_vector_type(self) -> PsVectorType: return cast(PsVectorType, self._dtype) diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py index 441faa606fd615e75cdcf39db167399254fdec45..47342cfedcb7c748f8c460ae99512ec64800ce61 100644 --- a/src/pystencils/backend/ast/structural.py +++ b/src/pystencils/backend/ast/structural.py @@ -4,6 +4,7 @@ from types import NoneType from .astnode import PsAstNode, PsLeafMixIn from .expressions import PsExpression, PsLvalue, PsSymbolExpr +from ..symbols import PsSymbol from .util import failing_cast @@ -121,7 +122,7 @@ class PsAssignment(PsAstNode): class PsDeclaration(PsAssignment): __match_args__ = ( - "declared_variable", + "lhs", "rhs", ) @@ -137,12 +138,8 @@ class PsDeclaration(PsAssignment): self._lhs = failing_cast(PsSymbolExpr, lvalue) @property - def declared_variable(self) -> PsSymbolExpr: - return cast(PsSymbolExpr, self._lhs) - - @declared_variable.setter - def declared_variable(self, lvalue: PsSymbolExpr): - self._lhs = lvalue + def declared_symbol(self) -> PsSymbol: + return cast(PsSymbolExpr, self._lhs).symbol def clone(self) -> PsDeclaration: return PsDeclaration(cast(PsSymbolExpr, self._lhs.clone()), self.rhs.clone()) diff --git a/src/pystencils/backend/constants.py b/src/pystencils/backend/constants.py index 125c1149ba7fec7ce279afef089229f70c75eb18..b867d89d34597ef524c36a1eb7e720b2dadf0cd2 100644 --- a/src/pystencils/backend/constants.py +++ b/src/pystencils/backend/constants.py @@ -7,10 +7,10 @@ from .exceptions import PsInternalCompilerError class PsConstant: """Type-safe representation of typed numerical constants. - + This class models constants in the backend representation of kernels. A constant may be *untyped*, in which case its ``value`` may be any Python object. - + If the constant is *typed* (i.e. its ``dtype`` is not ``None``), its data type is used to check the validity of its ``value`` and to convert it into the type's internal representation. @@ -36,19 +36,19 @@ class PsConstant: def interpret_as(self, dtype: PsNumericType) -> PsConstant: """Interprets this *untyped* constant with the given data type. - + If this constant is already typed, raises an error. """ if self._dtype is not None: raise PsInternalCompilerError( f"Cannot interpret already typed constant {self} with type {dtype}" ) - + return PsConstant(self._value, dtype) - + def reinterpret_as(self, dtype: PsNumericType) -> PsConstant: """Reinterprets this constant with the given data type. - + Other than `interpret_as`, this method also works on typed constants. """ return PsConstant(self._value, dtype) @@ -60,7 +60,7 @@ class PsConstant: @property def dtype(self) -> PsNumericType | None: """This constant's data type, or ``None`` if it is untyped. - + The data type of a constant always has ``const == True``. """ return self._dtype diff --git a/src/pystencils/backend/emission.py b/src/pystencils/backend/emission.py index b742c598db27e46d94655e309f277531bdcb75eb..aa5f853a73301f118e8fea9e5654237392867fc9 100644 --- a/src/pystencils/backend/emission.py +++ b/src/pystencils/backend/emission.py @@ -169,7 +169,7 @@ class CAstPrinter: return pc.indent(f"{self.visit(expr, pc)};") case PsDeclaration(lhs, rhs): - lhs_symb = lhs.symbol + lhs_symb = node.declared_symbol lhs_code = self._symbol_decl(lhs_symb) rhs_code = self.visit(rhs, pc) diff --git a/src/pystencils/backend/kernelcreation/__init__.py b/src/pystencils/backend/kernelcreation/__init__.py index 1cbddab4f9dea32a30b95f211db935a542840bbe..5de83caadb3b4aa50112ef2b65c28c1ca7932aae 100644 --- a/src/pystencils/backend/kernelcreation/__init__.py +++ b/src/pystencils/backend/kernelcreation/__init__.py @@ -2,6 +2,7 @@ from .context import KernelCreationContext from .analysis import KernelAnalysis from .freeze import FreezeExpressions from .typification import Typifier +from .ast_factory import AstFactory from .iteration_space import ( FullIterationSpace, @@ -17,6 +18,7 @@ __all__ = [ "KernelAnalysis", "FreezeExpressions", "Typifier", + "AstFactory", "FullIterationSpace", "SparseIterationSpace", "create_full_iteration_space", diff --git a/src/pystencils/backend/kernelcreation/ast_factory.py b/src/pystencils/backend/kernelcreation/ast_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..b9bbe8cce84dca53a02e2297e0e8cb5199a25b26 --- /dev/null +++ b/src/pystencils/backend/kernelcreation/ast_factory.py @@ -0,0 +1,187 @@ +from typing import Any, Sequence, cast, overload + +import sympy as sp +from sympy.codegen.ast import AssignmentBase + +from ..ast import PsAstNode +from ..ast.expressions import PsExpression, PsSymbolExpr +from ..ast.structural import PsLoop, PsBlock, PsAssignment + +from ..symbols import PsSymbol +from ..constants import PsConstant + +from .context import KernelCreationContext +from .freeze import FreezeExpressions +from .typification import Typifier +from .iteration_space import FullIterationSpace + + +class AstFactory: + """Factory providing a convenient interface for building syntax trees. + + The `AstFactory` uses the defaults provided by the given `KernelCreationContext` to quickly create + AST nodes. Depending on context (numerical, loop indexing, etc.), symbols and constants receive either + ``ctx.default_dtype`` or ``ctx.index_dtype``. + + Args: + ctx: The kernel creation context + """ + + def __init__(self, ctx: KernelCreationContext): + self._ctx = ctx + self._freeze = FreezeExpressions(ctx) + self._typify = Typifier(ctx) + + @overload + def parse_sympy(self, sp_obj: sp.Expr) -> PsExpression: + pass + + @overload + def parse_sympy(self, sp_obj: AssignmentBase) -> PsAssignment: + pass + + def parse_sympy(self, sp_obj: sp.Expr | AssignmentBase) -> PsAstNode: + """Parse a SymPy expression or assignment through `FreezeExpressions` and `Typifier`. + + The expression or assignment will be typified in a numerical context, using the kernel + creation context's `default_dtype`. + + Args: + sp_obj: A SymPy expression or assignment + """ + return self._typify(self._freeze(sp_obj)) + + def parse_slice( + self, slic: slice, upper_limit: Any | None = None + ) -> tuple[PsExpression, PsExpression, PsExpression]: + """Parse a slice to obtain start, stop and step expressions for a loop or iteration space dimension. + + The slice entries may be instances of `PsExpression`, `PsSymbol` or `PsConstant`, in which case they + must typify with the kernel creation context's ``index_dtype``. + They may also be sympy expressions or integer constants, in which case they are parsed to AST objects + and must also typify with the kernel creation context's ``index_dtype``. + + If the slice's ``stop`` member is `None` or a negative `int`, `upper_limit` must be specified, which is then + used as the upper iteration limit as either ``upper_limit`` or ``upper_limit - stop``. + + Args: + slic: The iteration slice + upper_limit: Optionally, the upper iteration limit + """ + + if slic.stop is None or (isinstance(slic.stop, int) and slic.stop < 0): + if upper_limit is None: + raise ValueError( + "Must specify an upper iteration limit if `slice.stop` is `None` or a negative `int`" + ) + + def make_expr(val: Any) -> PsExpression: + match val: + case PsExpression(): + return self._typify.typify_expression(val, self._ctx.index_dtype)[0] + case PsSymbol() | PsConstant(): + return self._typify.typify_expression( + PsExpression.make(val), self._ctx.index_dtype + )[0] + case sp.Expr(): + return self._typify.typify_expression( + self._freeze(val), self._ctx.index_dtype + )[0] + case _: + return PsExpression.make(PsConstant(val, self._ctx.index_dtype)) + + start = make_expr(slic.start if slic.start is not None else 0) + stop = make_expr(slic.stop) if slic.stop is not None else make_expr(upper_limit) + step = make_expr(slic.step if slic.step is not None else 1) + + if isinstance(slic.stop, int) and slic.stop < 0: + stop = make_expr(upper_limit) + stop + + return start, stop, step + + def loop(self, ctr_name: str, iteration_slice: slice, body: PsBlock): + """Create a loop from a slice. + + Args: + ctr_name: Name of the loop counter + iteration_slice: The iteration region as a slice; see `parse_slice`. + body: The loop body + """ + ctr = PsExpression.make(self._ctx.get_symbol(ctr_name, self._ctx.index_dtype)) + + start, stop, step = self.parse_slice(iteration_slice) + + return PsLoop( + ctr, + start, + stop, + step, + body, + ) + + def loop_nest(self, counters: Sequence[str], slices: Sequence[slice], body: PsBlock) -> PsLoop: + """Create a loop nest from a sequence of slices. + + **Example:** + This snippet creates a 3D loop nest with ten iterations in each dimension:: + + >>> from pystencils import make_slice + >>> ctx = KernelCreationContext() + >>> factory = AstFactory(ctx) + >>> loop = factory.loop_nest(("i", "j", "k"), make_slice[:10,:10,:10], PsBlock([])) + + Args: + counters: Sequence of names for the loop counters + slices: Sequence of iteration slices; see also `parse_slice` + body: The loop body + """ + if not slices: + raise ValueError( + "At least one slice must be specified to create a loop nest." + ) + + ast = body + for ctr_name, sl in zip(counters[::-1], slices[::-1], strict=True): + ast = self.loop( + ctr_name, + sl, + PsBlock([ast]) if not isinstance(ast, PsBlock) else ast, + ) + + return cast(PsLoop, ast) + + def loops_from_ispace( + self, + ispace: FullIterationSpace, + body: PsBlock, + loop_order: Sequence[int] | None = None, + ) -> PsLoop: + """Create a loop nest from a dense iteration space. + + Args: + ispace: The iteration space object + body: The loop body + loop_order: Optionally, a permutation of integers indicating the order of loops + """ + dimensions = ispace.dimensions + + if loop_order is not None: + dimensions = [dimensions[coordinate] for coordinate in loop_order] + + outer_node: PsLoop | PsBlock = body + + for dimension in dimensions[::-1]: + outer_node = PsLoop( + PsSymbolExpr(dimension.counter), + dimension.start, + dimension.stop, + dimension.step, + ( + outer_node + if isinstance(outer_node, PsBlock) + else PsBlock([outer_node]) + ), + ) + + assert isinstance(outer_node, PsLoop) + return outer_node diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index 5ce373797caa2f7e88957ab38921d0856b003b0c..d48953a5b2ccaf1325006f09ec63f5f8fea90f26 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -106,6 +106,14 @@ class KernelCreationContext: return symb + def find_symbol(self, name: str) -> PsSymbol | None: + """Find a symbol with the given name in the symbol table, if it exists. + + Returns: + The symbol with the given name, or `None` if no such symbol exists. + """ + return self._symbols.get(name, None) + def add_symbol(self, symbol: PsSymbol): if symbol.name in self._symbols: raise PsInternalCompilerError(f"Duplicate symbol: {symbol.name}") diff --git a/src/pystencils/backend/kernelcreation/cpu_optimization.py b/src/pystencils/backend/kernelcreation/cpu_optimization.py index 47db578232578ef9fbf48cb8b34fde8e549ab00e..b0156c7e8ce0b9cac2c6f3be9f60bbeef41e1c51 100644 --- a/src/pystencils/backend/kernelcreation/cpu_optimization.py +++ b/src/pystencils/backend/kernelcreation/cpu_optimization.py @@ -1,7 +1,9 @@ from __future__ import annotations +from typing import cast from .context import KernelCreationContext from ..platforms import GenericCpu +from ..transformations import CanonicalizeSymbols, HoistLoopInvariantDeclarations from ..ast.structural import PsBlock from ...config import CpuOptimConfig @@ -11,10 +13,19 @@ def optimize_cpu( ctx: KernelCreationContext, platform: GenericCpu, kernel_ast: PsBlock, - cfg: CpuOptimConfig, -): + cfg: CpuOptimConfig | None, +) -> PsBlock: """Carry out CPU-specific optimizations according to the given configuration.""" + canonicalize = CanonicalizeSymbols(ctx, True) + kernel_ast = cast(PsBlock, canonicalize(kernel_ast)) + + hoist_invariants = HoistLoopInvariantDeclarations(ctx) + kernel_ast = cast(PsBlock, hoist_invariants(kernel_ast)) + + if cfg is None: + return kernel_ast + if cfg.loop_blocking: raise NotImplementedError("Loop blocking not implemented yet.") @@ -26,3 +37,5 @@ def optimize_cpu( if cfg.use_cacheline_zeroing: raise NotImplementedError("CL-zeroing not implemented yet") + + return kernel_ast diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 6ce0264e20c2ba4b0ecb84bb6c3485890a6fd3a2..a9f760e9718742bfc16a2bee7c60f14f3f272be3 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -1,8 +1,9 @@ from typing import overload, cast, Any from functools import reduce -from operator import add, mul, sub +from operator import add, mul, sub, truediv import sympy as sp +from sympy.codegen.ast import AssignmentBase, AugmentedAssignment from ...sympyextensions import Assignment, AssignmentCollection, integer_functions from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc @@ -68,13 +69,13 @@ class FreezeExpressions: pass @overload - def __call__(self, obj: Assignment) -> PsAssignment: + def __call__(self, obj: AssignmentBase) -> PsAssignment: pass def __call__(self, obj: AssignmentCollection | sp.Basic) -> PsAstNode: if isinstance(obj, AssignmentCollection): return PsBlock([self.visit(asm) for asm in obj.all_assignments]) - elif isinstance(obj, Assignment): + elif isinstance(obj, AssignmentBase): return cast(PsAssignment, self.visit(obj)) elif isinstance(obj, sp.Expr): return cast(PsExpression, self.visit(obj)) @@ -128,6 +129,27 @@ class FreezeExpressions: f"Encountered unsupported expression on assignment left-hand side: {lhs}" ) + def map_AugmentedAssignment(self, expr: AugmentedAssignment): + lhs = self.visit(expr.lhs) + rhs = self.visit(expr.rhs) + + assert isinstance(lhs, PsExpression) + assert isinstance(rhs, PsExpression) + + match expr.op: + case "+=": + op = add + case "-=": + op = sub + case "*=": + op = mul + case "/=": + op = truediv + case _: + raise FreezeError(f"Unsupported augmented assignment: {expr.op}.") + + return PsAssignment(lhs, op(lhs.clone(), rhs)) + def map_Symbol(self, spsym: sp.Symbol) -> PsSymbolExpr: symb = self._ctx.get_symbol(spsym.name) return PsSymbolExpr(symb) diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py index 382adf7b651a5cf5e279da808a2f6fa661415610..5a093031cb4a498a4a00ff2020df0d69747a2b70 100644 --- a/src/pystencils/backend/kernelcreation/iteration_space.py +++ b/src/pystencils/backend/kernelcreation/iteration_space.py @@ -5,8 +5,6 @@ from dataclasses import dataclass from functools import reduce from operator import mul -import sympy as sp - from ...defaults import DEFAULTS from ...sympyextensions import AssignmentCollection from ...field import Field, FieldType @@ -71,8 +69,8 @@ class FullIterationSpace(IterationSpace): @staticmethod def create_with_ghost_layers( ctx: KernelCreationContext, - archetype_field: Field, ghost_layers: int | Sequence[int | tuple[int, int]], + archetype_field: Field, ) -> FullIterationSpace: """Create an iteration space over an archetype field with ghost layers.""" @@ -123,56 +121,52 @@ class FullIterationSpace(IterationSpace): @staticmethod def create_from_slice( ctx: KernelCreationContext, - archetype_field: Field, iteration_slice: Sequence[slice], + archetype_field: Field | None = None, ): - archetype_array = ctx.get_array(archetype_field) - dim = archetype_field.spatial_dimensions - - if len(iteration_slice) != dim: + """Create an iteration space from a sequence of slices, optionally over an archetype field. + + Args: + ctx: The kernel creation context + iteration_slice: The iteration slices for each dimension; for valid formats, see `AstFactory.parse_slice` + archetype_field: Optionally, an archetype field that dictates the upper slice limits and loop order. + """ + dim = len(iteration_slice) + if dim == 0: raise ValueError( - f"Number of dimensions in slice ({len(iteration_slice)}) " - f" did not equal iteration space dimensionality ({dim})" + "At least one slice must be specified to create an iteration space" ) - counters = [ - ctx.get_symbol(name, ctx.index_dtype) - for name in DEFAULTS.spatial_counter_names[:dim] - ] - - from .freeze import FreezeExpressions - from .typification import Typifier - - freeze = FreezeExpressions(ctx) - typifier = Typifier(ctx) + archetype_size: tuple[PsSymbol | PsConstant | None, ...] + if archetype_field is not None: + archetype_array = ctx.get_array(archetype_field) - def expr_convert(expr) -> PsExpression: - if isinstance(expr, int): - return PsConstantExpr(PsConstant(expr, ctx.index_dtype)) - elif isinstance(expr, sp.Expr): - typed_expr, _ = typifier.typify_expression( - freeze.freeze_expression(expr), ctx.index_dtype + if archetype_field.spatial_dimensions != dim: + raise ValueError( + f"Number of dimensions in slice ({len(iteration_slice)}) " + f" did not equal iteration space dimensionality ({dim})" ) - return typed_expr - else: - raise ValueError(f"Invalid entry in slice: {expr}") - def to_dim(slic: slice, size: PsSymbol | PsConstant, ctr: PsSymbol): - size_expr = PsExpression.make(size) + archetype_size = archetype_array.shape[:dim] + else: + archetype_size = (None,) * dim - start = expr_convert(slic.start if slic.start is not None else 0) - stop = expr_convert(slic.stop) if slic.stop is not None else size_expr - step = expr_convert(slic.step if slic.step is not None else 1) + counters = [ + ctx.get_symbol(name, ctx.index_dtype) + for name in DEFAULTS.spatial_counter_names[:dim] + ] - if isinstance(slic.stop, int) and slic.stop < 0: - stop = size_expr + stop # todo + from .ast_factory import AstFactory + factory = AstFactory(ctx) + def to_dim(slic: slice, size: PsSymbol | PsConstant | None, ctr: PsSymbol): + start, stop, step = factory.parse_slice(slic, size) return FullIterationSpace.Dimension(start, stop, step, ctr) dimensions = [ to_dim(slic, size, ctr) for slic, size, ctr in zip( - iteration_slice, archetype_array.shape[:dim], counters, strict=True + iteration_slice, archetype_size, counters, strict=True ) ] @@ -399,13 +393,13 @@ def create_full_iteration_space( if ghost_layers is not None: return FullIterationSpace.create_with_ghost_layers( - ctx, archetype_field, ghost_layers + ctx, ghost_layers, archetype_field ) elif iteration_slice is not None: return FullIterationSpace.create_from_slice( - ctx, archetype_field, iteration_slice + ctx, iteration_slice, archetype_field ) else: return FullIterationSpace.create_with_ghost_layers( - ctx, archetype_field, inferred_gls + ctx, inferred_gls, archetype_field ) diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 190cd9e23b715ac5bca2a4bc6fd119013376cb49..707b02c667ee1aae815ea4d7176bdb7694eb95d5 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -56,7 +56,7 @@ NodeT = TypeVar("NodeT", bound=PsAstNode) class TypeContext: """Typing context, with support for type inference and checking. - + Instances of this class are used to propagate and check data types across expression subtrees of the AST. Each type context has: @@ -185,7 +185,7 @@ class TypeContext: def _compatible(self, dtype: PsType): """Checks whether the given data type is compatible with the context's target type. - + If the target type is ``const``, they must be equal up to const qualification; if the target type is not ``const``, `dtype` must match it exactly. """ @@ -248,7 +248,7 @@ class Typifier: Some expressions (`PsSymbolExpr`, `PsArrayAccess`) encapsulate symbols and inherit their data types, but not necessarily their const-qualification. - A symbol with non-``const`` type may occur in a `PsSymbolExpr` with ``const`` type, + A symbol with non-``const`` type may occur in a `PsSymbolExpr` with ``const`` type, and an array base pointer with non-``const`` base type may be nested in a ``const`` `PsArrayAccess`, but not vice versa. """ @@ -321,7 +321,7 @@ class Typifier: def visit_expr(self, expr: PsExpression, tc: TypeContext) -> None: """Recursive processing of expression nodes. - + This method opens, expands, and closes typing contexts according to the respective expression's typing rules. It may add or check restrictions only when opening or closing a type context. diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py index a83bad4badf89a95f0d3efbd61c0bc355b534b2a..6899ac9474d303623f79e2bdc7c3765c64380a6c 100644 --- a/src/pystencils/backend/platforms/generic_cpu.py +++ b/src/pystencils/backend/platforms/generic_cpu.py @@ -7,6 +7,7 @@ from ...types import PsType, PsIeeeFloatType from .platform import Platform from ..exceptions import MaterializationError +from ..kernelcreation import AstFactory from ..kernelcreation.iteration_space import ( IterationSpace, FullIterationSpace, @@ -76,28 +77,17 @@ class GenericCpu(Platform): def _create_domain_loops( self, body: PsBlock, ispace: FullIterationSpace ) -> PsBlock: - - dimensions = ispace.dimensions + factory = AstFactory(self._ctx) # Determine loop order by permuting dimensions archetype_field = ispace.archetype_field if archetype_field is not None: loop_order = archetype_field.layout - dimensions = [dimensions[coordinate] for coordinate in loop_order] - - outer_block = body - - for dimension in dimensions[::-1]: - loop = PsLoop( - PsSymbolExpr(dimension.counter), - dimension.start, - dimension.stop, - dimension.step, - outer_block, - ) - outer_block = PsBlock([loop]) + else: + loop_order = None - return outer_block + loops = factory.loops_from_ispace(ispace, body, loop_order) + return PsBlock([loops]) def _create_sparse_loop(self, body: PsBlock, ispace: SparseIterationSpace): mappings = [ diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py index 64c0cd3e94b8bbae675417f165e9351a4d122f72..27fcdfac2b6730448b46f0d784292ca9cb577684 100644 --- a/src/pystencils/backend/platforms/generic_gpu.py +++ b/src/pystencils/backend/platforms/generic_gpu.py @@ -56,8 +56,10 @@ class GenericGpu(Platform): ] return indices[:dim] - - def select_function(self, math_function: PsMathFunction, dtype: PsType) -> CFunction: + + def select_function( + self, math_function: PsMathFunction, dtype: PsType + ) -> CFunction: raise NotImplementedError() # Internals diff --git a/src/pystencils/backend/transformations/__init__.py b/src/pystencils/backend/transformations/__init__.py index 8ef35e4fbcfd864be05c1d9bcedc7ee83b3d6a96..afb1e4fcd52d2f0fd85a008ca24987f085fd7dc6 100644 --- a/src/pystencils/backend/transformations/__init__.py +++ b/src/pystencils/backend/transformations/__init__.py @@ -1,10 +1,14 @@ from .eliminate_constants import EliminateConstants +from .canonicalize_symbols import CanonicalizeSymbols +from .hoist_loop_invariant_decls import HoistLoopInvariantDeclarations from .erase_anonymous_structs import EraseAnonymousStructTypes from .select_functions import SelectFunctions from .select_intrinsics import MaterializeVectorIntrinsics __all__ = [ "EliminateConstants", + "CanonicalizeSymbols", + "HoistLoopInvariantDeclarations", "EraseAnonymousStructTypes", "SelectFunctions", "MaterializeVectorIntrinsics", diff --git a/src/pystencils/backend/transformations/canonicalize_symbols.py b/src/pystencils/backend/transformations/canonicalize_symbols.py new file mode 100644 index 0000000000000000000000000000000000000000..6fe922f28cfbfd0a42c77dbd3d1606c910abf298 --- /dev/null +++ b/src/pystencils/backend/transformations/canonicalize_symbols.py @@ -0,0 +1,125 @@ +from itertools import count + +from ..kernelcreation import KernelCreationContext +from ..symbols import PsSymbol +from ..exceptions import PsInternalCompilerError + +from ..ast import PsAstNode +from ..ast.structural import PsDeclaration, PsAssignment, PsLoop, PsConditional, PsBlock +from ..ast.expressions import PsSymbolExpr, PsExpression + +from ...types import constify + +__all__ = ["CanonicalizeSymbols"] + + +class CanonContext: + def __init__(self, ctx: KernelCreationContext) -> None: + self._ctx = ctx + self.encountered_symbols: set[PsSymbol] = set() + self.live_symbols_map: dict[PsSymbol, PsSymbol] = dict() + + self.updated_symbols: set[PsSymbol] = set() + + def deduplicate(self, symb: PsSymbol) -> PsSymbol: + if symb in self.live_symbols_map: + return self.live_symbols_map[symb] + elif symb not in self.encountered_symbols: + self.encountered_symbols.add(symb) + self.live_symbols_map[symb] = symb + return symb + else: + for i in count(): + replacement_name = f"{symb.name}__{i}" + if self._ctx.find_symbol(replacement_name) is None: + replacement = self._ctx.get_symbol(replacement_name, symb.dtype) + self.live_symbols_map[symb] = replacement + self.encountered_symbols.add(replacement) + return replacement + assert False, "unreachable code" + + def mark_as_updated(self, symb: PsSymbol): + self.updated_symbols.add(self.deduplicate(symb)) + + def is_live(self, symb: PsSymbol) -> bool: + return symb in self.live_symbols_map + + def end_lifespan(self, symb: PsSymbol): + if symb in self.live_symbols_map: + del self.live_symbols_map[symb] + + +class CanonicalizeSymbols: + """Remove duplicate symbol declarations and declare all non-updated symbols ``const``. + + The `CanonicalizeSymbols` pass will remove multiple declarations of the same symbol by + renaming all but the last occurence, and will optionally ``const``-qualify all symbols + encountered in the AST that are never updated. + """ + + def __init__(self, ctx: KernelCreationContext, constify: bool = True) -> None: + self._ctx = ctx + self._constify = constify + self._last_result: CanonContext | None = None + + def get_last_live_symbols(self) -> set[PsSymbol]: + if self._last_result is None: + raise PsInternalCompilerError("Pass was not executed yet") + return set(self._last_result.live_symbols_map.values()) + + def __call__(self, node: PsAstNode) -> PsAstNode: + cc = CanonContext(self._ctx) + self.visit(node, cc) + + # Any symbol encountered but never updated can be marked const + if self._constify: + for symb in cc.encountered_symbols - cc.updated_symbols: + if symb.dtype is not None: + symb.dtype = constify(symb.dtype) + + # Any symbols still alive now are function params or globals + # Might use that to populate KernelFunction + self._last_result = cc + + return node + + def visit(self, node: PsAstNode, cc: CanonContext): + """Traverse the AST in reverse pre-order to collect, deduplicate, and maybe constify all live symbols.""" + + match node: + case PsSymbolExpr(symb): + node.symbol = cc.deduplicate(symb) + return node + + case PsExpression(): + for c in node.children: + self.visit(c, cc) + + case PsDeclaration(lhs, rhs): + decl_symb = node.declared_symbol + self.visit(lhs, cc) + self.visit(rhs, cc) + cc.end_lifespan(decl_symb) + + case PsAssignment(lhs, rhs): + self.visit(lhs, cc) + self.visit(rhs, cc) + + if isinstance(lhs, PsSymbolExpr): + cc.mark_as_updated(lhs.symbol) + + case PsLoop(ctr, _, _, _, _): + for c in node.children[::-1]: + self.visit(c, cc) + cc.mark_as_updated(ctr.symbol) + cc.end_lifespan(ctr.symbol) + + case PsConditional(cond, then, els): + if els is not None: + self.visit(els, cc) + self.visit(then, cc) + self.visit(cond, cc) + + case PsBlock(statements): + for stmt in statements[::-1]: + self.visit(stmt, cc) diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py index 808743d694f67166315b89f1d4abf012c8348f9a..22ad740faca992dc9f520f160a406912af4385af 100644 --- a/src/pystencils/backend/transformations/eliminate_constants.py +++ b/src/pystencils/backend/transformations/eliminate_constants.py @@ -20,7 +20,6 @@ from ..ast.util import AstEqWrapper from ..constants import PsConstant from ..symbols import PsSymbol from ...types import PsIntegerType, PsIeeeFloatType, PsTypeError -from ..emission import CAstPrinter __all__ = ["EliminateConstants"] @@ -32,6 +31,9 @@ class ECContext: self._extracted_constants: dict[AstEqWrapper, PsSymbol] = dict() self._typifier = Typifier(ctx) + + from ..emission import CAstPrinter + self._printer = CAstPrinter(0) @property diff --git a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py new file mode 100644 index 0000000000000000000000000000000000000000..5920038150ee6c5295c3f46f4530639ed02fca25 --- /dev/null +++ b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py @@ -0,0 +1,181 @@ +from typing import cast + +from ..kernelcreation import KernelCreationContext +from ..ast import PsAstNode +from ..ast.structural import PsBlock, PsLoop, PsConditional, PsDeclaration, PsAssignment +from ..ast.expressions import ( + PsExpression, + PsSymbolExpr, + PsConstantExpr, + PsCall, + PsDeref, + PsSubscript, + PsUnOp, + PsBinOp, + PsArrayInitList, +) + +from ...types import PsDereferencableType +from ..symbols import PsSymbol +from ..functions import PsMathFunction + +__all__ = ["HoistLoopInvariantDeclarations"] + + +class HoistContext: + def __init__(self) -> None: + self.hoisted_nodes: list[PsDeclaration] = [] + self.assigned_symbols: set[PsSymbol] = set() + self.invariant_symbols: set[PsSymbol] = set() + + def _is_invariant(self, expr: PsExpression) -> bool: + def args_invariant(expr): + return all( + self._is_invariant(cast(PsExpression, arg)) for arg in expr.children + ) + + match expr: + case PsSymbolExpr(symbol): + return (symbol not in self.assigned_symbols) or ( + symbol in self.invariant_symbols + ) + + case PsConstantExpr(): + return True + + case PsCall(func): + return isinstance(func, PsMathFunction) and args_invariant(expr) + + case PsSubscript(ptr, _) | PsDeref(ptr): + ptr_type = cast(PsDereferencableType, ptr.get_dtype()) + return ptr_type.base_type.const and args_invariant(expr) + + case PsUnOp() | PsBinOp() | PsArrayInitList(): + return args_invariant(expr) + + case _: + return False + + +class HoistLoopInvariantDeclarations: + """Hoist loop-invariant declarations out of the loop nest. + + This transformation moves loop-invariant symbol declarations outside of the loop + nest to prevent their repeated execution within the loops. + If this transformation results in the complete elimination of a loop body, the respective loop + is removed. + + `HoistLoopInvariantDeclarations` assumes that symbols are canonical; + in particular, each symbol may have at most one declaration. + To ensure this, a `CanonicalizeSymbols` pass should be run before `HoistLoopInvariantDeclarations`. + + `HoistLoopInvariants` assumes that all `PsMathFunction`s are pure (have no side effects), + but makes no such assumption about instances of `CFunction`. + """ + + def __init__(self, ctx: KernelCreationContext): + self._ctx = ctx + + def __call__(self, node: PsAstNode) -> PsAstNode: + return self.visit(node) + + def visit(self, node: PsAstNode) -> PsAstNode: + """Search the outermost loop and start the hoisting cascade there.""" + match node: + case PsLoop(): + temp_block = PsBlock([node]) + temp_block = cast(PsBlock, self.visit(temp_block)) + if temp_block.statements == [node]: + return node + else: + return temp_block + + case PsBlock(statements): + statements_new: list[PsAstNode] = [] + for stmt in statements: + if isinstance(stmt, PsLoop): + loop = stmt + hc = self._hoist(loop) + statements_new += hc.hoisted_nodes + if loop.body.statements: + statements_new.append(loop) + else: + self.visit(stmt) + statements_new.append(stmt) + + node.statements = statements_new + return node + + case PsConditional(_, then, els): + self.visit(then) + if els is not None: + self.visit(els) + return node + + case _: + # if the node is none of the above, end the search + return node + + # end match + + def _hoist(self, loop: PsLoop) -> HoistContext: + """Hoist invariant declarations out of the given loop.""" + hc = HoistContext() + hc.assigned_symbols.add(loop.counter.symbol) + self._prepare_hoist(loop.body, hc) + self._hoist_from_block(loop.body, hc) + return hc + + def _prepare_hoist(self, node: PsAstNode, hc: HoistContext): + """Collect all symbols assigned within a loop body, + and recursively apply loop-invariant code motion to any nested loops.""" + match node: + case PsExpression(): + return + + case PsAssignment(PsSymbolExpr(lhs_symb), _): + hc.assigned_symbols.add(lhs_symb) + + case PsAssignment(_, _): + return + + case PsBlock(statements): + statements_new: list[PsAstNode] = [] + for stmt in statements: + if isinstance(stmt, PsLoop): + loop = stmt + nested_hc = self._hoist(loop) + hc.assigned_symbols |= nested_hc.assigned_symbols + statements_new += nested_hc.hoisted_nodes + if loop.body.statements: + statements_new.append(loop) + else: + self._prepare_hoist(stmt, hc) + statements_new.append(stmt) + node.statements = statements_new + + case _: + for c in node.children: + self._prepare_hoist(c, hc) + + def _hoist_from_block(self, block: PsBlock, hc: HoistContext): + """Hoist invariant declarations from the given block, and any directly nested blocks. + + This method processes only statements of the given block, and any blocks directly nested inside it. + It does not descend into control structures like conditionals and nested loops. + """ + statements_new: list[PsAstNode] = [] + + for node in block.statements: + if isinstance(node, PsDeclaration): + if hc._is_invariant(node.rhs): + hc.hoisted_nodes.append(node) + hc.invariant_symbols.add(node.declared_symbol) + else: + statements_new.append(node) + else: + if isinstance(node, PsBlock): + self._hoist_from_block(node, hc) + statements_new.append(node) + + block.statements = statements_new diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py index 3f17e66bd7e8df097e48c1c2607e13d9844d9990..0f6941cf52ca1c3745b5af9a1f0478f269fc8e4a 100644 --- a/src/pystencils/kernelcreation.py +++ b/src/pystencils/kernelcreation.py @@ -25,7 +25,11 @@ from .backend.kernelcreation.iteration_space import ( ) from .backend.ast.analysis import collect_required_headers, collect_undefined_symbols -from .backend.transformations import EraseAnonymousStructTypes, EliminateConstants, SelectFunctions +from .backend.transformations import ( + EliminateConstants, + EraseAnonymousStructTypes, + SelectFunctions, +) from .sympyextensions import AssignmentCollection, Assignment @@ -38,7 +42,7 @@ def create_kernel( config: CreateKernelConfig = CreateKernelConfig(), ) -> KernelFunction: """Create a kernel function from a set of assignments. - + Args: assignments: The kernel's sequence of assignments, expressed using SymPy config: The configuration for the kernel translator @@ -84,6 +88,7 @@ def create_kernel( match config.target: case Target.GenericCPU: from .backend.platforms import GenericCpu + platform = GenericCpu(ctx) case _: # TODO: CUDA/HIP platform @@ -96,13 +101,13 @@ def create_kernel( elim_constants = EliminateConstants(ctx, extract_constant_exprs=True) kernel_ast = cast(PsBlock, elim_constants(kernel_ast)) - erase_anons = EraseAnonymousStructTypes(ctx) - kernel_ast = cast(PsBlock, erase_anons(kernel_ast)) - # Target-Specific optimizations - if config.target.is_cpu() and config.cpu_optim is not None: + if config.target.is_cpu(): from .backend.kernelcreation import optimize_cpu - optimize_cpu(ctx, platform, kernel_ast, config.cpu_optim) + kernel_ast = optimize_cpu(ctx, platform, kernel_ast, config.cpu_optim) + + erase_anons = EraseAnonymousStructTypes(ctx) + kernel_ast = cast(PsBlock, erase_anons(kernel_ast)) select_functions = SelectFunctions(platform) kernel_ast = cast(PsBlock, select_functions(kernel_ast)) diff --git a/tests/nbackend/kernelcreation/platform/test_basic_cpu.py b/tests/nbackend/kernelcreation/platform/test_basic_cpu.py index 7bfcd4e425a3011b983cd7361806c57415c121cc..79043c880c819f4e9fc9a7bfeb62e4b4ac955611 100644 --- a/tests/nbackend/kernelcreation/platform/test_basic_cpu.py +++ b/tests/nbackend/kernelcreation/platform/test_basic_cpu.py @@ -23,7 +23,7 @@ def test_loop_nest(layout): # FZYX Order archetype_field = Field.create_generic("field", spatial_dimensions=3, layout=layout) - ispace = FullIterationSpace.create_with_ghost_layers(ctx, archetype_field, 0) + ispace = FullIterationSpace.create_with_ghost_layers(ctx, 0, archetype_field) loop_nest = platform.materialize_iteration_space(body, ispace) diff --git a/tests/nbackend/kernelcreation/platform/test_basic_gpu.py b/tests/nbackend/kernelcreation/platform/test_basic_gpu.py index df0e48fb079ec169486544e56588872b248e357c..e47f38e4d30c3f94ded9469c7a7351e9a3f298da 100644 --- a/tests/nbackend/kernelcreation/platform/test_basic_gpu.py +++ b/tests/nbackend/kernelcreation/platform/test_basic_gpu.py @@ -23,6 +23,6 @@ def test_loop_nest(layout): # FZYX Order archetype_field = Field.create_generic("fzyx_field", spatial_dimensions=3, layout=layout) - ispace = FullIterationSpace.create_with_ghost_layers(ctx, archetype_field, 0) + ispace = FullIterationSpace.create_with_ghost_layers(ctx, 0, archetype_field) condition = platform.materialize_iteration_space(body, ispace) diff --git a/tests/nbackend/kernelcreation/test_iteration_space.py b/tests/nbackend/kernelcreation/test_iteration_space.py index 6b4145e98e6eca749886d06d6b2c644245fb293c..1dfcfea2a9bfd8e9db70dab7ca61732523855843 100644 --- a/tests/nbackend/kernelcreation/test_iteration_space.py +++ b/tests/nbackend/kernelcreation/test_iteration_space.py @@ -1,6 +1,5 @@ import pytest -from pystencils.defaults import DEFAULTS from pystencils.field import Field from pystencils.sympyextensions.typed_sympy import TypedSymbol, create_type @@ -8,6 +7,7 @@ from pystencils.backend.kernelcreation import KernelCreationContext, FullIterati from pystencils.backend.ast.expressions import PsAdd, PsConstantExpr, PsExpression from pystencils.backend.kernelcreation.typification import TypificationError +from pystencils.types import PsTypeError def test_slices(): @@ -17,7 +17,7 @@ def test_slices(): ctx.add_field(archetype_field) islice = (slice(1, -1, 1), slice(3, -3, 3), slice(0, None, -1)) - ispace = FullIterationSpace.create_from_slice(ctx, archetype_field, islice) + ispace = FullIterationSpace.create_from_slice(ctx, islice, archetype_field) archetype_arr = ctx.get_array(archetype_field) @@ -52,9 +52,9 @@ def test_invalid_slices(): ctx.add_field(archetype_field) islice = (slice(1, -1, 0.5),) - with pytest.raises(ValueError): - FullIterationSpace.create_from_slice(ctx, archetype_field, islice) + with pytest.raises(PsTypeError): + FullIterationSpace.create_from_slice(ctx, islice, archetype_field) islice = (slice(1, -1, TypedSymbol("w", dtype=create_type("double"))),) with pytest.raises(TypificationError): - FullIterationSpace.create_from_slice(ctx, archetype_field, islice) + FullIterationSpace.create_from_slice(ctx, islice, archetype_field) diff --git a/tests/nbackend/transformations/test_canonicalize_symbols.py b/tests/nbackend/transformations/test_canonicalize_symbols.py new file mode 100644 index 0000000000000000000000000000000000000000..43f269163f592116aaed19c44e003b882400f498 --- /dev/null +++ b/tests/nbackend/transformations/test_canonicalize_symbols.py @@ -0,0 +1,88 @@ +# type: ignore +import sympy as sp + +from pystencils import Field, Assignment, AddAugmentedAssignment, make_slice, DEFAULTS + +from pystencils.backend.kernelcreation import ( + KernelCreationContext, + AstFactory, + FullIterationSpace, +) +from pystencils.backend.transformations import CanonicalizeSymbols +from pystencils.backend.ast.structural import PsConditional, PsBlock + + +def test_deduplication(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + canonicalize = CanonicalizeSymbols(ctx) + + f = Field.create_fixed_size("f", (5, 5), strides=(5, 1)) + x, y, z = sp.symbols("x, y, z") + + ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:, :], f) + ctx.set_iteration_space(ispace) + + ctr_1 = DEFAULTS.spatial_counters[1] + + then_branch = PsBlock( + [ + factory.parse_sympy(Assignment(x, y)), + factory.parse_sympy(Assignment(f.center(0), x)), + ] + ) + + else_branch = PsBlock( + [ + factory.parse_sympy(Assignment(x, z)), + factory.parse_sympy(Assignment(f.center(0), x)), + ] + ) + + ast = PsConditional( + factory.parse_sympy(ctr_1), + then_branch, + else_branch, + ) + + ast = factory.loops_from_ispace(ispace, PsBlock([ast])) + + ast = canonicalize(ast) + + assert canonicalize.get_last_live_symbols() == { + ctx.find_symbol("y"), + ctx.find_symbol("z"), + ctx.get_array(f).base_pointer, + } + + assert ctx.find_symbol("x") is not None + assert ctx.find_symbol("x__0") is not None + + assert then_branch.statements[0].declared_symbol.name == "x__0" + assert then_branch.statements[1].rhs.symbol.name == "x__0" + + assert else_branch.statements[0].declared_symbol.name == "x" + assert else_branch.statements[1].rhs.symbol.name == "x" + + assert ctx.find_symbol("x").dtype.const + assert ctx.find_symbol("x__0").dtype.const + assert ctx.find_symbol("y").dtype.const + assert ctx.find_symbol("z").dtype.const + + +def test_do_not_constify(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + canonicalize = CanonicalizeSymbols(ctx) + + x, z = sp.symbols("x, z") + + ast = factory.loop("i", make_slice[:10], PsBlock([ + factory.parse_sympy(Assignment(x, z)), + factory.parse_sympy(AddAugmentedAssignment(z, 1)) + ])) + + ast = canonicalize(ast) + + assert ctx.find_symbol("x").dtype.const + assert not ctx.find_symbol("z").dtype.const diff --git a/tests/nbackend/transformations/test_hoist_invariants.py b/tests/nbackend/transformations/test_hoist_invariants.py new file mode 100644 index 0000000000000000000000000000000000000000..db78efce54e75650545960e0edc3877dcc57d951 --- /dev/null +++ b/tests/nbackend/transformations/test_hoist_invariants.py @@ -0,0 +1,195 @@ +import sympy as sp + +from pystencils import ( + Field, + TypedSymbol, + Assignment, + AddAugmentedAssignment, + make_slice, +) +from pystencils.types.quick import Arr, Fp, Bool + +from pystencils.backend.ast.structural import ( + PsBlock, + PsLoop, + PsConditional, + PsDeclaration, +) + +from pystencils.backend.kernelcreation import ( + KernelCreationContext, + AstFactory, + FullIterationSpace, +) +from pystencils.backend.transformations import ( + CanonicalizeSymbols, + HoistLoopInvariantDeclarations, +) + + +def test_hoist_multiple_loops(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + canonicalize = CanonicalizeSymbols(ctx) + hoist = HoistLoopInvariantDeclarations(ctx) + + f = Field.create_fixed_size("f", (5, 5), strides=(5, 1)) + x, y, z = sp.symbols("x, y, z") + + ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:, :], f) + ctx.set_iteration_space(ispace) + + first_loop = factory.loops_from_ispace( + ispace, + PsBlock( + [ + factory.parse_sympy(Assignment(x, y)), + factory.parse_sympy(Assignment(f.center(0), x)), + ] + ), + ) + + second_loop = factory.loops_from_ispace( + ispace, + PsBlock( + [ + factory.parse_sympy(Assignment(x, z)), + factory.parse_sympy(Assignment(f.center(0), x)), + ] + ), + ) + + ast = PsBlock([first_loop, second_loop]) + + result = canonicalize(ast) + result = hoist(result) + + assert isinstance(result, PsBlock) + + assert ( + isinstance(result.statements[0], PsDeclaration) + and result.statements[0].declared_symbol.name == "x__0" + ) + + assert result.statements[1] == first_loop + + assert ( + isinstance(result.statements[2], PsDeclaration) + and result.statements[2].declared_symbol.name == "x" + ) + + assert result.statements[3] == second_loop + + +def test_hoist_with_recurrence(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + hoist = HoistLoopInvariantDeclarations(ctx) + + x, y = sp.symbols("x, y") + x_decl = factory.parse_sympy(Assignment(x, 1)) + x_update = factory.parse_sympy(AddAugmentedAssignment(x, 1)) + y_decl = factory.parse_sympy(Assignment(y, 2 * x)) + + loop = factory.loop("i", make_slice[0:10:1], PsBlock([y_decl, x_update])) + + ast = PsBlock([x_decl, loop]) + + result = hoist(ast) + + # x is updated in the loop, so nothing can be hoisted + assert isinstance(result, PsBlock) + assert result.statements == [x_decl, loop] + assert loop.body.statements == [y_decl, x_update] + + +def test_hoist_with_conditionals(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + hoist = HoistLoopInvariantDeclarations(ctx) + + x, y, z, w = sp.symbols("x, y, z, w") + x_decl = factory.parse_sympy(Assignment(x, 1)) + x_update = factory.parse_sympy(AddAugmentedAssignment(x, 1)) + y_decl = factory.parse_sympy(Assignment(y, 2 * x)) + z_decl = factory.parse_sympy(Assignment(z, 312)) + w_decl = factory.parse_sympy(Assignment(w, 142)) + + cond = factory.parse_sympy(TypedSymbol("cond", Bool())) + + inner_conditional = PsConditional(cond, PsBlock([x_update, z_decl])) + loop = factory.loop( + "i", + make_slice[0:10:1], + PsBlock([y_decl, w_decl, inner_conditional]), + ) + outer_conditional = PsConditional(cond, PsBlock([loop])) + + ast = PsBlock([x_decl, outer_conditional]) + + result = hoist(ast) + + # z is hidden inside conditional, so z cannot be hoisted + # x is updated conditionally, so y cannot be hoisted + assert isinstance(result, PsBlock) + assert result.statements == [x_decl, outer_conditional] + assert outer_conditional.branch_true.statements == [w_decl, loop] + assert loop.body.statements == [y_decl, inner_conditional] + + +def test_hoist_arrays(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + hoist = HoistLoopInvariantDeclarations(ctx) + + const_arr_symb = TypedSymbol( + "const_arr", + Arr(Fp(64, const=True), 10), + ) + const_array_decl = factory.parse_sympy(Assignment(const_arr_symb, tuple(range(10)))) + const_arr = sp.IndexedBase(const_arr_symb, shape=(10,)) + + arr_symb = TypedSymbol( + "arr", + Arr(Fp(64, const=False), 10), + ) + array_decl = factory.parse_sympy(Assignment(arr_symb, tuple(range(10)))) + arr = sp.IndexedBase(arr_symb, shape=(10,)) + + x, y = sp.symbols("x, y") + + nonconst_usage = factory.parse_sympy(Assignment(x, arr[3])) + const_usage = factory.parse_sympy(Assignment(y, const_arr[3])) + body = PsBlock([array_decl, const_array_decl, nonconst_usage, const_usage]) + + loop = factory.loop_nest(("i", "j"), make_slice[:10, :42], body) + + result = hoist(loop) + + assert isinstance(result, PsBlock) + assert result.statements == [array_decl, const_array_decl, const_usage, loop] + + assert isinstance(loop.body.statements[0], PsLoop) + assert loop.body.statements[0].body.statements == [nonconst_usage] + + +def test_hoisting_eliminates_loops(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + hoist = HoistLoopInvariantDeclarations(ctx) + + x, y, z = sp.symbols("x, y, z") + + invariant_decls = [ + factory.parse_sympy(Assignment(x, 42)), + factory.parse_sympy(Assignment(y, 2 * x)), + factory.parse_sympy(Assignment(z, x + 4 * y)), + ] + + ast = factory.loop_nest(("i", "j"), make_slice[:10, :42], PsBlock(invariant_decls)) + + ast = hoist(ast) + + assert isinstance(ast, PsBlock) + # All statements are hoisted and the loops are removed + assert ast.statements == invariant_decls