Skip to content
Snippets Groups Projects
Commit 7c0fc3e8 authored by Martin Bauer's avatar Martin Bauer
Browse files

pygrandchem: approximate divs & sqrts

parent 0d12a2ac
No related merge requests found
......@@ -9,15 +9,28 @@ from pystencils.simp import AssignmentCollection
class fast_division(sp.Function):
nargs = (2,)
# noinspection PyPep8Naming
class fast_sqrt(sp.Function):
nargs = (1, )
# noinspection PyPep8Naming
class fast_inv_sqrt(sp.Function):
nargs = (1, )
def _run(term, visitor):
if isinstance(term, AssignmentCollection):
new_main_assignments = _run(term.main_assignments, visitor)
new_subexpressions = _run(term.subexpressions, visitor)
return term.copy(new_main_assignments, new_subexpressions)
elif isinstance(term, list):
return [_run(e, visitor) for e in term]
else:
return visitor(term)
def insert_fast_sqrts(term: Union[sp.Expr, List[sp.Expr], AssignmentCollection]):
def visit(expr):
if isinstance(expr, Node):
......@@ -31,15 +44,7 @@ def insert_fast_sqrts(term: Union[sp.Expr, List[sp.Expr], AssignmentCollection])
else:
new_args = [visit(a) for a in expr.args]
return expr.func(*new_args) if new_args else expr
if isinstance(term, AssignmentCollection):
new_main_assignments = insert_fast_sqrts(term.main_assignments)
new_subexpressions = insert_fast_sqrts(term.subexpressions)
return term.copy(new_main_assignments, new_subexpressions)
elif isinstance(term, list):
return [insert_fast_sqrts(e) for e in term]
else:
return visit(term)
return _run(term, visit)
def insert_fast_divisions(term: Union[sp.Expr, List[sp.Expr], AssignmentCollection]):
......@@ -65,11 +70,4 @@ def insert_fast_divisions(term: Union[sp.Expr, List[sp.Expr], AssignmentCollecti
new_args = [visit(a) for a in expr.args]
return expr.func(*new_args) if new_args else expr
if isinstance(term, AssignmentCollection):
new_main_assignments = insert_fast_divisions(term.main_assignments)
new_subexpressions = insert_fast_divisions(term.subexpressions)
return term.copy(new_main_assignments, new_subexpressions)
elif isinstance(term, list):
return [insert_fast_divisions(e) for e in term]
else:
return visit(term)
return _run(term, visit)
......@@ -135,7 +135,7 @@ class Discretization2ndOrder:
def __call__(self, expr):
if isinstance(expr, list):
return [self(e) for e in expr]
elif isinstance(expr, sp.Matrix):
elif isinstance(expr, sp.Matrix) or isinstance(expr, sp.ImmutableDenseMatrix):
return expr.applyfunc(self.__call__)
elif isinstance(expr, AssignmentCollection):
return expr.copy(main_assignments=[e for e in expr.main_assignments],
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment