From 7c0fc3e88693cff00012e32627813b987d0138be Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Thu, 7 Mar 2019 16:06:44 +0100
Subject: [PATCH] pygrandchem: approximate divs & sqrts

---
 fast_approximation.py   | 32 +++++++++++++++-----------------
 fd/finitedifferences.py |  2 +-
 2 files changed, 16 insertions(+), 18 deletions(-)

diff --git a/fast_approximation.py b/fast_approximation.py
index 538e493f9..0dbdab241 100644
--- a/fast_approximation.py
+++ b/fast_approximation.py
@@ -9,15 +9,28 @@ from pystencils.simp import AssignmentCollection
 class fast_division(sp.Function):
     nargs = (2,)
 
+
 # noinspection PyPep8Naming
 class fast_sqrt(sp.Function):
     nargs = (1, )
 
+
 # noinspection PyPep8Naming
 class fast_inv_sqrt(sp.Function):
     nargs = (1, )
 
 
+def _run(term, visitor):
+    if isinstance(term, AssignmentCollection):
+        new_main_assignments = _run(term.main_assignments, visitor)
+        new_subexpressions = _run(term.subexpressions, visitor)
+        return term.copy(new_main_assignments, new_subexpressions)
+    elif isinstance(term, list):
+        return [_run(e, visitor) for e in term]
+    else:
+        return visitor(term)
+
+
 def insert_fast_sqrts(term: Union[sp.Expr, List[sp.Expr], AssignmentCollection]):
     def visit(expr):
         if isinstance(expr, Node):
@@ -31,15 +44,7 @@ def insert_fast_sqrts(term: Union[sp.Expr, List[sp.Expr], AssignmentCollection])
         else:
             new_args = [visit(a) for a in expr.args]
             return expr.func(*new_args) if new_args else expr
-
-    if isinstance(term, AssignmentCollection):
-        new_main_assignments = insert_fast_sqrts(term.main_assignments)
-        new_subexpressions = insert_fast_sqrts(term.subexpressions)
-        return term.copy(new_main_assignments, new_subexpressions)
-    elif isinstance(term, list):
-        return [insert_fast_sqrts(e) for e in term]
-    else:
-        return visit(term)
+    return _run(term, visit)
 
 
 def insert_fast_divisions(term: Union[sp.Expr, List[sp.Expr], AssignmentCollection]):
@@ -65,11 +70,4 @@ def insert_fast_divisions(term: Union[sp.Expr, List[sp.Expr], AssignmentCollecti
             new_args = [visit(a) for a in expr.args]
             return expr.func(*new_args) if new_args else expr
 
-    if isinstance(term, AssignmentCollection):
-        new_main_assignments = insert_fast_divisions(term.main_assignments)
-        new_subexpressions = insert_fast_divisions(term.subexpressions)
-        return term.copy(new_main_assignments, new_subexpressions)
-    elif isinstance(term, list):
-        return [insert_fast_divisions(e) for e in term]
-    else:
-        return visit(term)
+    return _run(term, visit)
diff --git a/fd/finitedifferences.py b/fd/finitedifferences.py
index 3879addc0..fe764361e 100644
--- a/fd/finitedifferences.py
+++ b/fd/finitedifferences.py
@@ -135,7 +135,7 @@ class Discretization2ndOrder:
     def __call__(self, expr):
         if isinstance(expr, list):
             return [self(e) for e in expr]
-        elif isinstance(expr, sp.Matrix):
+        elif isinstance(expr, sp.Matrix) or isinstance(expr, sp.ImmutableDenseMatrix):
             return expr.applyfunc(self.__call__)
         elif isinstance(expr, AssignmentCollection):
             return expr.copy(main_assignments=[e for e in expr.main_assignments],
-- 
GitLab