diff --git a/transformations.py b/transformations.py index 2347fb7e0b37fc18a0a7824ae11c3598214b98c2..81dd22c2938feb611aba44d64bf3360c7d2e566b 100644 --- a/transformations.py +++ b/transformations.py @@ -455,11 +455,9 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), def move_constants_before_loop(ast_node): - """ - Moves :class:`pystencils.ast.SympyAssignment` nodes out of loop body if they are iteration independent. + """Moves :class:`pystencils.ast.SympyAssignment` nodes out of loop body if they are iteration independent. + Call this after creating the loop structure with :func:`make_loop_over_domain` - :param ast_node: - :return: """ def find_block_to_move_to(node): """ @@ -468,7 +466,6 @@ def move_constants_before_loop(ast_node): :param node: SympyAssignment inside a Block :return blockToInsertTo, childOfBlockToInsertBefore """ - assert isinstance(node, ast.SympyAssignment) assert isinstance(node.parent, ast.Block) last_block = node.parent @@ -510,18 +507,18 @@ def move_constants_before_loop(ast_node): for block in all_blocks: children = block.take_child_nodes() for child in children: - if not isinstance(child, ast.SympyAssignment): - block.append(child) + target, child_to_insert_before = find_block_to_move_to(child) + if target == block: # movement not possible + target.append(child) else: - target, child_to_insert_before = find_block_to_move_to(child) - if target == block: # movement not possible - target.append(child) + if isinstance(child, ast.SympyAssignment): + exists_already = check_if_assignment_already_in_block(child, target) else: - existing_assignment = check_if_assignment_already_in_block(child, target) - if not existing_assignment: - target.insert_before(child, child_to_insert_before) - else: - assert existing_assignment.rhs == child.rhs, "Symbol with same name exists already" + exists_already = False + if not exists_already: + target.insert_before(child, child_to_insert_before) + else: + assert exists_already.rhs == child.rhs, "Symbol with same name exists already" def split_inner_loop(ast_node: ast.Node, symbol_groups):