diff --git a/plot2d.py b/plot2d.py index 68cc8d729ed44ea90c4099d18cde182f99c152d3..8d7d0189aa7bce9bd522ea4ece5ecc46367e0570 100644 --- a/plot2d.py +++ b/plot2d.py @@ -255,6 +255,7 @@ def vector_field_magnitude_animation(run_function, plot_setup_function=lambda *_ def scalar_field_animation(run_function, plot_setup_function=lambda *_: None, rescale=True, plot_update_function=lambda *_: None, interval=30, frames=180, **kwargs): + """Animation of scalar field as colored image, see `scalar_field`.""" import matplotlib.animation as animation fig = gcf() @@ -274,7 +275,7 @@ def scalar_field_animation(run_function, plot_setup_function=lambda *_: None, re f_min, f_max = np.min(f), np.max(f) f = (f - f_min) / (f_max - f_min) if hasattr(f, 'mask'): - f = np.ma.masked_array(f, mask=f.mask[:, :, 0]) + f = np.ma.masked_array(f, mask=f.mask[:, :]) f = np.swapaxes(f, 0, 1) im.set_array(f) plot_update_function(im) @@ -283,7 +284,8 @@ def scalar_field_animation(run_function, plot_setup_function=lambda *_: None, re return animation.FuncAnimation(fig, update_figure, interval=interval, frames=frames) -def surface_plot_animation(run_function, frames=90, interval=30): +def surface_plot_animation(run_function, frames=90, interval=30, **kwargs): + """Animation of scalar field as 3D plot.""" from mpl_toolkits.mplot3d import Axes3D import matplotlib.animation as animation import matplotlib.pyplot as plt @@ -291,14 +293,19 @@ def surface_plot_animation(run_function, frames=90, interval=30): fig = plt.figure() ax = fig.add_subplot(111, projection='3d') - x, y, data = run_function(1) - ax.plot_surface(x, y, data, rstride=2, cstride=2, color='b', cmap=cm.coolwarm,) + data = run_function() + x, y = np.meshgrid(np.arange(data.shape[0]), np.arange(data.shape[1]), indexing='ij') + kwargs.setdefault('rstride', 2) + kwargs.setdefault('cstride', 2) + kwargs.setdefault('color', 'b') + kwargs.setdefault('cmap', cm.coolwarm) + ax.plot_surface(x, y, data, **kwargs) ax.set_zlim(-1.0, 1.0) def update_figure(*_): - x_grid, y_grid, d = run_function(1) + d = run_function() ax.clear() - plot = ax.plot_surface(x_grid, y_grid, d, rstride=2, cstride=2, color='b', cmap=cm.coolwarm,) + plot = ax.plot_surface(x, y, d, **kwargs) ax.set_zlim(-1.0, 1.0) return plot,