diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index d91c1dad7036c8cc2f3e2664e047c5b9a13dbd12..c9d66ae260a9dc89697999b437d960ca06d21c77 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -345,14 +345,6 @@ class Block(Node): assert self._nodes.count(insert_before) == 1 idx = self._nodes.index(insert_before) - # move all assignment (definitions to the top) - if isinstance(new_node, SympyAssignment) and new_node.is_declaration: - while idx > 0: - pn = self._nodes[idx - 1] - if isinstance(pn, LoopOverCoordinate) or isinstance(pn, Conditional): - idx -= 1 - else: - break if not if_not_exists or self._nodes[idx] != new_node: self._nodes.insert(idx, new_node) @@ -361,14 +353,6 @@ class Block(Node): assert self._nodes.count(insert_after) == 1 idx = self._nodes.index(insert_after) + 1 - # move all assignment (definitions to the top) - if isinstance(new_node, SympyAssignment) and new_node.is_declaration: - while idx > 0: - pn = self._nodes[idx - 1] - if isinstance(pn, LoopOverCoordinate) or isinstance(pn, Conditional): - idx -= 1 - else: - break if not if_not_exists or not (self._nodes[idx - 1] == new_node or (idx < len(self._nodes) and self._nodes[idx] == new_node)): self._nodes.insert(idx, new_node) diff --git a/pystencils/transformations.py b/pystencils/transformations.py index 02806d62223c8443cc2bb91740afd2c0ef6bda55..e07d871e97d413df0f8a9087a5bfbdc590b8067b 100644 --- a/pystencils/transformations.py +++ b/pystencils/transformations.py @@ -4,9 +4,11 @@ import warnings from collections import OrderedDict from copy import deepcopy from types import MappingProxyType +from typing import Set import sympy as sp +import pystencils as ps import pystencils.astnodes as ast from pystencils.assignment import Assignment from pystencils.typing import (CastFunc, PointerType, StructType, TypedSymbol, get_base_type, @@ -582,21 +584,65 @@ def move_constants_before_loop(ast_node): """ assert isinstance(node.parent, ast.Block) + def modifies_or_declares(node: ast.Node, symbol_names: Set[str]) -> bool: + if isinstance(node, (ps.Assignment, ast.SympyAssignment)): + if isinstance(node.lhs, ast.ResolvedFieldAccess): + return node.lhs.typed_symbol.name in symbol_names + else: + return node.lhs.name in symbol_names + elif isinstance(node, ast.Block): + for arg in node.args: + if isinstance(arg, ast.SympyAssignment) and arg.is_declaration: + continue + if modifies_or_declares(arg, symbol_names): + return True + return False + elif isinstance(node, ast.LoopOverCoordinate): + return modifies_or_declares(node.body, symbol_names) + elif isinstance(node, ast.Conditional): + return ( + modifies_or_declares(node.true_block, symbol_names) + or (node.false_block and modifies_or_declares(node.false_block, symbol_names)) + ) + elif isinstance(node, ast.KernelFunction): + return False + else: + defs = {s.name for s in node.symbols_defined} + return bool(symbol_names.intersection(defs)) + + dependencies = {s.name for s in node.undefined_symbols} + last_block = node.parent last_block_child = node element = node.parent prev_element = node + while element: - if isinstance(element, ast.Block): + if isinstance(element, (ast.Conditional, ast.KernelFunction)): + # Never move out of Conditionals or KernelFunctions. + break + + elif isinstance(element, ast.Block): last_block = element last_block_child = prev_element - if isinstance(element, ast.Conditional): - break + if any(modifies_or_declares(sibling, dependencies) for sibling in element.args): + # The node depends on one of the statements in this block. + # Do not move further out. + break + + elif isinstance(element, ast.LoopOverCoordinate): + if element.loop_counter_symbol.name in dependencies: + # The node depends on the loop counter. + # Do not move out of this loop. + break + else: - critical_symbols = set([s.name for s in element.symbols_defined]) - if set([s.name for s in node.undefined_symbols]).intersection(critical_symbols): - break + raise NotImplementedError(f'Due to defensive programming we handle only specific expressions.\n' + f'The expression {element} of type {type(element)} is not known yet.') + + # No dependencies to symbols defined/modified within the current element. + # We can move the node up one level and in front of the current element. prev_element = element element = element.parent return last_block, last_block_child diff --git a/pystencils_tests/test_move_constant_before_loop.py b/pystencils_tests/test_move_constant_before_loop.py index ea736dd183459a896ccf5d86662386a1a396c85c..0f5300f1b8c14747456efb6f1834a3fc3f3bc46f 100644 --- a/pystencils_tests/test_move_constant_before_loop.py +++ b/pystencils_tests/test_move_constant_before_loop.py @@ -25,9 +25,40 @@ def test_symbol_renaming(): loops = block.atoms(LoopOverCoordinate) assert len(loops) == 2 - assert len(block.args[2].body.args) == 1 + assert len(block.args[1].body.args) == 1 assert len(block.args[3].body.args) == 2 for loop in loops: assert len(loop.parent.args) == 4 # 2 loops + 2 subexpressions - assert loop.parent.args[0].lhs.name != loop.parent.args[1].lhs.name + assert loop.parent.args[0].lhs.name != loop.parent.args[2].lhs.name + + +def test_keep_order_of_accesses(): + f = ps.fields("f: [1D]") + x = TypedSymbol("x", np.float64) + n = 5 + + loop = LoopOverCoordinate(Block([SympyAssignment(x, f[0]), + SympyAssignment(f[1], 2 * x)]), + 0, 0, n) + block = Block([loop]) + + ps.transformations.resolve_field_accesses(block) + new_loops = ps.transformations.cut_loop(loop, [n - 1]) + ps.transformations.move_constants_before_loop(new_loops.args[1]) + + kernel_func = ps.astnodes.KernelFunction( + block, ps.Target.CPU, ps.Backend.C, ps.cpu.cpujit.make_python_function, None + ) + kernel = kernel_func.compile() + + print(ps.show_code(kernel_func)) + + f_arr = np.ones(n + 1) + kernel(f=f_arr) + + print(f_arr) + + assert np.allclose(f_arr, np.array([ + 1, 2, 4, 8, 16, 32 + ]))