sympy.codegen.rewriting there is a function
def optimize(expr, optimizations): """ Apply optimizations to an expression. Parameters ========== expr : expression optimizations : iterable of ``Optimization`` instances The optimizations will be sorted with respect to ``priority`` (highest first). Examples ======== >>> from sympy import log, Symbol >>> from sympy.codegen.rewriting import optims_c99, optimize >>> x = Symbol('x') >>> optimize(log(x+3)/log(2) + log(x**2 + 1), optims_c99) log1p(x**2) + log2(x + 3) """ for optim in sorted(optimizations, key=lambda opt: opt.priority, reverse=True): new_expr = optim(expr) if optim.cost_function is None: expr = new_expr else: before, after = map(lambda x: optim.cost_function(x), (expr, new_expr)) if before > after: expr = new_expr return expr
We should use it in
create_kernel to profit from Sympy's optimizations (currently very few, but some are duplicates from optimizations in pystencils) and to make it easy for users to incorporate their own optimizations into the expressions.
create_kernel could accept an
Optimizations with a default collection of optimizations that pystencils normally uses.
As you can see, it's really easy to implement own optimizations:
def create_expand_pow_optimization(limit): """ Creates an instance of :class:`ReplaceOptim` for expanding ``Pow``. The requirements for expansions are that the base needs to be a symbol and the exponent needs to be an Integer (and be less than or equal to ``limit``). Parameters ========== limit : int The highest power which is expanded into multiplication. Examples ======== >>> from sympy import Symbol, sin >>> from sympy.codegen.rewriting import create_expand_pow_optimization >>> x = Symbol('x') >>> expand_opt = create_expand_pow_optimization(3) >>> expand_opt(x**5 + x**3) x**5 + x*x*x >>> expand_opt(x**5 + x**3 + sin(x)**3) x**5 + sin(x)**3 + x*x*x """ return ReplaceOptim( lambda e: e.is_Pow and e.base.is_symbol and e.exp.is_Integer and abs(e.exp) <= limit, lambda p: ( UnevaluatedExpr(Mul(*([p.base]*+p.exp), evaluate=False)) if p.exp > 0 else 1/UnevaluatedExpr(Mul(*([p.base]*-p.exp), evaluate=False)) ))