diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index 10e855df2fd9a661f053adbdb43dbfb8d969fdf4..b081f752945133d56a4d32407e335b26d13614ec 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -310,6 +310,7 @@ class Block(Node): def insert_before(self, new_node, insert_before): new_node.parent = self + assert self._nodes.count(insert_before) == 1 idx = self._nodes.index(insert_before) # move all assignment (definitions to the top) @@ -337,6 +338,7 @@ class Block(Node): return tmp def replace(self, child, replacements): + assert self._nodes.count(child) == 1 idx = self._nodes.index(child) del self._nodes[idx] if type(replacements) is list: diff --git a/pystencils/kernelcreation.py b/pystencils/kernelcreation.py index 57a97bcb830af86cd702c412a816917bd0131e37..b968d4c621faca9012d0a5997e877c36b2ae54d8 100644 --- a/pystencils/kernelcreation.py +++ b/pystencils/kernelcreation.py @@ -287,7 +287,9 @@ def create_staggered_kernel(assignments, gpu_exclusive_conditions=False, **kwarg last_conditional = Conditional(condition(direction), Block(sp_assignments)) final_assignments.append(last_conditional) - prepend_optimizations = [remove_conditionals_in_staggered_kernel, move_constants_before_loop] + remove_start_conditional = any([gl[0] == 0 for gl in ghost_layers]) + prepend_optimizations = [lambda ast: remove_conditionals_in_staggered_kernel(ast, remove_start_conditional), + move_constants_before_loop] ast = create_kernel(final_assignments, ghost_layers=ghost_layers, cpu_prepend_optimizations=prepend_optimizations, **kwargs) return ast diff --git a/pystencils/transformations.py b/pystencils/transformations.py index ea340e930631a159cf81cf9d56337b315e27c88b..762c36136cd7f3eb541a6075f4b24021813c82ad 100644 --- a/pystencils/transformations.py +++ b/pystencils/transformations.py @@ -630,13 +630,18 @@ def move_constants_before_loop(ast_node): else: target.insert_before(child, child_to_insert_before) elif exists_already and exists_already.rhs == child.rhs: - pass + if target.args.index(exists_already) > target.args.index(child_to_insert_before): + assert target.args.count(exists_already) == 1 + assert target.args.count(child_to_insert_before) == 1 + target.args.remove(exists_already) + target.insert_before(exists_already, child_to_insert_before) else: # this variable already exists in outer block, but with different rhs # -> symbol has to be renamed assert isinstance(child.lhs, TypedSymbol) new_symbol = TypedSymbol(sp.Dummy().name, child.lhs.dtype) - target.insert_before(ast.SympyAssignment(new_symbol, child.rhs), child_to_insert_before) + target.insert_before(ast.SympyAssignment(new_symbol, child.rhs, is_const=child.is_const), + child_to_insert_before) substitute_variables[child.lhs] = new_symbol @@ -1064,15 +1069,19 @@ def insert_casts(node): return node.func(*args) -def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction) -> None: - """Removes conditionals of a kernel that iterates over staggered positions by splitting the loops at last element""" +def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction, include_first=True) -> None: + """Removes conditionals of a kernel that iterates over staggered positions by splitting the loops at last or + first and last element""" all_inner_loops = [l for l in function_node.atoms(ast.LoopOverCoordinate) if l.is_innermost_loop] assert len(all_inner_loops) == 1, "Transformation works only on kernels with exactly one inner loop" inner_loop = all_inner_loops.pop() for loop in parents_of_type(inner_loop, ast.LoopOverCoordinate, include_current=True): - cut_loop(loop, [loop.stop - 1]) + if include_first: + cut_loop(loop, [loop.start + 1, loop.stop - 1]) + else: + cut_loop(loop, [loop.stop - 1]) simplify_conditionals(function_node.body, loop_counter_simplification=True) cleanup_blocks(function_node.body) diff --git a/pystencils_tests/test_staggered_kernel.py b/pystencils_tests/test_staggered_kernel.py index a914b91806ee53ab200101edfccfcecaae81207b..4db538bf8f83d6779cb7dfd612fa9928acc23965 100644 --- a/pystencils_tests/test_staggered_kernel.py +++ b/pystencils_tests/test_staggered_kernel.py @@ -75,3 +75,11 @@ def test_staggered_subexpressions(): assignments = [ps.Assignment(j.staggered_access("W"), c), ps.Assignment(c, 1)] ps.create_staggered_kernel(assignments, target=dh.default_target).compile() + + +def test_staggered_loop_cutting(): + dh = ps.create_data_handling((4, 4), periodicity=True, default_target='cpu') + j = dh.add_array('j', values_per_cell=4, field_type=ps.FieldType.STAGGERED) + assignments = [ps.Assignment(j.staggered_access("SW"), 1)] + ast = ps.create_staggered_kernel(assignments, target=dh.default_target) + assert not ast.atoms(ps.astnodes.Conditional)