Commit b2312d53 authored by Markus Holzer's avatar Markus Holzer
Browse files

Merge branch 'simplification_patch' into 'master'

add_subexpressions_for_constants and new_filtered fix

See merge request pycodegen/pystencils!207
parents f93e8ff9 bda99ca6
......@@ -9,6 +9,7 @@ __pycache__
.cache
_build
/.idea
.vscode
.cache
_local_tmp
RELEASE-VERSION
......
from .assignment_collection import AssignmentCollection
from .simplifications import (
add_subexpressions_for_constants,
add_subexpressions_for_divisions, add_subexpressions_for_field_reads,
add_subexpressions_for_sums, apply_on_all_subexpressions, apply_to_all_assignments,
subexpression_substitution_in_existing_subexpressions,
......@@ -9,5 +10,5 @@ from .simplificationstrategy import SimplificationStrategy
__all__ = ['AssignmentCollection', 'SimplificationStrategy',
'sympy_cse', 'sympy_cse_on_assignment_list', 'apply_to_all_assignments',
'apply_on_all_subexpressions', 'subexpression_substitution_in_existing_subexpressions',
'subexpression_substitution_in_main_assignments', 'add_subexpressions_for_divisions',
'add_subexpressions_for_sums', 'add_subexpressions_for_field_reads']
'subexpression_substitution_in_main_assignments', 'add_subexpressions_for_constants',
'add_subexpressions_for_divisions', 'add_subexpressions_for_sums', 'add_subexpressions_for_field_reads']
......@@ -300,7 +300,7 @@ class AssignmentCollection:
new_sub_expr = [eq for eq in self.subexpressions
if eq.lhs in dependent_symbols and eq.lhs not in symbols_to_extract]
return AssignmentCollection(new_assignments, new_sub_expr)
return self.copy(new_assignments, new_sub_expr)
def new_without_unused_subexpressions(self) -> 'AssignmentCollection':
"""Returns new collection that only contains subexpressions required to compute the main assignments."""
......
from itertools import chain
from typing import Callable, List, Sequence, Union
from collections import defaultdict
import sympy as sp
from pystencils.assignment import Assignment
from pystencils.astnodes import Node
from pystencils.field import AbstractField, Field
from pystencils.sympyextensions import subs_additive
from pystencils.sympyextensions import subs_additive, is_constant, recursive_collect
def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]:
......@@ -83,6 +84,39 @@ def subexpression_substitution_in_main_assignments(ac):
return ac.copy(result)
def add_subexpressions_for_constants(ac):
"""Extracts constant factors to subexpressions in the given assignment collection.
SymPy will exclude common factors from a sum only if they are symbols. This simplification
can be applied to exclude common numeric constants from multiple terms of a sum. As a consequence,
the number of multiplications is reduced and in some cases, more common subexpressions can be found.
"""
constants_to_subexp_dict = defaultdict(lambda: next(ac.subexpression_symbol_generator))
def visit(expr):
args = list(expr.args)
if len(args) == 0:
return expr
if isinstance(expr, sp.Add) or isinstance(expr, sp.Mul):
for i, arg in enumerate(args):
if is_constant(arg) and abs(arg) != 1:
if arg < 0:
args[i] = - constants_to_subexp_dict[- arg]
else:
args[i] = constants_to_subexp_dict[arg]
return expr.func(*(visit(a) for a in args))
main_assignments = [Assignment(a.lhs, visit(a.rhs)) for a in ac.main_assignments]
subexpressions = [Assignment(a.lhs, visit(a.rhs)) for a in ac.subexpressions]
symbols_to_collect = set(constants_to_subexp_dict.values())
main_assignments = [Assignment(a.lhs, recursive_collect(a.rhs, symbols_to_collect, True)) for a in main_assignments]
subexpressions = [Assignment(a.lhs, recursive_collect(a.rhs, symbols_to_collect, True)) for a in subexpressions]
subexpressions = [Assignment(symb, c) for c, symb in constants_to_subexp_dict.items()] + subexpressions
return ac.copy(main_assignments=main_assignments, subexpressions=subexpressions)
def add_subexpressions_for_divisions(ac):
r"""Introduces subexpressions for all divisions which have no constant in the denominator.
......@@ -172,7 +206,7 @@ def transform_lhs_and_rhs(assignment_list, transformation, *args, **kwargs):
def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]):
"""Applies sympy expand operation to all equations in collection."""
"""Applies a given operation to all equations in collection."""
def f(ac):
return ac.copy(transform_rhs(ac.main_assignments, operation))
......
......@@ -431,6 +431,28 @@ def extract_most_common_factor(term):
return common_factor, term / common_factor
def recursive_collect(expr, symbols, order_by_occurences=False):
"""Applies sympy.collect recursively for a list of symbols, collecting symbol 2 in the coefficients of symbol 1,
and so on.
Args:
expr: A sympy expression
symbols: A sequence of symbols
order_by_occurences: If True, during recursive descent, always collect the symbol occuring
most often in the expression.
"""
if order_by_occurences:
symbols = list(expr.atoms(sp.Symbol) & set(symbols))
symbols = sorted(symbols, key=expr.count, reverse=True)
if len(symbols) == 0:
return expr
symbol = symbols[0]
collected_poly = sp.Poly(expr.collect(symbol), symbol)
coeffs = collected_poly.all_coeffs()[::-1]
rec_sum = sum(symbol**i * recursive_collect(c, symbols[1:], order_by_occurences) for i, c in enumerate(coeffs))
return rec_sum
def count_operations(term: Union[sp.Expr, List[sp.Expr]],
only_type: Optional[str] = 'real') -> Dict[str, int]:
"""Counts the number of additions, multiplications and division.
......
import pytest
import sympy as sp
from pystencils.simp import subexpression_substitution_in_main_assignments
from pystencils.simp import add_subexpressions_for_divisions
from pystencils.simp import add_subexpressions_for_sums
from pystencils.simp import add_subexpressions_for_field_reads
from pystencils.simp.simplifications import add_subexpressions_for_constants
from pystencils import Assignment, AssignmentCollection, fields
a, b, c, d, x, y, z = sp.symbols("a b c d x y z")
......@@ -58,6 +60,45 @@ def test_add_subexpressions_for_divisions():
assert 1/d in rhs
def test_add_subexpressions_for_constants():
half = sp.Rational(1,2)
sqrt_2 = sp.sqrt(2)
main = [
Assignment(f[0], half * a + half * b + half * c),
Assignment(f[1], - half * a - half * b),
Assignment(f[2], a * sqrt_2 - b * sqrt_2),
Assignment(f[3], a**2 + b**2)
]
ac = AssignmentCollection(main)
ac = add_subexpressions_for_constants(ac)
assert len(ac.subexpressions) == 2
half_subexp = None
sqrt_subexp = None
for asm in ac.subexpressions:
if asm.rhs == half:
half_subexp = asm.lhs
elif asm.rhs == sqrt_2:
sqrt_subexp = asm.lhs
else:
pytest.fail(f"An unexpected subexpression was encountered: {asm}")
assert half_subexp is not None
assert sqrt_subexp is not None
for asm in ac.main_assignments[:3]:
assert isinstance(asm.rhs, sp.Mul)
assert any(arg == half_subexp for arg in ac.main_assignments[0].rhs.args)
assert any(arg == half_subexp for arg in ac.main_assignments[1].rhs.args)
assert any(arg == sqrt_subexp for arg in ac.main_assignments[2].rhs.args)
# Do not replace exponents!
assert ac.main_assignments[3].rhs == a**2 + b**2
def test_add_subexpressions_for_sums():
subexpressions = [
Assignment(s0, a + b + c + d),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment