Skip to content
Snippets Groups Projects
Commit 7c0fc3e8 authored by Martin Bauer's avatar Martin Bauer
Browse files

pygrandchem: approximate divs & sqrts

parent 0d12a2ac
No related merge requests found
...@@ -9,15 +9,28 @@ from pystencils.simp import AssignmentCollection ...@@ -9,15 +9,28 @@ from pystencils.simp import AssignmentCollection
class fast_division(sp.Function): class fast_division(sp.Function):
nargs = (2,) nargs = (2,)
# noinspection PyPep8Naming # noinspection PyPep8Naming
class fast_sqrt(sp.Function): class fast_sqrt(sp.Function):
nargs = (1, ) nargs = (1, )
# noinspection PyPep8Naming # noinspection PyPep8Naming
class fast_inv_sqrt(sp.Function): class fast_inv_sqrt(sp.Function):
nargs = (1, ) nargs = (1, )
def _run(term, visitor):
if isinstance(term, AssignmentCollection):
new_main_assignments = _run(term.main_assignments, visitor)
new_subexpressions = _run(term.subexpressions, visitor)
return term.copy(new_main_assignments, new_subexpressions)
elif isinstance(term, list):
return [_run(e, visitor) for e in term]
else:
return visitor(term)
def insert_fast_sqrts(term: Union[sp.Expr, List[sp.Expr], AssignmentCollection]): def insert_fast_sqrts(term: Union[sp.Expr, List[sp.Expr], AssignmentCollection]):
def visit(expr): def visit(expr):
if isinstance(expr, Node): if isinstance(expr, Node):
...@@ -31,15 +44,7 @@ def insert_fast_sqrts(term: Union[sp.Expr, List[sp.Expr], AssignmentCollection]) ...@@ -31,15 +44,7 @@ def insert_fast_sqrts(term: Union[sp.Expr, List[sp.Expr], AssignmentCollection])
else: else:
new_args = [visit(a) for a in expr.args] new_args = [visit(a) for a in expr.args]
return expr.func(*new_args) if new_args else expr return expr.func(*new_args) if new_args else expr
return _run(term, visit)
if isinstance(term, AssignmentCollection):
new_main_assignments = insert_fast_sqrts(term.main_assignments)
new_subexpressions = insert_fast_sqrts(term.subexpressions)
return term.copy(new_main_assignments, new_subexpressions)
elif isinstance(term, list):
return [insert_fast_sqrts(e) for e in term]
else:
return visit(term)
def insert_fast_divisions(term: Union[sp.Expr, List[sp.Expr], AssignmentCollection]): def insert_fast_divisions(term: Union[sp.Expr, List[sp.Expr], AssignmentCollection]):
...@@ -65,11 +70,4 @@ def insert_fast_divisions(term: Union[sp.Expr, List[sp.Expr], AssignmentCollecti ...@@ -65,11 +70,4 @@ def insert_fast_divisions(term: Union[sp.Expr, List[sp.Expr], AssignmentCollecti
new_args = [visit(a) for a in expr.args] new_args = [visit(a) for a in expr.args]
return expr.func(*new_args) if new_args else expr return expr.func(*new_args) if new_args else expr
if isinstance(term, AssignmentCollection): return _run(term, visit)
new_main_assignments = insert_fast_divisions(term.main_assignments)
new_subexpressions = insert_fast_divisions(term.subexpressions)
return term.copy(new_main_assignments, new_subexpressions)
elif isinstance(term, list):
return [insert_fast_divisions(e) for e in term]
else:
return visit(term)
...@@ -135,7 +135,7 @@ class Discretization2ndOrder: ...@@ -135,7 +135,7 @@ class Discretization2ndOrder:
def __call__(self, expr): def __call__(self, expr):
if isinstance(expr, list): if isinstance(expr, list):
return [self(e) for e in expr] return [self(e) for e in expr]
elif isinstance(expr, sp.Matrix): elif isinstance(expr, sp.Matrix) or isinstance(expr, sp.ImmutableDenseMatrix):
return expr.applyfunc(self.__call__) return expr.applyfunc(self.__call__)
elif isinstance(expr, AssignmentCollection): elif isinstance(expr, AssignmentCollection):
return expr.copy(main_assignments=[e for e in expr.main_assignments], return expr.copy(main_assignments=[e for e in expr.main_assignments],
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment