From 64fcb5d5d6bd268a6f81cf0704af95bf912d81fb Mon Sep 17 00:00:00 2001
From: Daniel Bauer <daniel.j.bauer@fau.de>
Date: Tue, 28 May 2024 10:04:09 +0200
Subject: [PATCH] Implement loop peeling from back

---
 .../backend/transformations/reshape_loops.py  | 68 +++++++++++++-
 .../transformations/test_reshape_loops.py     | 88 ++++++++++++++++---
 2 files changed, 140 insertions(+), 16 deletions(-)

diff --git a/src/pystencils/backend/transformations/reshape_loops.py b/src/pystencils/backend/transformations/reshape_loops.py
index 6963bee0b..b5f7dde01 100644
--- a/src/pystencils/backend/transformations/reshape_loops.py
+++ b/src/pystencils/backend/transformations/reshape_loops.py
@@ -4,7 +4,7 @@ from ..kernelcreation import KernelCreationContext, Typifier
 from ..kernelcreation.ast_factory import AstFactory, IndexParsable
 
 from ..ast.structural import PsLoop, PsBlock, PsConditional, PsDeclaration
-from ..ast.expressions import PsExpression, PsConstantExpr, PsLt
+from ..ast.expressions import PsExpression, PsConstantExpr, PsGe, PsLt
 from ..constants import PsConstant
 
 from .canonical_clone import CanonicalClone, CloneContext
@@ -48,7 +48,9 @@ class ReshapeLoops:
             peeled_ctr = self._factory.parse_index(
                 cc.get_replacement(loop.counter.symbol)
             )
-            peeled_idx = self._typify(loop.start + PsExpression.make(PsConstant(i)))
+            peeled_idx = self._elim_constants(
+                self._typify(loop.start + PsExpression.make(PsConstant(i)) * loop.step)
+            )
 
             counter_decl = PsDeclaration(peeled_ctr, peeled_idx)
             peeled_block = self._canon_clone.visit(loop.body, cc)
@@ -65,11 +67,71 @@ class ReshapeLoops:
             peeled_iters.append(peeled_block)
 
         loop.start = self._elim_constants(
-            self._typify(loop.start + PsExpression.make(PsConstant(num_iterations)))
+            self._typify(
+                loop.start + PsExpression.make(PsConstant(num_iterations)) * loop.step
+            )
         )
 
         return peeled_iters, loop
 
+    def peel_loop_back(
+        self, loop: PsLoop, num_iterations: int, omit_range_check: bool = False
+    ) -> tuple[PsLoop, Sequence[PsBlock]]:
+        """Peel off iterations from the back of a loop.
+
+        Removes ``num_iterations`` from the back of the given loop and returns them as a sequence of
+        independent blocks.
+
+        Args:
+            loop: The loop node from which to peel iterations
+            num_iterations: The number of iterations to peel off
+            omit_range_check: If set to `True`, assume that the peeled-off iterations will always
+              be executed, and omit their enclosing conditional.
+
+        Returns:
+            Tuple containing the modified loop and the peeled-off iterations (sequence of blocks).
+        """
+
+        if not (
+            isinstance(loop.step, PsConstantExpr) and loop.step.constant.value == 1
+        ):
+            raise NotImplementedError(
+                "Peeling iterations from the back of loops is only implemented"
+                "for loops with unit step. Implementation is deferred until"
+                "loop range canonicalization is available (also needed for the"
+                "vectorizer)."
+            )
+
+        peeled_iters: list[PsBlock] = []
+
+        for i in range(num_iterations)[::-1]:
+            cc = CloneContext(self._ctx)
+            cc.symbol_decl(loop.counter.symbol)
+            peeled_ctr = self._factory.parse_index(
+                cc.get_replacement(loop.counter.symbol)
+            )
+            peeled_idx = self._typify(loop.stop - PsExpression.make(PsConstant(i + 1)))
+
+            counter_decl = PsDeclaration(peeled_ctr, peeled_idx)
+            peeled_block = self._canon_clone.visit(loop.body, cc)
+
+            if omit_range_check:
+                peeled_block.statements = [counter_decl] + peeled_block.statements
+            else:
+                iter_condition = PsGe(peeled_ctr, loop.start)
+                peeled_block.statements = [
+                    counter_decl,
+                    PsConditional(iter_condition, PsBlock(peeled_block.statements)),
+                ]
+
+            peeled_iters.append(peeled_block)
+
+        loop.stop = self._elim_constants(
+            self._typify(loop.stop - PsExpression.make(PsConstant(num_iterations)))
+        )
+
+        return loop, peeled_iters
+
     def cut_loop(
         self, loop: PsLoop, cutting_points: Sequence[IndexParsable]
     ) -> Sequence[PsLoop | PsBlock]:
diff --git a/tests/nbackend/transformations/test_reshape_loops.py b/tests/nbackend/transformations/test_reshape_loops.py
index e68cff1b6..c52e98ba0 100644
--- a/tests/nbackend/transformations/test_reshape_loops.py
+++ b/tests/nbackend/transformations/test_reshape_loops.py
@@ -8,8 +8,13 @@ from pystencils.backend.kernelcreation import (
 )
 from pystencils.backend.transformations import ReshapeLoops
 
-from pystencils.backend.ast.structural import PsDeclaration, PsBlock, PsLoop, PsConditional
-from pystencils.backend.ast.expressions import PsConstantExpr, PsLt
+from pystencils.backend.ast.structural import (
+    PsDeclaration,
+    PsBlock,
+    PsLoop,
+    PsConditional,
+)
+from pystencils.backend.ast.expressions import PsConstantExpr, PsGe, PsLt
 
 
 def test_loop_cutting():
@@ -43,10 +48,12 @@ def test_loop_cutting():
     x_decl = subloop.statements[1]
     assert isinstance(x_decl, PsDeclaration)
     assert x_decl.declared_symbol.name == "x__0"
-    
+
     subloop = subloops[1]
     assert isinstance(subloop, PsLoop)
-    assert isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 1
+    assert (
+        isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 1
+    )
     assert isinstance(subloop.stop, PsConstantExpr) and subloop.stop.constant.value == 3
 
     x_decl = subloop.body.statements[0]
@@ -55,7 +62,9 @@ def test_loop_cutting():
 
     subloop = subloops[2]
     assert isinstance(subloop, PsLoop)
-    assert isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 3
+    assert (
+        isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 3
+    )
     assert subloop.stop.structurally_equal(loop.stop)
 
 
@@ -67,19 +76,23 @@ def test_loop_peeling():
     x, y, z = sp.symbols("x, y, z")
 
     f = Field.create_generic("f", 1, index_shape=(2,))
-    ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:], archetype_field=f)
+    ispace = FullIterationSpace.create_from_slice(
+        ctx, slice(2, 11, 3), archetype_field=f
+    )
     ctx.set_iteration_space(ispace)
 
-    loop_body = PsBlock([
-        factory.parse_sympy(Assignment(x, 2 * z)),
-        factory.parse_sympy(Assignment(f.center(0), x + y)),
-    ])
+    loop_body = PsBlock(
+        [
+            factory.parse_sympy(Assignment(x, 2 * z)),
+            factory.parse_sympy(Assignment(f.center(0), x + y)),
+        ]
+    )
 
     loop = factory.loops_from_ispace(ispace, loop_body)
 
-    num_iters = 3
+    num_iters = 2
     peeled_iters, peeled_loop = reshape.peel_loop_front(loop, num_iters)
-    assert len(peeled_iters) == 3
+    assert len(peeled_iters) == num_iters
 
     for i, iter in enumerate(peeled_iters):
         assert isinstance(iter, PsBlock)
@@ -87,6 +100,8 @@ def test_loop_peeling():
         ctr_decl = iter.statements[0]
         assert isinstance(ctr_decl, PsDeclaration)
         assert ctr_decl.declared_symbol.name == f"ctr_0__{i}"
+        ctr_value = {0: 2, 1: 5}[i]
+        assert ctr_decl.rhs.structurally_equal(factory.parse_index(ctr_value))
 
         cond = iter.statements[1]
         assert isinstance(cond, PsConditional)
@@ -96,6 +111,53 @@ def test_loop_peeling():
         assert isinstance(subblock.statements[0], PsDeclaration)
         assert subblock.statements[0].declared_symbol.name == f"x__{i}"
 
-    assert peeled_loop.start.structurally_equal(factory.parse_index(num_iters))
+    assert peeled_loop.start.structurally_equal(factory.parse_index(8))
     assert peeled_loop.stop.structurally_equal(loop.stop)
     assert peeled_loop.body.structurally_equal(loop.body)
+
+
+def test_loop_peeling_back():
+    ctx = KernelCreationContext()
+    factory = AstFactory(ctx)
+    reshape = ReshapeLoops(ctx)
+
+    x, y, z = sp.symbols("x, y, z")
+
+    f = Field.create_generic("f", 1, index_shape=(2,))
+    ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:], archetype_field=f)
+    ctx.set_iteration_space(ispace)
+
+    loop_body = PsBlock(
+        [
+            factory.parse_sympy(Assignment(x, 2 * z)),
+            factory.parse_sympy(Assignment(f.center(0), x + y)),
+        ]
+    )
+
+    loop = factory.loops_from_ispace(ispace, loop_body)
+
+    num_iters = 3
+    peeled_loop, peeled_iters = reshape.peel_loop_back(loop, num_iters)
+    assert len(peeled_iters) == 3
+
+    for i, iter in enumerate(peeled_iters):
+        assert isinstance(iter, PsBlock)
+
+        ctr_decl = iter.statements[0]
+        assert isinstance(ctr_decl, PsDeclaration)
+        assert ctr_decl.declared_symbol.name == f"ctr_0__{i}"
+
+        cond = iter.statements[1]
+        assert isinstance(cond, PsConditional)
+        assert cond.condition.structurally_equal(PsGe(ctr_decl.lhs, loop.start))
+
+        subblock = cond.branch_true
+        assert isinstance(subblock.statements[0], PsDeclaration)
+        assert subblock.statements[0].declared_symbol.name == f"x__{i}"
+
+    assert peeled_loop.start.structurally_equal(loop.start)
+    assert peeled_loop.stop.structurally_equal(
+        factory.loops_from_ispace(ispace, loop_body).stop
+        - factory.parse_index(num_iters)
+    )
+    assert peeled_loop.body.structurally_equal(loop.body)
-- 
GitLab