Newer
Older
import sympy as sp
from typing import List, Union
from pystencils.astnodes import Node
from pystencils.simp import AssignmentCollection
# noinspection PyPep8Naming
class fast_division(sp.Function):
nargs = (2,)
# noinspection PyPep8Naming
class fast_sqrt(sp.Function):
nargs = (1, )
# noinspection PyPep8Naming
class fast_inv_sqrt(sp.Function):
nargs = (1, )
def _run(term, visitor):
if isinstance(term, AssignmentCollection):
new_main_assignments = _run(term.main_assignments, visitor)
new_subexpressions = _run(term.subexpressions, visitor)
return term.copy(new_main_assignments, new_subexpressions)
elif isinstance(term, list):
return [_run(e, visitor) for e in term]
else:
return visitor(term)
def insert_fast_sqrts(term: Union[sp.Expr, List[sp.Expr], AssignmentCollection]):
def visit(expr):
if isinstance(expr, Node):
return expr
if expr.func == sp.Pow and isinstance(expr.exp, sp.Rational) and expr.exp.q == 2:
power = expr.exp.p
if power < 0:
return fast_inv_sqrt(expr.args[0]) ** (-power)
else:
return fast_sqrt(expr.args[0]) ** power
else:
new_args = [visit(a) for a in expr.args]
return expr.func(*new_args) if new_args else expr
def insert_fast_divisions(term: Union[sp.Expr, List[sp.Expr], AssignmentCollection]):
def visit(expr):
if isinstance(expr, Node):
return expr
if expr.func == sp.Mul:
div_args = []
other_args = []
for a in expr.args:
if a.func == sp.Pow and a.exp.is_integer and a.exp < 0:
div_args.append(visit(a.base) ** (-a.exp))
else:
other_args.append(visit(a))
if div_args:
return fast_division(sp.Mul(*other_args), sp.Mul(*div_args))
else:
return sp.Mul(*other_args)
elif expr.func == sp.Pow and expr.exp.is_integer and expr.exp < 0:
return fast_division(1, visit(expr.base) ** (-expr.exp))
else:
new_args = [visit(a) for a in expr.args]
return expr.func(*new_args) if new_args else expr