From bda99ca660a7c7d07e046751136efaaab88be273 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Fri, 29 Jan 2021 15:42:39 +0100
Subject: [PATCH] add_subexpressions_for_constants and new_filtered fix

---
 .gitignore                               |  1 +
 pystencils/simp/__init__.py              |  5 +--
 pystencils/simp/assignment_collection.py |  2 +-
 pystencils/simp/simplifications.py       | 38 ++++++++++++++++++++--
 pystencils/sympyextensions.py            | 22 +++++++++++++
 pystencils_tests/test_simplifications.py | 41 ++++++++++++++++++++++++
 6 files changed, 104 insertions(+), 5 deletions(-)

diff --git a/.gitignore b/.gitignore
index 3d736e113..32a9d1357 100644
--- a/.gitignore
+++ b/.gitignore
@@ -9,6 +9,7 @@ __pycache__
 .cache
 _build
 /.idea
+.vscode
 .cache
 _local_tmp
 RELEASE-VERSION
diff --git a/pystencils/simp/__init__.py b/pystencils/simp/__init__.py
index dadaa7911..190fce962 100644
--- a/pystencils/simp/__init__.py
+++ b/pystencils/simp/__init__.py
@@ -1,5 +1,6 @@
 from .assignment_collection import AssignmentCollection
 from .simplifications import (
+    add_subexpressions_for_constants,
     add_subexpressions_for_divisions, add_subexpressions_for_field_reads,
     add_subexpressions_for_sums, apply_on_all_subexpressions, apply_to_all_assignments,
     subexpression_substitution_in_existing_subexpressions,
@@ -9,5 +10,5 @@ from .simplificationstrategy import SimplificationStrategy
 __all__ = ['AssignmentCollection', 'SimplificationStrategy',
            'sympy_cse', 'sympy_cse_on_assignment_list', 'apply_to_all_assignments',
            'apply_on_all_subexpressions', 'subexpression_substitution_in_existing_subexpressions',
-           'subexpression_substitution_in_main_assignments', 'add_subexpressions_for_divisions',
-           'add_subexpressions_for_sums', 'add_subexpressions_for_field_reads']
+           'subexpression_substitution_in_main_assignments', 'add_subexpressions_for_constants',
+           'add_subexpressions_for_divisions', 'add_subexpressions_for_sums', 'add_subexpressions_for_field_reads']
diff --git a/pystencils/simp/assignment_collection.py b/pystencils/simp/assignment_collection.py
index 6bd1c6602..950644089 100644
--- a/pystencils/simp/assignment_collection.py
+++ b/pystencils/simp/assignment_collection.py
@@ -300,7 +300,7 @@ class AssignmentCollection:
 
         new_sub_expr = [eq for eq in self.subexpressions
                         if eq.lhs in dependent_symbols and eq.lhs not in symbols_to_extract]
-        return AssignmentCollection(new_assignments, new_sub_expr)
+        return self.copy(new_assignments, new_sub_expr)
 
     def new_without_unused_subexpressions(self) -> 'AssignmentCollection':
         """Returns new collection that only contains subexpressions required to compute the main assignments."""
diff --git a/pystencils/simp/simplifications.py b/pystencils/simp/simplifications.py
index 234b7a373..7345ad6f3 100644
--- a/pystencils/simp/simplifications.py
+++ b/pystencils/simp/simplifications.py
@@ -1,12 +1,13 @@
 from itertools import chain
 from typing import Callable, List, Sequence, Union
+from collections import defaultdict
 
 import sympy as sp
 
 from pystencils.assignment import Assignment
 from pystencils.astnodes import Node
 from pystencils.field import AbstractField, Field
-from pystencils.sympyextensions import subs_additive
+from pystencils.sympyextensions import subs_additive, is_constant, recursive_collect
 
 
 def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]:
@@ -83,6 +84,39 @@ def subexpression_substitution_in_main_assignments(ac):
     return ac.copy(result)
 
 
+def add_subexpressions_for_constants(ac):
+    """Extracts constant factors to subexpressions in the given assignment collection.
+
+    SymPy will exclude common factors from a sum only if they are symbols. This simplification
+    can be applied to exclude common numeric constants from multiple terms of a sum. As a consequence,
+    the number of multiplications is reduced and in some cases, more common subexpressions can be found.
+    """
+    constants_to_subexp_dict = defaultdict(lambda: next(ac.subexpression_symbol_generator))
+
+    def visit(expr):
+        args = list(expr.args)
+        if len(args) == 0:
+            return expr
+        if isinstance(expr, sp.Add) or isinstance(expr, sp.Mul):
+            for i, arg in enumerate(args):
+                if is_constant(arg) and abs(arg) != 1:
+                    if arg < 0:
+                        args[i] = - constants_to_subexp_dict[- arg]
+                    else:
+                        args[i] = constants_to_subexp_dict[arg]
+        return expr.func(*(visit(a) for a in args))
+    main_assignments = [Assignment(a.lhs, visit(a.rhs)) for a in ac.main_assignments]
+    subexpressions = [Assignment(a.lhs, visit(a.rhs)) for a in ac.subexpressions]
+
+    symbols_to_collect = set(constants_to_subexp_dict.values())
+
+    main_assignments = [Assignment(a.lhs, recursive_collect(a.rhs, symbols_to_collect, True)) for a in main_assignments]
+    subexpressions = [Assignment(a.lhs, recursive_collect(a.rhs, symbols_to_collect, True)) for a in subexpressions]
+
+    subexpressions = [Assignment(symb, c) for c, symb in constants_to_subexp_dict.items()] + subexpressions
+    return ac.copy(main_assignments=main_assignments, subexpressions=subexpressions)
+
+
 def add_subexpressions_for_divisions(ac):
     r"""Introduces subexpressions for all divisions which have no constant in the denominator.
 
@@ -172,7 +206,7 @@ def transform_lhs_and_rhs(assignment_list, transformation, *args, **kwargs):
 
 
 def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]):
-    """Applies sympy expand operation to all equations in collection."""
+    """Applies a given operation to all equations in collection."""
 
     def f(ac):
         return ac.copy(transform_rhs(ac.main_assignments, operation))
diff --git a/pystencils/sympyextensions.py b/pystencils/sympyextensions.py
index ce23f4541..29b524eef 100644
--- a/pystencils/sympyextensions.py
+++ b/pystencils/sympyextensions.py
@@ -431,6 +431,28 @@ def extract_most_common_factor(term):
     return common_factor, term / common_factor
 
 
+def recursive_collect(expr, symbols, order_by_occurences=False):
+    """Applies sympy.collect recursively for a list of symbols, collecting symbol 2 in the coefficients of symbol 1, 
+    and so on.
+
+    Args:
+        expr: A sympy expression
+        symbols: A sequence of symbols
+        order_by_occurences: If True, during recursive descent, always collect the symbol occuring 
+                             most often in the expression.
+    """
+    if order_by_occurences:
+        symbols = list(expr.atoms(sp.Symbol) & set(symbols))
+        symbols = sorted(symbols, key=expr.count, reverse=True)
+    if len(symbols) == 0:
+        return expr
+    symbol = symbols[0]
+    collected_poly = sp.Poly(expr.collect(symbol), symbol)
+    coeffs = collected_poly.all_coeffs()[::-1]
+    rec_sum = sum(symbol**i * recursive_collect(c, symbols[1:], order_by_occurences) for i, c in enumerate(coeffs))
+    return rec_sum
+
+
 def count_operations(term: Union[sp.Expr, List[sp.Expr]],
                      only_type: Optional[str] = 'real') -> Dict[str, int]:
     """Counts the number of additions, multiplications and division.
diff --git a/pystencils_tests/test_simplifications.py b/pystencils_tests/test_simplifications.py
index b9f9cc8a1..62fa262fb 100644
--- a/pystencils_tests/test_simplifications.py
+++ b/pystencils_tests/test_simplifications.py
@@ -1,9 +1,11 @@
+import pytest
 import sympy as sp
 
 from pystencils.simp import subexpression_substitution_in_main_assignments
 from pystencils.simp import add_subexpressions_for_divisions
 from pystencils.simp import add_subexpressions_for_sums
 from pystencils.simp import add_subexpressions_for_field_reads
+from pystencils.simp.simplifications import add_subexpressions_for_constants
 from pystencils import Assignment, AssignmentCollection, fields
 
 a, b, c, d, x, y, z = sp.symbols("a b c d x y z")
@@ -58,6 +60,45 @@ def test_add_subexpressions_for_divisions():
     assert 1/d in rhs
 
 
+def test_add_subexpressions_for_constants():
+    half = sp.Rational(1,2)
+    sqrt_2 = sp.sqrt(2)
+    main = [
+        Assignment(f[0], half * a + half * b + half * c),
+        Assignment(f[1], - half * a - half * b),
+        Assignment(f[2], a * sqrt_2 - b * sqrt_2),
+        Assignment(f[3], a**2 + b**2)
+    ]
+    ac = AssignmentCollection(main)
+    ac = add_subexpressions_for_constants(ac)
+    
+    assert len(ac.subexpressions) == 2
+    
+    half_subexp = None
+    sqrt_subexp = None
+
+    for asm in ac.subexpressions:
+        if asm.rhs == half:
+            half_subexp = asm.lhs
+        elif asm.rhs == sqrt_2:
+            sqrt_subexp = asm.lhs
+        else:
+            pytest.fail(f"An unexpected subexpression was encountered: {asm}")
+            
+    assert half_subexp is not None
+    assert sqrt_subexp is not None
+    
+    for asm in ac.main_assignments[:3]:
+        assert isinstance(asm.rhs, sp.Mul)
+
+    assert any(arg == half_subexp for arg in ac.main_assignments[0].rhs.args)
+    assert any(arg == half_subexp for arg in ac.main_assignments[1].rhs.args)
+    assert any(arg == sqrt_subexp for arg in ac.main_assignments[2].rhs.args)
+
+    #   Do not replace exponents!
+    assert ac.main_assignments[3].rhs == a**2 + b**2
+
+
 def test_add_subexpressions_for_sums():
     subexpressions = [
         Assignment(s0, a + b + c + d),
-- 
GitLab