Skip to content

show how to create own opts with sympy.codegen.rewriting.optimize in doc

In sympy.codegen.rewriting there is a function optimize:

def optimize(expr, optimizations):
    """ Apply optimizations to an expression.

    Parameters
    ==========

    expr : expression
    optimizations : iterable of ``Optimization`` instances
        The optimizations will be sorted with respect to ``priority`` (highest first).

    Examples
    ========

    >>> from sympy import log, Symbol
    >>> from sympy.codegen.rewriting import optims_c99, optimize
    >>> x = Symbol('x')
    >>> optimize(log(x+3)/log(2) + log(x**2 + 1), optims_c99)
    log1p(x**2) + log2(x + 3)

    """

    for optim in sorted(optimizations, key=lambda opt: opt.priority, reverse=True):
        new_expr = optim(expr)
        if optim.cost_function is None:
            expr = new_expr
        else:
            before, after = map(lambda x: optim.cost_function(x), (expr, new_expr))
            if before > after:
                expr = new_expr
    return expr

We should use it in create_kernel to profit from Sympy's optimizations (currently very few, but some are duplicates from optimizations in pystencils) and to make it easy for users to incorporate their own optimizations into the expressions. create_kernel could accept an Iterable of Optimizations with a default collection of optimizations that pystencils normally uses.

As you can see, it's really easy to implement own optimizations:

def create_expand_pow_optimization(limit):
    """ Creates an instance of :class:`ReplaceOptim` for expanding ``Pow``.

    The requirements for expansions are that the base needs to be a symbol
    and the exponent needs to be an Integer (and be less than or equal to
    ``limit``).

    Parameters
    ==========

    limit : int
         The highest power which is expanded into multiplication.

    Examples
    ========

    >>> from sympy import Symbol, sin
    >>> from sympy.codegen.rewriting import create_expand_pow_optimization
    >>> x = Symbol('x')
    >>> expand_opt = create_expand_pow_optimization(3)
    >>> expand_opt(x**5 + x**3)
    x**5 + x*x*x
    >>> expand_opt(x**5 + x**3 + sin(x)**3)
    x**5 + sin(x)**3 + x*x*x

    """
    return ReplaceOptim(
        lambda e: e.is_Pow and e.base.is_symbol and e.exp.is_Integer and abs(e.exp) <= limit,
        lambda p: (
            UnevaluatedExpr(Mul(*([p.base]*+p.exp), evaluate=False)) if p.exp > 0 else
            1/UnevaluatedExpr(Mul(*([p.base]*-p.exp), evaluate=False))
        ))
Edited by Markus Holzer