diff --git a/src/pystencils/backend/transformations/reshape_loops.py b/src/pystencils/backend/transformations/reshape_loops.py index 6963bee0b2e43bc6bac58a6c96de5f4a35e57148..b5f7dde01317df412be9c0b34f906b5c2f2f8706 100644 --- a/src/pystencils/backend/transformations/reshape_loops.py +++ b/src/pystencils/backend/transformations/reshape_loops.py @@ -4,7 +4,7 @@ from ..kernelcreation import KernelCreationContext, Typifier from ..kernelcreation.ast_factory import AstFactory, IndexParsable from ..ast.structural import PsLoop, PsBlock, PsConditional, PsDeclaration -from ..ast.expressions import PsExpression, PsConstantExpr, PsLt +from ..ast.expressions import PsExpression, PsConstantExpr, PsGe, PsLt from ..constants import PsConstant from .canonical_clone import CanonicalClone, CloneContext @@ -48,7 +48,9 @@ class ReshapeLoops: peeled_ctr = self._factory.parse_index( cc.get_replacement(loop.counter.symbol) ) - peeled_idx = self._typify(loop.start + PsExpression.make(PsConstant(i))) + peeled_idx = self._elim_constants( + self._typify(loop.start + PsExpression.make(PsConstant(i)) * loop.step) + ) counter_decl = PsDeclaration(peeled_ctr, peeled_idx) peeled_block = self._canon_clone.visit(loop.body, cc) @@ -65,11 +67,71 @@ class ReshapeLoops: peeled_iters.append(peeled_block) loop.start = self._elim_constants( - self._typify(loop.start + PsExpression.make(PsConstant(num_iterations))) + self._typify( + loop.start + PsExpression.make(PsConstant(num_iterations)) * loop.step + ) ) return peeled_iters, loop + def peel_loop_back( + self, loop: PsLoop, num_iterations: int, omit_range_check: bool = False + ) -> tuple[PsLoop, Sequence[PsBlock]]: + """Peel off iterations from the back of a loop. + + Removes ``num_iterations`` from the back of the given loop and returns them as a sequence of + independent blocks. + + Args: + loop: The loop node from which to peel iterations + num_iterations: The number of iterations to peel off + omit_range_check: If set to `True`, assume that the peeled-off iterations will always + be executed, and omit their enclosing conditional. + + Returns: + Tuple containing the modified loop and the peeled-off iterations (sequence of blocks). + """ + + if not ( + isinstance(loop.step, PsConstantExpr) and loop.step.constant.value == 1 + ): + raise NotImplementedError( + "Peeling iterations from the back of loops is only implemented" + "for loops with unit step. Implementation is deferred until" + "loop range canonicalization is available (also needed for the" + "vectorizer)." + ) + + peeled_iters: list[PsBlock] = [] + + for i in range(num_iterations)[::-1]: + cc = CloneContext(self._ctx) + cc.symbol_decl(loop.counter.symbol) + peeled_ctr = self._factory.parse_index( + cc.get_replacement(loop.counter.symbol) + ) + peeled_idx = self._typify(loop.stop - PsExpression.make(PsConstant(i + 1))) + + counter_decl = PsDeclaration(peeled_ctr, peeled_idx) + peeled_block = self._canon_clone.visit(loop.body, cc) + + if omit_range_check: + peeled_block.statements = [counter_decl] + peeled_block.statements + else: + iter_condition = PsGe(peeled_ctr, loop.start) + peeled_block.statements = [ + counter_decl, + PsConditional(iter_condition, PsBlock(peeled_block.statements)), + ] + + peeled_iters.append(peeled_block) + + loop.stop = self._elim_constants( + self._typify(loop.stop - PsExpression.make(PsConstant(num_iterations))) + ) + + return loop, peeled_iters + def cut_loop( self, loop: PsLoop, cutting_points: Sequence[IndexParsable] ) -> Sequence[PsLoop | PsBlock]: diff --git a/tests/nbackend/transformations/test_reshape_loops.py b/tests/nbackend/transformations/test_reshape_loops.py index e68cff1b64acbb4f9bbf30dee9ef3f2abe9e59d3..c52e98ba0fb2269afc4c3b968bcfc599cd400a7d 100644 --- a/tests/nbackend/transformations/test_reshape_loops.py +++ b/tests/nbackend/transformations/test_reshape_loops.py @@ -8,8 +8,13 @@ from pystencils.backend.kernelcreation import ( ) from pystencils.backend.transformations import ReshapeLoops -from pystencils.backend.ast.structural import PsDeclaration, PsBlock, PsLoop, PsConditional -from pystencils.backend.ast.expressions import PsConstantExpr, PsLt +from pystencils.backend.ast.structural import ( + PsDeclaration, + PsBlock, + PsLoop, + PsConditional, +) +from pystencils.backend.ast.expressions import PsConstantExpr, PsGe, PsLt def test_loop_cutting(): @@ -43,10 +48,12 @@ def test_loop_cutting(): x_decl = subloop.statements[1] assert isinstance(x_decl, PsDeclaration) assert x_decl.declared_symbol.name == "x__0" - + subloop = subloops[1] assert isinstance(subloop, PsLoop) - assert isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 1 + assert ( + isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 1 + ) assert isinstance(subloop.stop, PsConstantExpr) and subloop.stop.constant.value == 3 x_decl = subloop.body.statements[0] @@ -55,7 +62,9 @@ def test_loop_cutting(): subloop = subloops[2] assert isinstance(subloop, PsLoop) - assert isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 3 + assert ( + isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 3 + ) assert subloop.stop.structurally_equal(loop.stop) @@ -67,19 +76,23 @@ def test_loop_peeling(): x, y, z = sp.symbols("x, y, z") f = Field.create_generic("f", 1, index_shape=(2,)) - ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:], archetype_field=f) + ispace = FullIterationSpace.create_from_slice( + ctx, slice(2, 11, 3), archetype_field=f + ) ctx.set_iteration_space(ispace) - loop_body = PsBlock([ - factory.parse_sympy(Assignment(x, 2 * z)), - factory.parse_sympy(Assignment(f.center(0), x + y)), - ]) + loop_body = PsBlock( + [ + factory.parse_sympy(Assignment(x, 2 * z)), + factory.parse_sympy(Assignment(f.center(0), x + y)), + ] + ) loop = factory.loops_from_ispace(ispace, loop_body) - num_iters = 3 + num_iters = 2 peeled_iters, peeled_loop = reshape.peel_loop_front(loop, num_iters) - assert len(peeled_iters) == 3 + assert len(peeled_iters) == num_iters for i, iter in enumerate(peeled_iters): assert isinstance(iter, PsBlock) @@ -87,6 +100,8 @@ def test_loop_peeling(): ctr_decl = iter.statements[0] assert isinstance(ctr_decl, PsDeclaration) assert ctr_decl.declared_symbol.name == f"ctr_0__{i}" + ctr_value = {0: 2, 1: 5}[i] + assert ctr_decl.rhs.structurally_equal(factory.parse_index(ctr_value)) cond = iter.statements[1] assert isinstance(cond, PsConditional) @@ -96,6 +111,53 @@ def test_loop_peeling(): assert isinstance(subblock.statements[0], PsDeclaration) assert subblock.statements[0].declared_symbol.name == f"x__{i}" - assert peeled_loop.start.structurally_equal(factory.parse_index(num_iters)) + assert peeled_loop.start.structurally_equal(factory.parse_index(8)) assert peeled_loop.stop.structurally_equal(loop.stop) assert peeled_loop.body.structurally_equal(loop.body) + + +def test_loop_peeling_back(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + reshape = ReshapeLoops(ctx) + + x, y, z = sp.symbols("x, y, z") + + f = Field.create_generic("f", 1, index_shape=(2,)) + ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:], archetype_field=f) + ctx.set_iteration_space(ispace) + + loop_body = PsBlock( + [ + factory.parse_sympy(Assignment(x, 2 * z)), + factory.parse_sympy(Assignment(f.center(0), x + y)), + ] + ) + + loop = factory.loops_from_ispace(ispace, loop_body) + + num_iters = 3 + peeled_loop, peeled_iters = reshape.peel_loop_back(loop, num_iters) + assert len(peeled_iters) == 3 + + for i, iter in enumerate(peeled_iters): + assert isinstance(iter, PsBlock) + + ctr_decl = iter.statements[0] + assert isinstance(ctr_decl, PsDeclaration) + assert ctr_decl.declared_symbol.name == f"ctr_0__{i}" + + cond = iter.statements[1] + assert isinstance(cond, PsConditional) + assert cond.condition.structurally_equal(PsGe(ctr_decl.lhs, loop.start)) + + subblock = cond.branch_true + assert isinstance(subblock.statements[0], PsDeclaration) + assert subblock.statements[0].declared_symbol.name == f"x__{i}" + + assert peeled_loop.start.structurally_equal(loop.start) + assert peeled_loop.stop.structurally_equal( + factory.loops_from_ispace(ispace, loop_body).stop + - factory.parse_index(num_iters) + ) + assert peeled_loop.body.structurally_equal(loop.body)