diff --git a/docs/source/backend/index.rst b/docs/source/backend/index.rst index 1e3968bc0e4137b7a43796911de646cdddcb9ab6..df194bde9b9c50cc9ee3f1064ff5b5361205c227 100644 --- a/docs/source/backend/index.rst +++ b/docs/source/backend/index.rst @@ -14,6 +14,7 @@ who wish to customize or extend the behaviour of the code generator in their app iteration_space translation platforms + transformations jit Internal Representation diff --git a/docs/source/backend/transformations.rst b/docs/source/backend/transformations.rst new file mode 100644 index 0000000000000000000000000000000000000000..44bf4da23e160edaac5e2dd9918fbd389aba94d6 --- /dev/null +++ b/docs/source/backend/transformations.rst @@ -0,0 +1,7 @@ +******************* +AST Transformations +******************* + +`pystencils.backend.transformations` + +.. automodule:: pystencils.backend.transformations diff --git a/src/pystencils/backend/kernelcreation/ast_factory.py b/src/pystencils/backend/kernelcreation/ast_factory.py index c2334f54c34d476207eddc5466b2b13bff0d39d8..83c406b0a99d52cee9599f321d6c32477f6dbf8a 100644 --- a/src/pystencils/backend/kernelcreation/ast_factory.py +++ b/src/pystencils/backend/kernelcreation/ast_factory.py @@ -1,10 +1,11 @@ from typing import Any, Sequence, cast, overload +import numpy as np import sympy as sp from sympy.codegen.ast import AssignmentBase from ..ast import PsAstNode -from ..ast.expressions import PsExpression, PsSymbolExpr +from ..ast.expressions import PsExpression, PsSymbolExpr, PsConstantExpr from ..ast.structural import PsLoop, PsBlock, PsAssignment from ..symbols import PsSymbol @@ -16,6 +17,10 @@ from .typification import Typifier from .iteration_space import FullIterationSpace +IndexParsable = PsExpression | PsSymbol | PsConstant | sp.Expr | int | np.integer +_IndexParsable = (PsExpression, PsSymbol, PsConstant, sp.Expr, int, np.integer) + + class AstFactory: """Factory providing a convenient interface for building syntax trees. @@ -51,6 +56,45 @@ class AstFactory: """ return self._typify(self._freeze(sp_obj)) + @overload + def parse_index(self, idx: sp.Symbol | PsSymbol | PsSymbolExpr) -> PsSymbolExpr: + pass + + @overload + def parse_index( + self, idx: int | np.integer | PsConstant | PsConstantExpr + ) -> PsConstantExpr: + pass + + @overload + def parse_index(self, idx: sp.Expr | PsExpression) -> PsExpression: + pass + + def parse_index(self, idx: IndexParsable): + """Parse the given object as an expression with data type `ctx.index_dtype`.""" + + if not isinstance(idx, _IndexParsable): + raise TypeError( + f"Cannot parse object of type {type(idx)} as an index expression" + ) + + match idx: + case PsExpression(): + return self._typify.typify_expression(idx, self._ctx.index_dtype)[0] + case PsSymbol() | PsConstant(): + return self._typify.typify_expression( + PsExpression.make(idx), self._ctx.index_dtype + )[0] + case sp.Expr(): + return self._typify.typify_expression( + self._freeze(idx), self._ctx.index_dtype + )[0] + case _: + return PsExpression.make(PsConstant(idx, self._ctx.index_dtype)) + + def _parse_any_index(self, idx: Any) -> PsExpression: + return self.parse_index(cast(IndexParsable, idx)) + def parse_slice( self, slic: slice, upper_limit: Any | None = None ) -> tuple[PsExpression, PsExpression, PsExpression]: @@ -75,27 +119,16 @@ class AstFactory: "Must specify an upper iteration limit if `slice.stop` is `None` or a negative `int`" ) - def make_expr(val: Any) -> PsExpression: - match val: - case PsExpression(): - return self._typify.typify_expression(val, self._ctx.index_dtype)[0] - case PsSymbol() | PsConstant(): - return self._typify.typify_expression( - PsExpression.make(val), self._ctx.index_dtype - )[0] - case sp.Expr(): - return self._typify.typify_expression( - self._freeze(val), self._ctx.index_dtype - )[0] - case _: - return PsExpression.make(PsConstant(val, self._ctx.index_dtype)) - - start = make_expr(slic.start if slic.start is not None else 0) - stop = make_expr(slic.stop) if slic.stop is not None else make_expr(upper_limit) - step = make_expr(slic.step if slic.step is not None else 1) + start = self._parse_any_index(slic.start if slic.start is not None else 0) + stop = ( + self._parse_any_index(slic.stop) + if slic.stop is not None + else self._parse_any_index(upper_limit) + ) + step = self._parse_any_index(slic.step if slic.step is not None else 1) if isinstance(slic.stop, int) and slic.stop < 0: - stop = make_expr(upper_limit) + stop + stop = self._parse_any_index(upper_limit) + stop return start, stop, step diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index d48953a5b2ccaf1325006f09ec63f5f8fea90f26..263c2f48ecfff15dd4c6271f77ca2a7578b86d09 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -1,9 +1,10 @@ from __future__ import annotations from typing import Iterable, Iterator -from itertools import chain +from itertools import chain, count from types import EllipsisType -from collections import namedtuple +from collections import namedtuple, defaultdict +import re from ...defaults import DEFAULTS from ...field import Field, FieldType @@ -67,6 +68,9 @@ class KernelCreationContext: self._symbols: dict[str, PsSymbol] = dict() + self._symbol_ctr_pattern = re.compile(r"__[0-9]+$") + self._symbol_dup_table: defaultdict[str, int] = defaultdict(lambda: 0) + self._fields_and_arrays: dict[str, FieldArrayPair] = dict() self._fields_collection = FieldsInKernel() @@ -95,6 +99,21 @@ class KernelCreationContext: # Symbols def get_symbol(self, name: str, dtype: PsType | None = None) -> PsSymbol: + """Retrieve the symbol with the given name and data type from the symbol table. + + If no symbol named ``name`` exists, a new symbol with the given data type is created. + + If a symbol with the given ``name`` already exists and ``dtype`` is not `None`, + the given data type will be applied to it, and it is returned. + If the symbol already has a different data type, an error will be raised. + + If the symbol already exists and ``dtype`` is `None`, the existing symbol is returned + without checking or altering its data type. + + Args: + name: The symbol's name + dtype: The symbol's data type, or `None` + """ if name not in self._symbols: symb = PsSymbol(name, None) self._symbols[name] = symb @@ -115,12 +134,20 @@ class KernelCreationContext: return self._symbols.get(name, None) def add_symbol(self, symbol: PsSymbol): + """Add an existing symbol to the symbol table. + + If a symbol with the same name already exists, an error will be raised. + """ if symbol.name in self._symbols: raise PsInternalCompilerError(f"Duplicate symbol: {symbol.name}") self._symbols[symbol.name] = symbol def replace_symbol(self, old: PsSymbol, new: PsSymbol): + """Replace one symbol by another. + + The two symbols ``old`` and ``new`` must have the same name, but may have different data types. + """ if old.name != new.name: raise PsInternalCompilerError( "replace_symbol: Old and new symbol must have the same name" @@ -131,8 +158,30 @@ class KernelCreationContext: self._symbols[old.name] = new + def duplicate_symbol(self, symb: PsSymbol) -> PsSymbol: + """Canonically duplicates the given symbol. + + A new symbol with the same data type, and new name ``symb.name + "__<counter>"`` is created, + added to the symbol table, and returned. + The ``counter`` reflects the number of previously created duplicates of this symbol. + """ + if (result := self._symbol_ctr_pattern.search(symb.name)) is not None: + span = result.span() + basename = symb.name[: span[0]] + else: + basename = symb.name + + initial_count = self._symbol_dup_table[basename] + for i in count(initial_count): + dup_name = f"{basename}__{i}" + if self.find_symbol(dup_name) is None: + self._symbol_dup_table[basename] = i + 1 + return self.get_symbol(dup_name, symb.dtype) + assert False, "unreachable code" + @property def symbols(self) -> Iterable[PsSymbol]: + """Return an iterable of all symbols listed in the symbol table.""" return self._symbols.values() # Fields and Arrays diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py index ba215f822ea7372211bf764425d44e44487cc46b..6adac2a519ffc04505c0e0adac3484d78f30d013 100644 --- a/src/pystencils/backend/kernelcreation/iteration_space.py +++ b/src/pystencils/backend/kernelcreation/iteration_space.py @@ -121,7 +121,7 @@ class FullIterationSpace(IterationSpace): @staticmethod def create_from_slice( ctx: KernelCreationContext, - iteration_slice: Sequence[slice], + iteration_slice: slice | Sequence[slice], archetype_field: Field | None = None, ): """Create an iteration space from a sequence of slices, optionally over an archetype field. @@ -131,6 +131,9 @@ class FullIterationSpace(IterationSpace): iteration_slice: The iteration slices for each dimension; for valid formats, see `AstFactory.parse_slice` archetype_field: Optionally, an archetype field that dictates the upper slice limits and loop order. """ + if isinstance(iteration_slice, slice): + iteration_slice = (iteration_slice,) + dim = len(iteration_slice) if dim == 0: raise ValueError( diff --git a/src/pystencils/backend/transformations/__init__.py b/src/pystencils/backend/transformations/__init__.py index 01b69509991eaa762a093f50f427f6e4050dc34a..518c402d27e8828cc129c06a20bd10e2ba3d3168 100644 --- a/src/pystencils/backend/transformations/__init__.py +++ b/src/pystencils/backend/transformations/__init__.py @@ -1,16 +1,94 @@ +""" +This module contains various transformation and optimization passes that can be +executed on the backend AST. + +Canonical Form +============== + +Many transformations in this module require that their input AST is in *canonical form*. +This means that: + +- Each symbol, constant, and expression node is annotated with a data type; +- Each symbol has at most one declaration; +- Each symbol that is never written to apart from its declaration has a ``const`` type; and +- Each symbol whose type is *not* ``const`` has at least one non-declaring assignment. + +The first requirement can be ensured by running the `Typifier` on each newly constructed subtree. +The other three requirements are ensured by the `CanonicalizeSymbols` pass, +which should be run first before applying any optimizing transformations. +All transformations in this module retain canonicality of the AST. + +Canonicality allows transformations to forego various checks that would otherwise be necessary +to prove their legality. + +Certain transformations, like the auto-vectorizer (TODO), state additional requirements, e.g. +the absence of loop-carried dependencies. + +Transformations +=============== + +Canonicalization +---------------- + +.. autoclass:: CanonicalizeSymbols + :members: __call__ + +AST Cloning +----------- + +.. autoclass:: CanonicalClone + :members: __call__ + +Simplifying Transformations +--------------------------- + +.. autoclass:: EliminateConstants + :members: __call__ + +.. autoclass:: EliminateBranches + :members: __call__ + +Code Motion +----------- + +.. autoclass:: HoistLoopInvariantDeclarations + :members: __call__ + +Loop Reshaping Transformations +------------------------------ + +.. autoclass:: ReshapeLoops + :members: + + +Code Lowering and Materialization +--------------------------------- + +.. autoclass:: EraseAnonymousStructTypes + :members: __call__ + +.. autoclass:: SelectFunctions + :members: __call__ + +""" + +from .canonicalize_symbols import CanonicalizeSymbols +from .canonical_clone import CanonicalClone from .eliminate_constants import EliminateConstants from .eliminate_branches import EliminateBranches -from .canonicalize_symbols import CanonicalizeSymbols from .hoist_loop_invariant_decls import HoistLoopInvariantDeclarations +from .reshape_loops import ReshapeLoops from .erase_anonymous_structs import EraseAnonymousStructTypes from .select_functions import SelectFunctions from .select_intrinsics import MaterializeVectorIntrinsics __all__ = [ + "CanonicalizeSymbols", + "CanonicalClone", "EliminateConstants", "EliminateBranches", - "CanonicalizeSymbols", "HoistLoopInvariantDeclarations", + "ReshapeLoops", "EraseAnonymousStructTypes", "SelectFunctions", "MaterializeVectorIntrinsics", diff --git a/src/pystencils/backend/transformations/canonical_clone.py b/src/pystencils/backend/transformations/canonical_clone.py new file mode 100644 index 0000000000000000000000000000000000000000..538bb2779314fc0fe1d7b83810dd6a4b031ca46a --- /dev/null +++ b/src/pystencils/backend/transformations/canonical_clone.py @@ -0,0 +1,112 @@ +from typing import TypeVar, cast + +from ..kernelcreation import KernelCreationContext +from ..symbols import PsSymbol +from ..exceptions import PsInternalCompilerError + +from ..ast import PsAstNode +from ..ast.structural import ( + PsBlock, + PsConditional, + PsLoop, + PsDeclaration, + PsAssignment, + PsComment, +) +from ..ast.expressions import PsExpression, PsSymbolExpr + +__all__ = ["CanonicalClone"] + + +class CloneContext: + def __init__(self, ctx: KernelCreationContext) -> None: + self._ctx = ctx + self._dup_table: dict[PsSymbol, PsSymbol] = dict() + + def symbol_decl(self, declared_symbol: PsSymbol): + self._dup_table[declared_symbol] = self._ctx.duplicate_symbol(declared_symbol) + + def get_replacement(self, symb: PsSymbol): + return self._dup_table.get(symb, symb) + + +Node_T = TypeVar("Node_T", bound=PsAstNode) + + +class CanonicalClone: + """Clone a subtree, and rename all symbols declared inside it to retain canonicality.""" + + def __init__(self, ctx: KernelCreationContext) -> None: + self._ctx = ctx + + def __call__(self, node: Node_T) -> Node_T: + return self.visit(node, CloneContext(self._ctx)) + + def visit(self, node: Node_T, cc: CloneContext) -> Node_T: + match node: + case PsBlock(statements): + return cast( + Node_T, PsBlock([self.visit(stmt, cc) for stmt in statements]) + ) + + case PsLoop(ctr, start, stop, step, body): + cc.symbol_decl(ctr.symbol) + return cast( + Node_T, + PsLoop( + self.visit(ctr, cc), + self.visit(start, cc), + self.visit(stop, cc), + self.visit(step, cc), + self.visit(body, cc), + ), + ) + + case PsConditional(cond, then, els): + return cast( + Node_T, + PsConditional( + self.visit(cond, cc), + self.visit(then, cc), + self.visit(els, cc) if els is not None else None, + ), + ) + + case PsComment(): + return cast(Node_T, node.clone()) + + case PsDeclaration(lhs, rhs): + cc.symbol_decl(node.declared_symbol) + return cast( + Node_T, + PsDeclaration( + cast(PsSymbolExpr, self.visit(lhs, cc)), + self.visit(rhs, cc), + ), + ) + + case PsAssignment(lhs, rhs): + return cast( + Node_T, + PsAssignment( + self.visit(lhs, cc), + self.visit(rhs, cc), + ), + ) + + case PsExpression(): + expr_clone = node.clone() + self._replace_symbols(expr_clone, cc) + return cast(Node_T, expr_clone) + + case _: + raise PsInternalCompilerError( + f"Don't know how to canonically clone {type(node)}" + ) + + def _replace_symbols(self, expr: PsExpression, cc: CloneContext): + if isinstance(expr, PsSymbolExpr): + expr.symbol = cc.get_replacement(expr.symbol) + else: + for c in expr.children: + self._replace_symbols(cast(PsExpression, c), cc) diff --git a/src/pystencils/backend/transformations/canonicalize_symbols.py b/src/pystencils/backend/transformations/canonicalize_symbols.py index 6fe922f28cfbfd0a42c77dbd3d1606c910abf298..3900105b8f64b7cd33e154de02eaec7cf826d0fb 100644 --- a/src/pystencils/backend/transformations/canonicalize_symbols.py +++ b/src/pystencils/backend/transformations/canonicalize_symbols.py @@ -1,5 +1,3 @@ -from itertools import count - from ..kernelcreation import KernelCreationContext from ..symbols import PsSymbol from ..exceptions import PsInternalCompilerError @@ -29,14 +27,10 @@ class CanonContext: self.live_symbols_map[symb] = symb return symb else: - for i in count(): - replacement_name = f"{symb.name}__{i}" - if self._ctx.find_symbol(replacement_name) is None: - replacement = self._ctx.get_symbol(replacement_name, symb.dtype) - self.live_symbols_map[symb] = replacement - self.encountered_symbols.add(replacement) - return replacement - assert False, "unreachable code" + replacement = self._ctx.duplicate_symbol(symb) + self.live_symbols_map[symb] = replacement + self.encountered_symbols.add(replacement) + return replacement def mark_as_updated(self, symb: PsSymbol): self.updated_symbols.add(self.deduplicate(symb)) diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py index 7678dbd8c6ce783585fb7095b201e9f92e65e485..7fa4766eb305954f56d10b8cf8052c2fb26cb8fe 100644 --- a/src/pystencils/backend/transformations/eliminate_constants.py +++ b/src/pystencils/backend/transformations/eliminate_constants.py @@ -1,4 +1,4 @@ -from typing import cast, Iterable +from typing import cast, Iterable, overload from collections import defaultdict from ..kernelcreation import KernelCreationContext, Typifier @@ -116,6 +116,14 @@ class EliminateConstants: self._fold_floats = False self._extract_constant_exprs = extract_constant_exprs + @overload + def __call__(self, node: PsExpression) -> PsExpression: + pass + + @overload + def __call__(self, node: PsAstNode) -> PsAstNode: + pass + def __call__(self, node: PsAstNode) -> PsAstNode: ecc = ECContext(self._ctx) diff --git a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py index 5920038150ee6c5295c3f46f4530639ed02fca25..5824239e40ff8365a20658defab344892973f58f 100644 --- a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py +++ b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py @@ -69,7 +69,7 @@ class HoistLoopInvariantDeclarations: in particular, each symbol may have at most one declaration. To ensure this, a `CanonicalizeSymbols` pass should be run before `HoistLoopInvariantDeclarations`. - `HoistLoopInvariants` assumes that all `PsMathFunction`s are pure (have no side effects), + `HoistLoopInvariantDeclarations` assumes that all `PsMathFunction` s are pure (have no side effects), but makes no such assumption about instances of `CFunction`. """ diff --git a/src/pystencils/backend/transformations/reshape_loops.py b/src/pystencils/backend/transformations/reshape_loops.py new file mode 100644 index 0000000000000000000000000000000000000000..6963bee0b2e43bc6bac58a6c96de5f4a35e57148 --- /dev/null +++ b/src/pystencils/backend/transformations/reshape_loops.py @@ -0,0 +1,138 @@ +from typing import Sequence + +from ..kernelcreation import KernelCreationContext, Typifier +from ..kernelcreation.ast_factory import AstFactory, IndexParsable + +from ..ast.structural import PsLoop, PsBlock, PsConditional, PsDeclaration +from ..ast.expressions import PsExpression, PsConstantExpr, PsLt +from ..constants import PsConstant + +from .canonical_clone import CanonicalClone, CloneContext +from .eliminate_constants import EliminateConstants + + +class ReshapeLoops: + """Various transformations for reshaping loop nests.""" + + def __init__(self, ctx: KernelCreationContext) -> None: + self._ctx = ctx + self._typify = Typifier(ctx) + self._factory = AstFactory(ctx) + self._canon_clone = CanonicalClone(ctx) + self._elim_constants = EliminateConstants(ctx) + + def peel_loop_front( + self, loop: PsLoop, num_iterations: int, omit_range_check: bool = False + ) -> tuple[Sequence[PsBlock], PsLoop]: + """Peel off iterations from the front of a loop. + + Removes ``num_iterations`` from the front of the given loop and returns them as a sequence of + independent blocks. + + Args: + loop: The loop node from which to peel iterations + num_iterations: The number of iterations to peel off + omit_range_check: If set to `True`, assume that the peeled-off iterations will always + be executed, and omit their enclosing conditional. + + Returns: + Tuple containing the peeled-off iterations as a sequence of blocks, + and the remaining loop. + """ + + peeled_iters: list[PsBlock] = [] + + for i in range(num_iterations): + cc = CloneContext(self._ctx) + cc.symbol_decl(loop.counter.symbol) + peeled_ctr = self._factory.parse_index( + cc.get_replacement(loop.counter.symbol) + ) + peeled_idx = self._typify(loop.start + PsExpression.make(PsConstant(i))) + + counter_decl = PsDeclaration(peeled_ctr, peeled_idx) + peeled_block = self._canon_clone.visit(loop.body, cc) + + if omit_range_check: + peeled_block.statements = [counter_decl] + peeled_block.statements + else: + iter_condition = PsLt(peeled_ctr, loop.stop) + peeled_block.statements = [ + counter_decl, + PsConditional(iter_condition, PsBlock(peeled_block.statements)), + ] + + peeled_iters.append(peeled_block) + + loop.start = self._elim_constants( + self._typify(loop.start + PsExpression.make(PsConstant(num_iterations))) + ) + + return peeled_iters, loop + + def cut_loop( + self, loop: PsLoop, cutting_points: Sequence[IndexParsable] + ) -> Sequence[PsLoop | PsBlock]: + """Cut a loop at the given cutting points. + + Cut the given loop at the iterations specified by the given cutting points, + producing ``n`` new subtrees representing the iterations + ``(loop.start:cutting_points[0]), (cutting_points[0]:cutting_points[1]), ..., (cutting_points[-1]:loop.stop)``. + + Resulting subtrees representing zero iterations are dropped; subtrees representing exactly one iteration are + returned without the trivial loop structure. + + Currently, `cut_loop` performs no checks to ensure that the given cutting points are in fact inside + the loop's iteration range. + + Returns: + Sequence of ``n`` subtrees representing the respective iteration ranges + """ + + if not ( + isinstance(loop.step, PsConstantExpr) and loop.step.constant.value == 1 + ): + raise NotImplementedError( + "Loop cutting for loops with step != 1 is not implemented" + ) + + result: list[PsLoop | PsBlock] = [] + new_start = loop.start + cutting_points = [self._factory.parse_index(idx) for idx in cutting_points] + [ + loop.stop + ] + + for new_end in cutting_points: + if new_end.structurally_equal(new_start): + continue + + num_iters = self._elim_constants(self._typify(new_end - new_start)) + skip = False + + if isinstance(num_iters, PsConstantExpr): + if num_iters.constant.value == 0: + skip = True + elif num_iters.constant.value == 1: + skip = True + cc = CloneContext(self._ctx) + cc.symbol_decl(loop.counter.symbol) + local_counter = self._factory.parse_index( + cc.get_replacement(loop.counter.symbol) + ) + ctr_decl = PsDeclaration( + local_counter, + new_start, + ) + cloned_body = self._canon_clone.visit(loop.body, cc) + cloned_body.statements = [ctr_decl] + cloned_body.statements + result.append(cloned_body) + + if not skip: + loop_clone = self._canon_clone(loop) + loop_clone.start = new_start.clone() + loop_clone.stop = new_end.clone() + result.append(loop_clone) + + new_start = new_end + + return result diff --git a/tests/nbackend/kernelcreation/test_iteration_space.py b/tests/nbackend/kernelcreation/test_iteration_space.py index 1dfcfea2a9bfd8e9db70dab7ca61732523855843..7fd6d778ff62f7fb2fcbc24a55af5225fb9f870e 100644 --- a/tests/nbackend/kernelcreation/test_iteration_space.py +++ b/tests/nbackend/kernelcreation/test_iteration_space.py @@ -52,7 +52,7 @@ def test_invalid_slices(): ctx.add_field(archetype_field) islice = (slice(1, -1, 0.5),) - with pytest.raises(PsTypeError): + with pytest.raises(TypeError): FullIterationSpace.create_from_slice(ctx, islice, archetype_field) islice = (slice(1, -1, TypedSymbol("w", dtype=create_type("double"))),) diff --git a/tests/nbackend/transformations/test_canonical_clone.py b/tests/nbackend/transformations/test_canonical_clone.py new file mode 100644 index 0000000000000000000000000000000000000000..b158b91781b49f8d589a3da3b266e8c2137fceab --- /dev/null +++ b/tests/nbackend/transformations/test_canonical_clone.py @@ -0,0 +1,63 @@ +import sympy as sp +from pystencils import Field, Assignment, make_slice, TypedSymbol +from pystencils.types.quick import Arr + +from pystencils.backend.kernelcreation import ( + KernelCreationContext, + AstFactory, + FullIterationSpace, +) +from pystencils.backend.transformations import CanonicalClone +from pystencils.backend.ast.structural import PsBlock, PsComment +from pystencils.backend.ast.expressions import PsSymbolExpr +from pystencils.backend.ast.iteration import dfs_preorder + + +def test_clone_entire_ast(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + canon_clone = CanonicalClone(ctx) + + f = Field.create_generic("f", 2, index_shape=(5,)) + rho = sp.Symbol("rho") + u = sp.symbols("u_:2") + + cx = TypedSymbol("cx", Arr(ctx.default_dtype)) + cy = TypedSymbol("cy", Arr(ctx.default_dtype)) + cxs = sp.IndexedBase(cx, shape=(5,)) + cys = sp.IndexedBase(cy, shape=(5,)) + + rho_out = Field.create_generic("rho", 2, index_shape=(1,)) + u_out = Field.create_generic("u", 2, index_shape=(2,)) + + ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:, :], f) + ctx.set_iteration_space(ispace) + + asms = [ + Assignment(cx, (0, 1, -1, 0, 0)), + Assignment(cy, (0, 0, 0, 1, -1)), + Assignment(rho, sum(f.center(i) for i in range(5))), + Assignment(u[0], 1 / rho * sum((f.center(i) * cxs[i]) for i in range(5))), + Assignment(u[1], 1 / rho * sum((f.center(i) * cys[i]) for i in range(5))), + Assignment(rho_out.center(0), rho), + Assignment(u_out.center(0), u[0]), + Assignment(u_out.center(1), u[1]), + ] + + body = PsBlock( + [PsComment("Compute and export density and velocity")] + + [factory.parse_sympy(asm) for asm in asms] + ) + + ast = factory.loops_from_ispace(ispace, body) + ast_clone = canon_clone(ast) + + for orig, clone in zip(dfs_preorder(ast), dfs_preorder(ast_clone), strict=True): + assert type(orig) is type(clone) + assert orig is not clone + + if isinstance(orig, PsSymbolExpr): + assert isinstance(clone, PsSymbolExpr) + + if orig.symbol.name in ("ctr_0", "ctr_1", "rho", "u_0", "u_1", "cx", "cy"): + assert clone.symbol.name == orig.symbol.name + "__0" diff --git a/tests/nbackend/transformations/test_reshape_loops.py b/tests/nbackend/transformations/test_reshape_loops.py new file mode 100644 index 0000000000000000000000000000000000000000..e68cff1b64acbb4f9bbf30dee9ef3f2abe9e59d3 --- /dev/null +++ b/tests/nbackend/transformations/test_reshape_loops.py @@ -0,0 +1,101 @@ +import sympy as sp + +from pystencils import Field, Assignment, make_slice +from pystencils.backend.kernelcreation import ( + KernelCreationContext, + AstFactory, + FullIterationSpace, +) +from pystencils.backend.transformations import ReshapeLoops + +from pystencils.backend.ast.structural import PsDeclaration, PsBlock, PsLoop, PsConditional +from pystencils.backend.ast.expressions import PsConstantExpr, PsLt + + +def test_loop_cutting(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + reshape = ReshapeLoops(ctx) + + x, y, z = sp.symbols("x, y, z") + + f = Field.create_generic("f", 1, index_shape=(2,)) + ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:], archetype_field=f) + ctx.set_iteration_space(ispace) + + loop_body = PsBlock( + [ + factory.parse_sympy(Assignment(x, 2 * z)), + factory.parse_sympy(Assignment(f.center(0), x + y)), + ] + ) + + loop = factory.loops_from_ispace(ispace, loop_body) + + subloops = reshape.cut_loop(loop, [1, 1, 3]) + assert len(subloops) == 3 + + subloop = subloops[0] + assert isinstance(subloop, PsBlock) + assert isinstance(subloop.statements[0], PsDeclaration) + assert subloop.statements[0].declared_symbol.name == "ctr_0__0" + + x_decl = subloop.statements[1] + assert isinstance(x_decl, PsDeclaration) + assert x_decl.declared_symbol.name == "x__0" + + subloop = subloops[1] + assert isinstance(subloop, PsLoop) + assert isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 1 + assert isinstance(subloop.stop, PsConstantExpr) and subloop.stop.constant.value == 3 + + x_decl = subloop.body.statements[0] + assert isinstance(x_decl, PsDeclaration) + assert x_decl.declared_symbol.name == "x__1" + + subloop = subloops[2] + assert isinstance(subloop, PsLoop) + assert isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 3 + assert subloop.stop.structurally_equal(loop.stop) + + +def test_loop_peeling(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + reshape = ReshapeLoops(ctx) + + x, y, z = sp.symbols("x, y, z") + + f = Field.create_generic("f", 1, index_shape=(2,)) + ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:], archetype_field=f) + ctx.set_iteration_space(ispace) + + loop_body = PsBlock([ + factory.parse_sympy(Assignment(x, 2 * z)), + factory.parse_sympy(Assignment(f.center(0), x + y)), + ]) + + loop = factory.loops_from_ispace(ispace, loop_body) + + num_iters = 3 + peeled_iters, peeled_loop = reshape.peel_loop_front(loop, num_iters) + assert len(peeled_iters) == 3 + + for i, iter in enumerate(peeled_iters): + assert isinstance(iter, PsBlock) + + ctr_decl = iter.statements[0] + assert isinstance(ctr_decl, PsDeclaration) + assert ctr_decl.declared_symbol.name == f"ctr_0__{i}" + + cond = iter.statements[1] + assert isinstance(cond, PsConditional) + assert cond.condition.structurally_equal(PsLt(ctr_decl.lhs, loop.stop)) + + subblock = cond.branch_true + assert isinstance(subblock.statements[0], PsDeclaration) + assert subblock.statements[0].declared_symbol.name == f"x__{i}" + + assert peeled_loop.start.structurally_equal(factory.parse_index(num_iters)) + assert peeled_loop.stop.structurally_equal(loop.stop) + assert peeled_loop.body.structurally_equal(loop.body)