From 1c3e9880f2c8eb0ae04271a1f1361339061a50ff Mon Sep 17 00:00:00 2001
From: Nils Kohl <nils.kohl@fau.de>
Date: Fri, 26 Apr 2019 16:13:49 +0200
Subject: [PATCH] Added simplifications & transformations

---
 pystencils/simp/simplifications.py | 28 +++++++++++++++++++++++++++
 pystencils/transformations.py      | 31 +++++++++++++++++++++++++++++-
 2 files changed, 58 insertions(+), 1 deletion(-)

diff --git a/pystencils/simp/simplifications.py b/pystencils/simp/simplifications.py
index a42601b52..47a9a9f64 100644
--- a/pystencils/simp/simplifications.py
+++ b/pystencils/simp/simplifications.py
@@ -3,6 +3,7 @@ from typing import Callable, List
 
 from pystencils import Field
 from pystencils.assignment import Assignment
+from pystencils.field import AbstractField
 from pystencils.simp.assignment_collection import AssignmentCollection
 from pystencils.sympyextensions import subs_additive
 
@@ -85,6 +86,33 @@ def add_subexpressions_for_divisions(ac: AC) -> AC:
     return ac.new_with_substitutions(substitutions, True)
 
 
+def add_subexpressions_for_sums(ac: AC) -> AC:
+    r"""Introduces subexpressions for all sums - i.e. splits addends into subexpressions."""
+    addends = []
+
+    def contains_sum(term):
+        if term.func == sp.add.Add:
+            return True
+        if term.is_Atom:
+            return False
+        return any([contains_sum(a) for a in term.args])
+
+    def search_addends(term):
+        if term.func == sp.add.Add:
+            if all([not contains_sum(a) for a in term.args]):
+                addends.extend(term.args)
+        for a in term.args:
+            search_addends(a)
+
+    for eq in ac.all_assignments:
+        search_addends(eq.rhs)
+
+    addends = [a for a in addends if not isinstance(a, sp.Symbol) or isinstance(a, AbstractField.AbstractAccess)]
+    new_symbol_gen = ac.subexpression_symbol_generator
+    substitutions = {addend: new_symbol for new_symbol, addend in zip(new_symbol_gen, addends)}
+    return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False)
+
+
 def add_subexpressions_for_field_reads(ac: AC, subexpressions=True, main_assignments=True) -> AC:
     r"""Substitutes field accesses on rhs of assignments with subexpressions
 
diff --git a/pystencils/transformations.py b/pystencils/transformations.py
index bc325dcec..4490e3e1a 100644
--- a/pystencils/transformations.py
+++ b/pystencils/transformations.py
@@ -702,7 +702,7 @@ def cut_loop(loop_node, cutting_points):
     """
     if loop_node.step != 1:
         raise NotImplementedError("Can only split loops that have a step of 1")
-    new_loops = []
+    new_loops = ast.Block([])
     new_start = loop_node.start
     cutting_points = list(cutting_points) + [loop_node.stop]
     for new_end in cutting_points:
@@ -1089,6 +1089,35 @@ def get_optimal_loop_ordering(fields):
     return list(layout)
 
 
+def get_loop_hierarchy(ast_node):
+    """Determines the loop structure around a given AST node, i.e. the node has to be inside the loops.
+
+    Returns:
+        sequence of LoopOverCoordinate nodes, starting from outer loop to innermost loop
+    """
+    result = []
+    node = ast_node
+    while node is not None:
+        node = get_next_parent_of_type(node, ast.LoopOverCoordinate)
+        if node:
+            result.append(node.coordinate_to_loop_over)
+    return reversed(result)
+
+
+def get_loop_counter_symbol_hierarchy(astNode):
+    """Determines the loop counter symbols around a given AST node.
+    :param astNode: the AST node
+    :return: list of loop counter symbols, where the first list entry is the symbol of the innermost loop
+    """
+    result = []
+    node = astNode
+    while node is not None:
+        node = get_next_parent_of_type(node, ast.LoopOverCoordinate)
+        if node:
+            result.append(node.loop_counter_symbol)
+    return result
+
+
 def replace_inner_stride_with_one(ast_node: ast.KernelFunction) -> None:
     """Replaces the stride of the innermost loop of a variable sized kernel with 1 (assumes optimal loop ordering).
 
-- 
GitLab