diff --git a/__init__.py b/__init__.py
index 0c574e9129ff8eb1998437e400dce2bac12eba0c..20db8d6b9047045a8e2b7d642eb413e7f71ea9f0 100644
--- a/__init__.py
+++ b/__init__.py
@@ -10,6 +10,7 @@ from .assignment import Assignment
 from .sympyextensions import SymbolCreator
 from .datahandling import create_data_handling
 from .kernel_decorator import kernel
+from . import fd
 
 __all__ = ['Field', 'FieldType', 'fields',
            'TypedSymbol',
@@ -20,5 +21,6 @@ __all__ = ['Field', 'FieldType', 'fields',
            'Assignment',
            'SymbolCreator',
            'create_data_handling',
-           'kernel']
+           'kernel',
+           'fd']
 
diff --git a/cpu/kernelcreation.py b/cpu/kernelcreation.py
index 6884f3d8c5f734c8902ce194a6dd24632dbcc968..557f817eb8117334c1d64be662fd10c1c15fcc07 100644
--- a/cpu/kernelcreation.py
+++ b/cpu/kernelcreation.py
@@ -17,8 +17,7 @@ AssignmentOrAstNodeList = List[Union[Assignment, ast.Node]]
 def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "kernel", type_info='double',
                   split_groups=(), iteration_slice=None, ghost_layers=None,
                   skip_independence_check=False) -> KernelFunction:
-    """
-    Creates an abstract syntax tree for a kernel function, by taking a list of update rules.
+    """Creates an abstract syntax tree for a kernel function, by taking a list of update rules.
 
     Loops are created according to the field accesses in the equations.
 
diff --git a/jupytersetup.py b/jupytersetup.py
index f478bf18e48e0cc2ecfd5f2f6563841f5235f038..b4b5cafff184df85661d9b4f9430797d54f2ef73 100644
--- a/jupytersetup.py
+++ b/jupytersetup.py
@@ -5,8 +5,7 @@ from tempfile import NamedTemporaryFile
 import base64
 import sympy as sp
 
-__all__ = ['log_progress', 'make_imshow_animation', 'make_surface_plot_animation',
-           'display_animation', 'set_display_mode']
+__all__ = ['log_progress', 'make_imshow_animation', 'display_animation', 'set_display_mode']
 
 
 def log_progress(sequence, every=None, size=None, name='Items'):
@@ -98,28 +97,6 @@ def make_imshow_animation(grid, grid_update_function, frames=90, **_):
     return animation.FuncAnimation(fig, partial(update_figure, image=grid), frames=frames)
 
 
-def make_surface_plot_animation(run_function, frames=90, interval=30):
-    from mpl_toolkits.mplot3d import Axes3D
-    import matplotlib.animation as animation
-    import matplotlib.pyplot as plt
-    from matplotlib import cm
-
-    fig = plt.figure()
-    ax = fig.add_subplot(111, projection='3d')
-    x, y, data = run_function(1)
-    ax.plot_surface(x, y, data, rstride=2, cstride=2, color='b', cmap=cm.coolwarm,)
-    ax.set_zlim(-1.0, 1.0)
-
-    def update_figure(*_):
-        x_grid, y_grid, d = run_function(1)
-        ax.clear()
-        plot = ax.plot_surface(x_grid, y_grid, d, rstride=2, cstride=2, color='b', cmap=cm.coolwarm,)
-        ax.set_zlim(-1.0, 1.0)
-        return plot,
-
-    return animation.FuncAnimation(fig, update_figure, interval=interval, frames=frames, blit=False)
-
-
 # -------   Version 1: Embed the animation as HTML5 video --------- ----------------------------------
 
 def display_as_html_video(animation, fps=30, show=True, **_):
diff --git a/plot2d.py b/plot2d.py
index e0935c831186e378413613221f7dad25a227c9da..6157f93141343c64a00fe0e8acd8d4398ad17b20 100644
--- a/plot2d.py
+++ b/plot2d.py
@@ -251,3 +251,25 @@ def vector_field_magnitude_animation(run_function, plot_setup_function=lambda *_
         return im,
 
     return animation.FuncAnimation(fig, update_figure, interval=interval, frames=frames)
+
+
+def surface_plot_animation(run_function, frames=90, interval=30):
+    from mpl_toolkits.mplot3d import Axes3D
+    import matplotlib.animation as animation
+    import matplotlib.pyplot as plt
+    from matplotlib import cm
+
+    fig = plt.figure()
+    ax = fig.add_subplot(111, projection='3d')
+    x, y, data = run_function(1)
+    ax.plot_surface(x, y, data, rstride=2, cstride=2, color='b', cmap=cm.coolwarm,)
+    ax.set_zlim(-1.0, 1.0)
+
+    def update_figure(*_):
+        x_grid, y_grid, d = run_function(1)
+        ax.clear()
+        plot = ax.plot_surface(x_grid, y_grid, d, rstride=2, cstride=2, color='b', cmap=cm.coolwarm,)
+        ax.set_zlim(-1.0, 1.0)
+        return plot,
+
+    return animation.FuncAnimation(fig, update_figure, interval=interval, frames=frames, blit=False)