From e46f414906a8938945b709a7a6c4563283a3faba Mon Sep 17 00:00:00 2001 From: markus holzer <markus.holzer@fau.de> Date: Sun, 9 Aug 2020 11:14:23 +0200 Subject: [PATCH] Added test case for inner loop split --- pystencils/transformations.py | 6 +-- .../test_simplification_strategy.py | 37 +++++++++++++++++++ 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/pystencils/transformations.py b/pystencils/transformations.py index b3f9431bb..5e306f2de 100644 --- a/pystencils/transformations.py +++ b/pystencils/transformations.py @@ -1206,13 +1206,13 @@ def get_loop_hierarchy(ast_node): return reversed(result) -def get_loop_counter_symbol_hierarchy(astNode): +def get_loop_counter_symbol_hierarchy(ast_node): """Determines the loop counter symbols around a given AST node. - :param astNode: the AST node + :param ast_node: the AST node :return: list of loop counter symbols, where the first list entry is the symbol of the innermost loop """ result = [] - node = astNode + node = ast_node while node is not None: node = get_next_parent_of_type(node, ast.LoopOverCoordinate) if node: diff --git a/pystencils_tests/test_simplification_strategy.py b/pystencils_tests/test_simplification_strategy.py index 189482c00..5176ae5f4 100644 --- a/pystencils_tests/test_simplification_strategy.py +++ b/pystencils_tests/test_simplification_strategy.py @@ -1,5 +1,6 @@ import sympy as sp +import pystencils as ps from pystencils import Assignment, AssignmentCollection from pystencils.simp import ( SimplificationStrategy, apply_on_all_subexpressions, @@ -43,3 +44,39 @@ def test_simplification_strategy(): assert 'Adds' in report._repr_html_() assert 'factor' in str(strategy) + + +def test_split_inner_loop(): + dst = ps.fields('dst(8): double[2D]') + s = sp.symbols('s_:8') + x = sp.symbols('x') + subexpressions = [] + main = [ + Assignment(dst[0, 0](0), s[0]), + Assignment(dst[0, 0](1), s[1]), + Assignment(dst[0, 0](2), s[2]), + Assignment(dst[0, 0](3), s[3]), + Assignment(dst[0, 0](4), s[4]), + Assignment(dst[0, 0](5), s[5]), + Assignment(dst[0, 0](6), s[6]), + Assignment(dst[0, 0](7), s[7]), + Assignment(x, sum(s)) + ] + ac = AssignmentCollection(main, subexpressions) + split_groups = [[dst[0, 0](0), dst[0, 0](1)], + [dst[0, 0](2), dst[0, 0](3)], + [dst[0, 0](4), dst[0, 0](5)], + [dst[0, 0](6), dst[0, 0](7), x]] + ac.simplification_hints['split_groups'] = split_groups + ast = ps.create_kernel(ac) + + code = ps.get_code_str(ast) + # we have four inner loops as indicated in split groups (4 elements) plus one outer loop + assert code.count('for') == 5 + + ac = AssignmentCollection(main, subexpressions) + ast = ps.create_kernel(ac) + + code = ps.get_code_str(ast) + # one inner loop and one outer loop + assert code.count('for') == 2 -- GitLab