diff --git a/.gitignore b/.gitignore index 3d736e113e2b7aeb636abe9677916725fbc1af89..32a9d13575e401b4e255ca58d0639bf65c7ff449 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ __pycache__ .cache _build /.idea +.vscode .cache _local_tmp RELEASE-VERSION diff --git a/pystencils/simp/__init__.py b/pystencils/simp/__init__.py index dadaa7911a536ed36eafa75b720d2cddfae6a6d9..190fce9622d61656c9d8a3861be1715d12384fe6 100644 --- a/pystencils/simp/__init__.py +++ b/pystencils/simp/__init__.py @@ -1,5 +1,6 @@ from .assignment_collection import AssignmentCollection from .simplifications import ( + add_subexpressions_for_constants, add_subexpressions_for_divisions, add_subexpressions_for_field_reads, add_subexpressions_for_sums, apply_on_all_subexpressions, apply_to_all_assignments, subexpression_substitution_in_existing_subexpressions, @@ -9,5 +10,5 @@ from .simplificationstrategy import SimplificationStrategy __all__ = ['AssignmentCollection', 'SimplificationStrategy', 'sympy_cse', 'sympy_cse_on_assignment_list', 'apply_to_all_assignments', 'apply_on_all_subexpressions', 'subexpression_substitution_in_existing_subexpressions', - 'subexpression_substitution_in_main_assignments', 'add_subexpressions_for_divisions', - 'add_subexpressions_for_sums', 'add_subexpressions_for_field_reads'] + 'subexpression_substitution_in_main_assignments', 'add_subexpressions_for_constants', + 'add_subexpressions_for_divisions', 'add_subexpressions_for_sums', 'add_subexpressions_for_field_reads'] diff --git a/pystencils/simp/assignment_collection.py b/pystencils/simp/assignment_collection.py index 6bd1c66021c129e35288101c24cdcca96fcebd46..95064408959830650df6402cf8f22983122638c7 100644 --- a/pystencils/simp/assignment_collection.py +++ b/pystencils/simp/assignment_collection.py @@ -300,7 +300,7 @@ class AssignmentCollection: new_sub_expr = [eq for eq in self.subexpressions if eq.lhs in dependent_symbols and eq.lhs not in symbols_to_extract] - return AssignmentCollection(new_assignments, new_sub_expr) + return self.copy(new_assignments, new_sub_expr) def new_without_unused_subexpressions(self) -> 'AssignmentCollection': """Returns new collection that only contains subexpressions required to compute the main assignments.""" diff --git a/pystencils/simp/simplifications.py b/pystencils/simp/simplifications.py index 234b7a373bcc5ad4dc7939f7cdb0044b50bb0c6d..7345ad6f3fff3f79bc54237d0dcc76253831efa8 100644 --- a/pystencils/simp/simplifications.py +++ b/pystencils/simp/simplifications.py @@ -1,12 +1,13 @@ from itertools import chain from typing import Callable, List, Sequence, Union +from collections import defaultdict import sympy as sp from pystencils.assignment import Assignment from pystencils.astnodes import Node from pystencils.field import AbstractField, Field -from pystencils.sympyextensions import subs_additive +from pystencils.sympyextensions import subs_additive, is_constant, recursive_collect def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]: @@ -83,6 +84,39 @@ def subexpression_substitution_in_main_assignments(ac): return ac.copy(result) +def add_subexpressions_for_constants(ac): + """Extracts constant factors to subexpressions in the given assignment collection. + + SymPy will exclude common factors from a sum only if they are symbols. This simplification + can be applied to exclude common numeric constants from multiple terms of a sum. As a consequence, + the number of multiplications is reduced and in some cases, more common subexpressions can be found. + """ + constants_to_subexp_dict = defaultdict(lambda: next(ac.subexpression_symbol_generator)) + + def visit(expr): + args = list(expr.args) + if len(args) == 0: + return expr + if isinstance(expr, sp.Add) or isinstance(expr, sp.Mul): + for i, arg in enumerate(args): + if is_constant(arg) and abs(arg) != 1: + if arg < 0: + args[i] = - constants_to_subexp_dict[- arg] + else: + args[i] = constants_to_subexp_dict[arg] + return expr.func(*(visit(a) for a in args)) + main_assignments = [Assignment(a.lhs, visit(a.rhs)) for a in ac.main_assignments] + subexpressions = [Assignment(a.lhs, visit(a.rhs)) for a in ac.subexpressions] + + symbols_to_collect = set(constants_to_subexp_dict.values()) + + main_assignments = [Assignment(a.lhs, recursive_collect(a.rhs, symbols_to_collect, True)) for a in main_assignments] + subexpressions = [Assignment(a.lhs, recursive_collect(a.rhs, symbols_to_collect, True)) for a in subexpressions] + + subexpressions = [Assignment(symb, c) for c, symb in constants_to_subexp_dict.items()] + subexpressions + return ac.copy(main_assignments=main_assignments, subexpressions=subexpressions) + + def add_subexpressions_for_divisions(ac): r"""Introduces subexpressions for all divisions which have no constant in the denominator. @@ -172,7 +206,7 @@ def transform_lhs_and_rhs(assignment_list, transformation, *args, **kwargs): def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]): - """Applies sympy expand operation to all equations in collection.""" + """Applies a given operation to all equations in collection.""" def f(ac): return ac.copy(transform_rhs(ac.main_assignments, operation)) diff --git a/pystencils/sympyextensions.py b/pystencils/sympyextensions.py index ce23f4541b46444f55b4d87ee89590bc519586a6..29b524eef506dcabcdac3c487bbf53e87965b42a 100644 --- a/pystencils/sympyextensions.py +++ b/pystencils/sympyextensions.py @@ -431,6 +431,28 @@ def extract_most_common_factor(term): return common_factor, term / common_factor +def recursive_collect(expr, symbols, order_by_occurences=False): + """Applies sympy.collect recursively for a list of symbols, collecting symbol 2 in the coefficients of symbol 1, + and so on. + + Args: + expr: A sympy expression + symbols: A sequence of symbols + order_by_occurences: If True, during recursive descent, always collect the symbol occuring + most often in the expression. + """ + if order_by_occurences: + symbols = list(expr.atoms(sp.Symbol) & set(symbols)) + symbols = sorted(symbols, key=expr.count, reverse=True) + if len(symbols) == 0: + return expr + symbol = symbols[0] + collected_poly = sp.Poly(expr.collect(symbol), symbol) + coeffs = collected_poly.all_coeffs()[::-1] + rec_sum = sum(symbol**i * recursive_collect(c, symbols[1:], order_by_occurences) for i, c in enumerate(coeffs)) + return rec_sum + + def count_operations(term: Union[sp.Expr, List[sp.Expr]], only_type: Optional[str] = 'real') -> Dict[str, int]: """Counts the number of additions, multiplications and division. diff --git a/pystencils_tests/test_simplifications.py b/pystencils_tests/test_simplifications.py index b9f9cc8a16308702df1cf1274c62ca086ef72d3f..62fa262fb203d02b2c10e624893b428f97d33328 100644 --- a/pystencils_tests/test_simplifications.py +++ b/pystencils_tests/test_simplifications.py @@ -1,9 +1,11 @@ +import pytest import sympy as sp from pystencils.simp import subexpression_substitution_in_main_assignments from pystencils.simp import add_subexpressions_for_divisions from pystencils.simp import add_subexpressions_for_sums from pystencils.simp import add_subexpressions_for_field_reads +from pystencils.simp.simplifications import add_subexpressions_for_constants from pystencils import Assignment, AssignmentCollection, fields a, b, c, d, x, y, z = sp.symbols("a b c d x y z") @@ -58,6 +60,45 @@ def test_add_subexpressions_for_divisions(): assert 1/d in rhs +def test_add_subexpressions_for_constants(): + half = sp.Rational(1,2) + sqrt_2 = sp.sqrt(2) + main = [ + Assignment(f[0], half * a + half * b + half * c), + Assignment(f[1], - half * a - half * b), + Assignment(f[2], a * sqrt_2 - b * sqrt_2), + Assignment(f[3], a**2 + b**2) + ] + ac = AssignmentCollection(main) + ac = add_subexpressions_for_constants(ac) + + assert len(ac.subexpressions) == 2 + + half_subexp = None + sqrt_subexp = None + + for asm in ac.subexpressions: + if asm.rhs == half: + half_subexp = asm.lhs + elif asm.rhs == sqrt_2: + sqrt_subexp = asm.lhs + else: + pytest.fail(f"An unexpected subexpression was encountered: {asm}") + + assert half_subexp is not None + assert sqrt_subexp is not None + + for asm in ac.main_assignments[:3]: + assert isinstance(asm.rhs, sp.Mul) + + assert any(arg == half_subexp for arg in ac.main_assignments[0].rhs.args) + assert any(arg == half_subexp for arg in ac.main_assignments[1].rhs.args) + assert any(arg == sqrt_subexp for arg in ac.main_assignments[2].rhs.args) + + # Do not replace exponents! + assert ac.main_assignments[3].rhs == a**2 + b**2 + + def test_add_subexpressions_for_sums(): subexpressions = [ Assignment(s0, a + b + c + d),