From b2f328c017853bef92d9f9dc5ad901b7889a1a63 Mon Sep 17 00:00:00 2001 From: Daniel Bauer <daniel.j.bauer@fau.de> Date: Mon, 14 Aug 2023 13:22:13 +0200 Subject: [PATCH] insert nodes precisely before/after the before-/after- argument --- pystencils/astnodes.py | 16 ---------------- .../test_move_constant_before_loop.py | 4 ++-- 2 files changed, 2 insertions(+), 18 deletions(-) diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index d91c1dad7..c9d66ae26 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -345,14 +345,6 @@ class Block(Node): assert self._nodes.count(insert_before) == 1 idx = self._nodes.index(insert_before) - # move all assignment (definitions to the top) - if isinstance(new_node, SympyAssignment) and new_node.is_declaration: - while idx > 0: - pn = self._nodes[idx - 1] - if isinstance(pn, LoopOverCoordinate) or isinstance(pn, Conditional): - idx -= 1 - else: - break if not if_not_exists or self._nodes[idx] != new_node: self._nodes.insert(idx, new_node) @@ -361,14 +353,6 @@ class Block(Node): assert self._nodes.count(insert_after) == 1 idx = self._nodes.index(insert_after) + 1 - # move all assignment (definitions to the top) - if isinstance(new_node, SympyAssignment) and new_node.is_declaration: - while idx > 0: - pn = self._nodes[idx - 1] - if isinstance(pn, LoopOverCoordinate) or isinstance(pn, Conditional): - idx -= 1 - else: - break if not if_not_exists or not (self._nodes[idx - 1] == new_node or (idx < len(self._nodes) and self._nodes[idx] == new_node)): self._nodes.insert(idx, new_node) diff --git a/pystencils_tests/test_move_constant_before_loop.py b/pystencils_tests/test_move_constant_before_loop.py index ea736dd18..fb9e537b2 100644 --- a/pystencils_tests/test_move_constant_before_loop.py +++ b/pystencils_tests/test_move_constant_before_loop.py @@ -25,9 +25,9 @@ def test_symbol_renaming(): loops = block.atoms(LoopOverCoordinate) assert len(loops) == 2 - assert len(block.args[2].body.args) == 1 + assert len(block.args[1].body.args) == 1 assert len(block.args[3].body.args) == 2 for loop in loops: assert len(loop.parent.args) == 4 # 2 loops + 2 subexpressions - assert loop.parent.args[0].lhs.name != loop.parent.args[1].lhs.name + assert loop.parent.args[0].lhs.name != loop.parent.args[2].lhs.name -- GitLab