diff --git a/fast_approximation.py b/fast_approximation.py index 538e493f9bebcafa1eb788050758a4c93b556332..0dbdab241151a5e211efa0c9a4847fcbd9211fae 100644 --- a/fast_approximation.py +++ b/fast_approximation.py @@ -9,15 +9,28 @@ from pystencils.simp import AssignmentCollection class fast_division(sp.Function): nargs = (2,) + # noinspection PyPep8Naming class fast_sqrt(sp.Function): nargs = (1, ) + # noinspection PyPep8Naming class fast_inv_sqrt(sp.Function): nargs = (1, ) +def _run(term, visitor): + if isinstance(term, AssignmentCollection): + new_main_assignments = _run(term.main_assignments, visitor) + new_subexpressions = _run(term.subexpressions, visitor) + return term.copy(new_main_assignments, new_subexpressions) + elif isinstance(term, list): + return [_run(e, visitor) for e in term] + else: + return visitor(term) + + def insert_fast_sqrts(term: Union[sp.Expr, List[sp.Expr], AssignmentCollection]): def visit(expr): if isinstance(expr, Node): @@ -31,15 +44,7 @@ def insert_fast_sqrts(term: Union[sp.Expr, List[sp.Expr], AssignmentCollection]) else: new_args = [visit(a) for a in expr.args] return expr.func(*new_args) if new_args else expr - - if isinstance(term, AssignmentCollection): - new_main_assignments = insert_fast_sqrts(term.main_assignments) - new_subexpressions = insert_fast_sqrts(term.subexpressions) - return term.copy(new_main_assignments, new_subexpressions) - elif isinstance(term, list): - return [insert_fast_sqrts(e) for e in term] - else: - return visit(term) + return _run(term, visit) def insert_fast_divisions(term: Union[sp.Expr, List[sp.Expr], AssignmentCollection]): @@ -65,11 +70,4 @@ def insert_fast_divisions(term: Union[sp.Expr, List[sp.Expr], AssignmentCollecti new_args = [visit(a) for a in expr.args] return expr.func(*new_args) if new_args else expr - if isinstance(term, AssignmentCollection): - new_main_assignments = insert_fast_divisions(term.main_assignments) - new_subexpressions = insert_fast_divisions(term.subexpressions) - return term.copy(new_main_assignments, new_subexpressions) - elif isinstance(term, list): - return [insert_fast_divisions(e) for e in term] - else: - return visit(term) + return _run(term, visit) diff --git a/fd/finitedifferences.py b/fd/finitedifferences.py index 3879addc0a82952c836bf18c1a258ad93e0fa67d..fe764361e9bad3b59b11ea4055b8b27e3662fb99 100644 --- a/fd/finitedifferences.py +++ b/fd/finitedifferences.py @@ -135,7 +135,7 @@ class Discretization2ndOrder: def __call__(self, expr): if isinstance(expr, list): return [self(e) for e in expr] - elif isinstance(expr, sp.Matrix): + elif isinstance(expr, sp.Matrix) or isinstance(expr, sp.ImmutableDenseMatrix): return expr.applyfunc(self.__call__) elif isinstance(expr, AssignmentCollection): return expr.copy(main_assignments=[e for e in expr.main_assignments],