From 7c0fc3e88693cff00012e32627813b987d0138be Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Thu, 7 Mar 2019 16:06:44 +0100 Subject: [PATCH] pygrandchem: approximate divs & sqrts --- fast_approximation.py | 32 +++++++++++++++----------------- fd/finitedifferences.py | 2 +- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/fast_approximation.py b/fast_approximation.py index 538e493f9..0dbdab241 100644 --- a/fast_approximation.py +++ b/fast_approximation.py @@ -9,15 +9,28 @@ from pystencils.simp import AssignmentCollection class fast_division(sp.Function): nargs = (2,) + # noinspection PyPep8Naming class fast_sqrt(sp.Function): nargs = (1, ) + # noinspection PyPep8Naming class fast_inv_sqrt(sp.Function): nargs = (1, ) +def _run(term, visitor): + if isinstance(term, AssignmentCollection): + new_main_assignments = _run(term.main_assignments, visitor) + new_subexpressions = _run(term.subexpressions, visitor) + return term.copy(new_main_assignments, new_subexpressions) + elif isinstance(term, list): + return [_run(e, visitor) for e in term] + else: + return visitor(term) + + def insert_fast_sqrts(term: Union[sp.Expr, List[sp.Expr], AssignmentCollection]): def visit(expr): if isinstance(expr, Node): @@ -31,15 +44,7 @@ def insert_fast_sqrts(term: Union[sp.Expr, List[sp.Expr], AssignmentCollection]) else: new_args = [visit(a) for a in expr.args] return expr.func(*new_args) if new_args else expr - - if isinstance(term, AssignmentCollection): - new_main_assignments = insert_fast_sqrts(term.main_assignments) - new_subexpressions = insert_fast_sqrts(term.subexpressions) - return term.copy(new_main_assignments, new_subexpressions) - elif isinstance(term, list): - return [insert_fast_sqrts(e) for e in term] - else: - return visit(term) + return _run(term, visit) def insert_fast_divisions(term: Union[sp.Expr, List[sp.Expr], AssignmentCollection]): @@ -65,11 +70,4 @@ def insert_fast_divisions(term: Union[sp.Expr, List[sp.Expr], AssignmentCollecti new_args = [visit(a) for a in expr.args] return expr.func(*new_args) if new_args else expr - if isinstance(term, AssignmentCollection): - new_main_assignments = insert_fast_divisions(term.main_assignments) - new_subexpressions = insert_fast_divisions(term.subexpressions) - return term.copy(new_main_assignments, new_subexpressions) - elif isinstance(term, list): - return [insert_fast_divisions(e) for e in term] - else: - return visit(term) + return _run(term, visit) diff --git a/fd/finitedifferences.py b/fd/finitedifferences.py index 3879addc0..fe764361e 100644 --- a/fd/finitedifferences.py +++ b/fd/finitedifferences.py @@ -135,7 +135,7 @@ class Discretization2ndOrder: def __call__(self, expr): if isinstance(expr, list): return [self(e) for e in expr] - elif isinstance(expr, sp.Matrix): + elif isinstance(expr, sp.Matrix) or isinstance(expr, sp.ImmutableDenseMatrix): return expr.applyfunc(self.__call__) elif isinstance(expr, AssignmentCollection): return expr.copy(main_assignments=[e for e in expr.main_assignments], -- GitLab