Skip to content
Snippets Groups Projects
fast_approximation.py 2.28 KiB
Newer Older
import sympy as sp
from typing import List, Union

from pystencils.astnodes import Node
from pystencils.simp import AssignmentCollection


# noinspection PyPep8Naming
class fast_division(sp.Function):
    nargs = (2,)

# noinspection PyPep8Naming
class fast_sqrt(sp.Function):
    nargs = (1, )

# noinspection PyPep8Naming
class fast_inv_sqrt(sp.Function):
    nargs = (1, )


def _run(term, visitor):
    if isinstance(term, AssignmentCollection):
        new_main_assignments = _run(term.main_assignments, visitor)
        new_subexpressions = _run(term.subexpressions, visitor)
        return term.copy(new_main_assignments, new_subexpressions)
    elif isinstance(term, list):
        return [_run(e, visitor) for e in term]
    else:
        return visitor(term)


def insert_fast_sqrts(term: Union[sp.Expr, List[sp.Expr], AssignmentCollection]):
    def visit(expr):
        if isinstance(expr, Node):
            return expr
        if expr.func == sp.Pow and isinstance(expr.exp, sp.Rational) and expr.exp.q == 2:
            power = expr.exp.p
            if power < 0:
                return fast_inv_sqrt(expr.args[0]) ** (-power)
            else:
                return fast_sqrt(expr.args[0]) ** power
        else:
            new_args = [visit(a) for a in expr.args]
            return expr.func(*new_args) if new_args else expr
    return _run(term, visit)


def insert_fast_divisions(term: Union[sp.Expr, List[sp.Expr], AssignmentCollection]):

    def visit(expr):
        if isinstance(expr, Node):
            return expr
        if expr.func == sp.Mul:
            div_args = []
            other_args = []
            for a in expr.args:
                if a.func == sp.Pow and a.exp.is_integer and a.exp < 0:
                    div_args.append(visit(a.base) ** (-a.exp))
                else:
                    other_args.append(visit(a))
            if div_args:
                return fast_division(sp.Mul(*other_args), sp.Mul(*div_args))
            else:
                return sp.Mul(*other_args)
        elif expr.func == sp.Pow and expr.exp.is_integer and expr.exp < 0:
            return fast_division(1, visit(expr.base) ** (-expr.exp))
        else:
            new_args = [visit(a) for a in expr.args]
            return expr.func(*new_args) if new_args else expr

    return _run(term, visit)