From e46f414906a8938945b709a7a6c4563283a3faba Mon Sep 17 00:00:00 2001
From: markus holzer <markus.holzer@fau.de>
Date: Sun, 9 Aug 2020 11:14:23 +0200
Subject: [PATCH] Added test case for inner loop split

---
 pystencils/transformations.py                 |  6 +--
 .../test_simplification_strategy.py           | 37 +++++++++++++++++++
 2 files changed, 40 insertions(+), 3 deletions(-)

diff --git a/pystencils/transformations.py b/pystencils/transformations.py
index b3f9431bb..5e306f2de 100644
--- a/pystencils/transformations.py
+++ b/pystencils/transformations.py
@@ -1206,13 +1206,13 @@ def get_loop_hierarchy(ast_node):
     return reversed(result)
 
 
-def get_loop_counter_symbol_hierarchy(astNode):
+def get_loop_counter_symbol_hierarchy(ast_node):
     """Determines the loop counter symbols around a given AST node.
-    :param astNode: the AST node
+    :param ast_node: the AST node
     :return: list of loop counter symbols, where the first list entry is the symbol of the innermost loop
     """
     result = []
-    node = astNode
+    node = ast_node
     while node is not None:
         node = get_next_parent_of_type(node, ast.LoopOverCoordinate)
         if node:
diff --git a/pystencils_tests/test_simplification_strategy.py b/pystencils_tests/test_simplification_strategy.py
index 189482c00..5176ae5f4 100644
--- a/pystencils_tests/test_simplification_strategy.py
+++ b/pystencils_tests/test_simplification_strategy.py
@@ -1,5 +1,6 @@
 import sympy as sp
 
+import pystencils as ps
 from pystencils import Assignment, AssignmentCollection
 from pystencils.simp import (
     SimplificationStrategy, apply_on_all_subexpressions,
@@ -43,3 +44,39 @@ def test_simplification_strategy():
     assert 'Adds' in report._repr_html_()
 
     assert 'factor' in str(strategy)
+
+
+def test_split_inner_loop():
+    dst = ps.fields('dst(8): double[2D]')
+    s = sp.symbols('s_:8')
+    x = sp.symbols('x')
+    subexpressions = []
+    main = [
+        Assignment(dst[0, 0](0), s[0]),
+        Assignment(dst[0, 0](1), s[1]),
+        Assignment(dst[0, 0](2), s[2]),
+        Assignment(dst[0, 0](3), s[3]),
+        Assignment(dst[0, 0](4), s[4]),
+        Assignment(dst[0, 0](5), s[5]),
+        Assignment(dst[0, 0](6), s[6]),
+        Assignment(dst[0, 0](7), s[7]),
+        Assignment(x, sum(s))
+    ]
+    ac = AssignmentCollection(main, subexpressions)
+    split_groups = [[dst[0, 0](0), dst[0, 0](1)],
+                    [dst[0, 0](2), dst[0, 0](3)],
+                    [dst[0, 0](4), dst[0, 0](5)],
+                    [dst[0, 0](6), dst[0, 0](7), x]]
+    ac.simplification_hints['split_groups'] = split_groups
+    ast = ps.create_kernel(ac)
+
+    code = ps.get_code_str(ast)
+    # we have four inner loops as indicated in split groups (4 elements) plus one outer loop
+    assert code.count('for') == 5
+
+    ac = AssignmentCollection(main, subexpressions)
+    ast = ps.create_kernel(ac)
+
+    code = ps.get_code_str(ast)
+    # one inner loop and one outer loop
+    assert code.count('for') == 2
-- 
GitLab