diff --git a/pystencils/simp/__init__.py b/pystencils/simp/__init__.py index 190fce9622d61656c9d8a3861be1715d12384fe6..0e5ff0e6435bf08f3d305ff7cc9087775a09c6e2 100644 --- a/pystencils/simp/__init__.py +++ b/pystencils/simp/__init__.py @@ -5,10 +5,17 @@ from .simplifications import ( add_subexpressions_for_sums, apply_on_all_subexpressions, apply_to_all_assignments, subexpression_substitution_in_existing_subexpressions, subexpression_substitution_in_main_assignments, sympy_cse, sympy_cse_on_assignment_list) +from .subexpression_insertion import ( + insert_aliases, insert_zeros, insert_constants, + insert_constant_additions, insert_constant_multiples, + insert_squares, insert_symbol_times_minus_one) from .simplificationstrategy import SimplificationStrategy __all__ = ['AssignmentCollection', 'SimplificationStrategy', 'sympy_cse', 'sympy_cse_on_assignment_list', 'apply_to_all_assignments', 'apply_on_all_subexpressions', 'subexpression_substitution_in_existing_subexpressions', 'subexpression_substitution_in_main_assignments', 'add_subexpressions_for_constants', - 'add_subexpressions_for_divisions', 'add_subexpressions_for_sums', 'add_subexpressions_for_field_reads'] + 'add_subexpressions_for_divisions', 'add_subexpressions_for_sums', 'add_subexpressions_for_field_reads', + 'insert_aliases', 'insert_zeros', 'insert_constants', + 'insert_constant_additions', 'insert_constant_multiples', + 'insert_squares', 'insert_symbol_times_minus_one'] diff --git a/pystencils/simp/subexpression_insertion.py b/pystencils/simp/subexpression_insertion.py new file mode 100644 index 0000000000000000000000000000000000000000..9293b56be8f6fb223da6778a33db7e07bf463fc4 --- /dev/null +++ b/pystencils/simp/subexpression_insertion.py @@ -0,0 +1,93 @@ +import sympy as sp +from pystencils.sympyextensions import is_constant + +# Subexpression Insertion + + +def insert_subexpressions(ac, selection_callback, skip=set()): + """ + Removes a number of subexpressions from an assignment collection by + inserting their right-hand side wherever they occur. + + Args: + - selection_callback: Function that is called to qualify subexpressions + for insertion. Should return `True` for any subexpression that is to be + inserted, and `False` otherwise. + - skip: Set of symbols (left-hand sides of subexpressions) that should be + ignored even if qualified by the callback. + """ + i = 0 + while i < len(ac.subexpressions): + exp = ac.subexpressions[i] + if exp.lhs not in skip and selection_callback(exp): + ac = ac.new_with_inserted_subexpression(exp.lhs) + else: + i += 1 + + return ac + + +def insert_aliases(ac, **kwargs): + """Inserts subexpressions that are aliases of other symbols, + i.e. their right-hand side is only another symbol.""" + return insert_subexpressions(ac, lambda x: isinstance(x.rhs, sp.Symbol), **kwargs) + + +def insert_zeros(ac, **kwargs): + """Inserts subexpressions whose right-hand side is zero.""" + zero = sp.Integer(0) + return insert_subexpressions(ac, lambda x: x.rhs == zero, **kwargs) + + +def insert_constants(ac, **kwargs): + """Inserts subexpressions whose right-hand side is constant, + i.e. contains no symbols.""" + return insert_subexpressions(ac, lambda x: is_constant(x.rhs), **kwargs) + + +def insert_symbol_times_minus_one(ac, **kwargs): + """Inserts subexpressions whose right-hand side is just a + negation of another symbol.""" + def callback(exp): + rhs = exp.rhs + minus_one = sp.Integer(-1) + atoms = rhs.atoms(sp.Symbol) + return len(atoms) == 1 and rhs == minus_one * atoms.pop() + return insert_subexpressions(ac, callback, **kwargs) + + +def insert_constant_multiples(ac, **kwargs): + """Inserts subexpressions whose right-hand side is a constant + multiplied with another symbol.""" + def callback(exp): + rhs = exp.rhs + symbols = rhs.atoms(sp.Symbol) + numbers = rhs.atoms(sp.Number) + return len(symbols) == 1 and len(numbers) == 1 and \ + rhs == numbers.pop() * symbols.pop() + return insert_subexpressions(ac, callback, **kwargs) + + +def insert_constant_additions(ac, **kwargs): + """Inserts subexpressions whose right-hand side is a sum of a + constant and another symbol.""" + def callback(exp): + rhs = exp.rhs + symbols = rhs.atoms(sp.Symbol) + numbers = rhs.atoms(sp.Number) + return len(symbols) == 1 and len(numbers) == 1 and \ + rhs == numbers.pop() + symbols.pop() + return insert_subexpressions(ac, callback, **kwargs) + + +def insert_squares(ac, **kwargs): + """Inserts subexpressions whose right-hand side is another symbol squared.""" + def callback(exp): + rhs = exp.rhs + symbols = rhs.atoms(sp.Symbol) + return len(symbols) == 1 and rhs == symbols.pop() ** 2 + return insert_subexpressions(ac, callback, **kwargs) + + +def bind_symbols_to_skip(insertion_function, skip): + return lambda ac: insertion_function(ac, skip=skip) diff --git a/pystencils_tests/test_subexpression_insertion.py b/pystencils_tests/test_subexpression_insertion.py new file mode 100644 index 0000000000000000000000000000000000000000..9ae64d9fe0693016fde25c5aed0fc93942f1d763 --- /dev/null +++ b/pystencils_tests/test_subexpression_insertion.py @@ -0,0 +1,46 @@ +import sympy as sp +from pystencils import fields, Assignment, AssignmentCollection +from pystencils.simp.subexpression_insertion import * + + +def test_subexpression_insertion(): + f, g = fields('f(10), g(10) : [2D]') + xi = sp.symbols('xi_:10') + xi_set = set(xi) + + subexpressions = [ + Assignment(xi[0], -f(4)), + Assignment(xi[1], -(f(1) * f(2))), + Assignment(xi[2], 2.31 * f(5)), + Assignment(xi[3], 1.8 + f(5) + f(6)), + Assignment(xi[4], 5.7 + f(6)), + Assignment(xi[5], (f(4) + f(5))**2), + Assignment(xi[6], f(3)**2), + Assignment(xi[7], f(4)), + Assignment(xi[8], 13), + Assignment(xi[9], 0), + ] + + assignments = [Assignment(g(i), x) for i, x in enumerate(xi)] + ac = AssignmentCollection(assignments, subexpressions=subexpressions) + + ac_ins = insert_symbol_times_minus_one(ac) + assert (ac_ins.bound_symbols & xi_set) == (xi_set - {xi[0]}) + + ac_ins = insert_constant_multiples(ac) + assert (ac_ins.bound_symbols & xi_set) == (xi_set - {xi[0], xi[2]}) + + ac_ins = insert_constant_additions(ac) + assert (ac_ins.bound_symbols & xi_set) == (xi_set - {xi[4]}) + + ac_ins = insert_squares(ac) + assert (ac_ins.bound_symbols & xi_set) == (xi_set - {xi[6]}) + + ac_ins = insert_aliases(ac) + assert (ac_ins.bound_symbols & xi_set) == (xi_set - {xi[7]}) + + ac_ins = insert_zeros(ac) + assert (ac_ins.bound_symbols & xi_set) == (xi_set - {xi[9]}) + + ac_ins = insert_constants(ac, skip={xi[9]}) + assert (ac_ins.bound_symbols & xi_set) == (xi_set - {xi[8]})