diff --git a/plot2d.py b/plot2d.py
index 68cc8d729ed44ea90c4099d18cde182f99c152d3..8d7d0189aa7bce9bd522ea4ece5ecc46367e0570 100644
--- a/plot2d.py
+++ b/plot2d.py
@@ -255,6 +255,7 @@ def vector_field_magnitude_animation(run_function, plot_setup_function=lambda *_
 
 def scalar_field_animation(run_function, plot_setup_function=lambda *_: None, rescale=True,
                            plot_update_function=lambda *_: None, interval=30, frames=180, **kwargs):
+    """Animation of scalar field as colored image, see `scalar_field`."""
     import matplotlib.animation as animation
 
     fig = gcf()
@@ -274,7 +275,7 @@ def scalar_field_animation(run_function, plot_setup_function=lambda *_: None, re
             f_min, f_max = np.min(f), np.max(f)
             f = (f - f_min) / (f_max - f_min)
         if hasattr(f, 'mask'):
-            f = np.ma.masked_array(f, mask=f.mask[:, :, 0])
+            f = np.ma.masked_array(f, mask=f.mask[:, :])
         f = np.swapaxes(f, 0, 1)
         im.set_array(f)
         plot_update_function(im)
@@ -283,7 +284,8 @@ def scalar_field_animation(run_function, plot_setup_function=lambda *_: None, re
     return animation.FuncAnimation(fig, update_figure, interval=interval, frames=frames)
 
 
-def surface_plot_animation(run_function, frames=90, interval=30):
+def surface_plot_animation(run_function, frames=90, interval=30, **kwargs):
+    """Animation of scalar field as 3D plot."""
     from mpl_toolkits.mplot3d import Axes3D
     import matplotlib.animation as animation
     import matplotlib.pyplot as plt
@@ -291,14 +293,19 @@ def surface_plot_animation(run_function, frames=90, interval=30):
 
     fig = plt.figure()
     ax = fig.add_subplot(111, projection='3d')
-    x, y, data = run_function(1)
-    ax.plot_surface(x, y, data, rstride=2, cstride=2, color='b', cmap=cm.coolwarm,)
+    data = run_function()
+    x, y = np.meshgrid(np.arange(data.shape[0]), np.arange(data.shape[1]), indexing='ij')
+    kwargs.setdefault('rstride', 2)
+    kwargs.setdefault('cstride', 2)
+    kwargs.setdefault('color', 'b')
+    kwargs.setdefault('cmap', cm.coolwarm)
+    ax.plot_surface(x, y, data, **kwargs)
     ax.set_zlim(-1.0, 1.0)
 
     def update_figure(*_):
-        x_grid, y_grid, d = run_function(1)
+        d = run_function()
         ax.clear()
-        plot = ax.plot_surface(x_grid, y_grid, d, rstride=2, cstride=2, color='b', cmap=cm.coolwarm,)
+        plot = ax.plot_surface(x, y, d, **kwargs)
         ax.set_zlim(-1.0, 1.0)
         return plot,