From 3e7154dc776eedadc52925a868e464e23676c044 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Tue, 24 Apr 2018 13:26:25 +0200
Subject: [PATCH] pystencils: finite difference & session cleanup

---
 fd/__init__.py          |  3 ++-
 fd/finitedifferences.py | 44 +++++++++++++++++++++++++++++------------
 session.py              |  9 ++++++---
 3 files changed, 39 insertions(+), 17 deletions(-)

diff --git a/fd/__init__.py b/fd/__init__.py
index 2fcf3d9cf..39f7094d4 100644
--- a/fd/__init__.py
+++ b/fd/__init__.py
@@ -7,4 +7,5 @@ from .finitedifferences import advection, diffusion, transient, Discretization2n
 
 __all__ = ['Diff', 'DiffOperator', 'diff_terms', 'collect_diffs', 'create_nested_diff', 'replace_diff', 'zero_diffs',
            'evaluate_diffs', 'normalize_diff_order', 'expand_diff_full', 'expand_diff_linear',
-           'expand_diff_products', 'combine_diff_products', 'functional_derivative']
+           'expand_diff_products', 'combine_diff_products', 'functional_derivative',
+           'advection', 'diffusion', 'transient', 'Discretization2ndOrder']
diff --git a/fd/finitedifferences.py b/fd/finitedifferences.py
index b6f05f30e..f39d4c0c7 100644
--- a/fd/finitedifferences.py
+++ b/fd/finitedifferences.py
@@ -1,5 +1,6 @@
 import numpy as np
 import sympy as sp
+from typing import Union, Optional
 
 from pystencils.assignment_collection import AssignmentCollection
 from pystencils.field import Field
@@ -7,24 +8,22 @@ from pystencils.sympyextensions import fast_subs
 from pystencils.fd.derivative import Diff
 
 
-# --------------------------------------- Advection Diffusion ----------------------------------------------------------
+FieldOrFieldAccess = Union[Field, Field.Access]
 
-def advection(advected_scalar, velocity_field, idx=None):
-    """Advection term: divergence( velocity_field * advected_scalar )"""
-    if isinstance(advected_scalar, Field):
-        first_arg = advected_scalar.center
-    elif isinstance(advected_scalar, Field.Access):
-        first_arg = advected_scalar
-    else:
-        raise ValueError("Advected scalar has to be a pystencils Field or Field.Access")
 
-    args = [first_arg, velocity_field if not isinstance(velocity_field, Field) else velocity_field.center]
-    if idx is not None:
-        args.append(idx)
-    return Advection(*args)
+# --------------------------------------- Advection Diffusion ----------------------------------------------------------
 
 
 def diffusion(scalar, diffusion_coeff, idx=None):
+    """Diffusion term ∇·( diffusion_coeff · ∇(scalar))
+
+    Examples:
+        >>> f = Field.create_generic('f', spatial_dimensions=2)
+        >>> diffusion_term = diffusion(scalar=f, diffusion_coeff=sp.Symbol("d"))
+        >>> discretization = Discretization2ndOrder()
+        >>> discretization(diffusion_term)
+        (-4*f_C*d + f_E*d + f_N*d + f_S*d + f_W*d)/dx**2
+    """
     if isinstance(scalar, Field):
         first_arg = scalar.center
     elif isinstance(scalar, Field.Access):
@@ -38,7 +37,26 @@ def diffusion(scalar, diffusion_coeff, idx=None):
     return Diffusion(*args)
 
 
+def advection(advected_scalar: FieldOrFieldAccess, velocity_field: FieldOrFieldAccess, idx: Optional[int] = None):
+    """Advection term  ∇·(velocity_field · advected_scalar)
+
+    Term that describes the advection of a scalar quantity in a velocity field.
+    """
+    if isinstance(advected_scalar, Field):
+        first_arg = advected_scalar.center
+    elif isinstance(advected_scalar, Field.Access):
+        first_arg = advected_scalar
+    else:
+        raise ValueError("Advected scalar has to be a pystencils Field or Field.Access")
+
+    args = [first_arg, velocity_field if not isinstance(velocity_field, Field) else velocity_field.center]
+    if idx is not None:
+        args.append(idx)
+    return Advection(*args)
+
+
 def transient(scalar, idx=None):
+    """Transient term ∂_t(scalar)"""
     if isinstance(scalar, Field):
         args = [scalar.center]
     elif isinstance(scalar, Field.Access):
diff --git a/session.py b/session.py
index fe13926d8..c55c8e498 100644
--- a/session.py
+++ b/session.py
@@ -1,6 +1,9 @@
-from pystencils.sympy_gmpy_bug_workaround import *
-from pystencils import *
+import pystencils.sympy_gmpy_bug_workaround
 import sympy as sp
 import numpy as np
+import pystencils as ps
 import pystencils.plot2d as plt
-from pystencils.jupytersetup import *
+import pystencils.jupytersetup as ps_notebook
+
+
+__all__ = ['sp', 'np', 'ps', 'plt', 'ps_notebook']
-- 
GitLab