From 1c3e9880f2c8eb0ae04271a1f1361339061a50ff Mon Sep 17 00:00:00 2001 From: Nils Kohl <nils.kohl@fau.de> Date: Fri, 26 Apr 2019 16:13:49 +0200 Subject: [PATCH] Added simplifications & transformations --- pystencils/simp/simplifications.py | 28 +++++++++++++++++++++++++++ pystencils/transformations.py | 31 +++++++++++++++++++++++++++++- 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/pystencils/simp/simplifications.py b/pystencils/simp/simplifications.py index a42601b52..47a9a9f64 100644 --- a/pystencils/simp/simplifications.py +++ b/pystencils/simp/simplifications.py @@ -3,6 +3,7 @@ from typing import Callable, List from pystencils import Field from pystencils.assignment import Assignment +from pystencils.field import AbstractField from pystencils.simp.assignment_collection import AssignmentCollection from pystencils.sympyextensions import subs_additive @@ -85,6 +86,33 @@ def add_subexpressions_for_divisions(ac: AC) -> AC: return ac.new_with_substitutions(substitutions, True) +def add_subexpressions_for_sums(ac: AC) -> AC: + r"""Introduces subexpressions for all sums - i.e. splits addends into subexpressions.""" + addends = [] + + def contains_sum(term): + if term.func == sp.add.Add: + return True + if term.is_Atom: + return False + return any([contains_sum(a) for a in term.args]) + + def search_addends(term): + if term.func == sp.add.Add: + if all([not contains_sum(a) for a in term.args]): + addends.extend(term.args) + for a in term.args: + search_addends(a) + + for eq in ac.all_assignments: + search_addends(eq.rhs) + + addends = [a for a in addends if not isinstance(a, sp.Symbol) or isinstance(a, AbstractField.AbstractAccess)] + new_symbol_gen = ac.subexpression_symbol_generator + substitutions = {addend: new_symbol for new_symbol, addend in zip(new_symbol_gen, addends)} + return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False) + + def add_subexpressions_for_field_reads(ac: AC, subexpressions=True, main_assignments=True) -> AC: r"""Substitutes field accesses on rhs of assignments with subexpressions diff --git a/pystencils/transformations.py b/pystencils/transformations.py index bc325dcec..4490e3e1a 100644 --- a/pystencils/transformations.py +++ b/pystencils/transformations.py @@ -702,7 +702,7 @@ def cut_loop(loop_node, cutting_points): """ if loop_node.step != 1: raise NotImplementedError("Can only split loops that have a step of 1") - new_loops = [] + new_loops = ast.Block([]) new_start = loop_node.start cutting_points = list(cutting_points) + [loop_node.stop] for new_end in cutting_points: @@ -1089,6 +1089,35 @@ def get_optimal_loop_ordering(fields): return list(layout) +def get_loop_hierarchy(ast_node): + """Determines the loop structure around a given AST node, i.e. the node has to be inside the loops. + + Returns: + sequence of LoopOverCoordinate nodes, starting from outer loop to innermost loop + """ + result = [] + node = ast_node + while node is not None: + node = get_next_parent_of_type(node, ast.LoopOverCoordinate) + if node: + result.append(node.coordinate_to_loop_over) + return reversed(result) + + +def get_loop_counter_symbol_hierarchy(astNode): + """Determines the loop counter symbols around a given AST node. + :param astNode: the AST node + :return: list of loop counter symbols, where the first list entry is the symbol of the innermost loop + """ + result = [] + node = astNode + while node is not None: + node = get_next_parent_of_type(node, ast.LoopOverCoordinate) + if node: + result.append(node.loop_counter_symbol) + return result + + def replace_inner_stride_with_one(ast_node: ast.KernelFunction) -> None: """Replaces the stride of the innermost loop of a variable sized kernel with 1 (assumes optimal loop ordering). -- GitLab