diff --git a/pystencils/transformations.py b/pystencils/transformations.py index b3f9431bbf3035aaabd25f6eb430c738dddaf3a7..5e306f2de168994575562a92156408f128ab447c 100644 --- a/pystencils/transformations.py +++ b/pystencils/transformations.py @@ -1206,13 +1206,13 @@ def get_loop_hierarchy(ast_node): return reversed(result) -def get_loop_counter_symbol_hierarchy(astNode): +def get_loop_counter_symbol_hierarchy(ast_node): """Determines the loop counter symbols around a given AST node. - :param astNode: the AST node + :param ast_node: the AST node :return: list of loop counter symbols, where the first list entry is the symbol of the innermost loop """ result = [] - node = astNode + node = ast_node while node is not None: node = get_next_parent_of_type(node, ast.LoopOverCoordinate) if node: diff --git a/pystencils_tests/test_simplification_strategy.py b/pystencils_tests/test_simplification_strategy.py index 189482c006197160e1f49771dc534206f8d0ef9e..5176ae5f49aa45ed952882cb0313d5e5f7754177 100644 --- a/pystencils_tests/test_simplification_strategy.py +++ b/pystencils_tests/test_simplification_strategy.py @@ -1,5 +1,6 @@ import sympy as sp +import pystencils as ps from pystencils import Assignment, AssignmentCollection from pystencils.simp import ( SimplificationStrategy, apply_on_all_subexpressions, @@ -43,3 +44,39 @@ def test_simplification_strategy(): assert 'Adds' in report._repr_html_() assert 'factor' in str(strategy) + + +def test_split_inner_loop(): + dst = ps.fields('dst(8): double[2D]') + s = sp.symbols('s_:8') + x = sp.symbols('x') + subexpressions = [] + main = [ + Assignment(dst[0, 0](0), s[0]), + Assignment(dst[0, 0](1), s[1]), + Assignment(dst[0, 0](2), s[2]), + Assignment(dst[0, 0](3), s[3]), + Assignment(dst[0, 0](4), s[4]), + Assignment(dst[0, 0](5), s[5]), + Assignment(dst[0, 0](6), s[6]), + Assignment(dst[0, 0](7), s[7]), + Assignment(x, sum(s)) + ] + ac = AssignmentCollection(main, subexpressions) + split_groups = [[dst[0, 0](0), dst[0, 0](1)], + [dst[0, 0](2), dst[0, 0](3)], + [dst[0, 0](4), dst[0, 0](5)], + [dst[0, 0](6), dst[0, 0](7), x]] + ac.simplification_hints['split_groups'] = split_groups + ast = ps.create_kernel(ac) + + code = ps.get_code_str(ast) + # we have four inner loops as indicated in split groups (4 elements) plus one outer loop + assert code.count('for') == 5 + + ac = AssignmentCollection(main, subexpressions) + ast = ps.create_kernel(ac) + + code = ps.get_code_str(ast) + # one inner loop and one outer loop + assert code.count('for') == 2