Skip to content
Snippets Groups Projects
Commit 8c744109 authored by Martin Bauer's avatar Martin Bauer
Browse files

Bugfix in pystencils plotting functions and sphinx doc for plot

parent ca31d3cc
No related merge requests found
......@@ -255,6 +255,7 @@ def vector_field_magnitude_animation(run_function, plot_setup_function=lambda *_
def scalar_field_animation(run_function, plot_setup_function=lambda *_: None, rescale=True,
plot_update_function=lambda *_: None, interval=30, frames=180, **kwargs):
"""Animation of scalar field as colored image, see `scalar_field`."""
import matplotlib.animation as animation
fig = gcf()
......@@ -274,7 +275,7 @@ def scalar_field_animation(run_function, plot_setup_function=lambda *_: None, re
f_min, f_max = np.min(f), np.max(f)
f = (f - f_min) / (f_max - f_min)
if hasattr(f, 'mask'):
f = np.ma.masked_array(f, mask=f.mask[:, :, 0])
f = np.ma.masked_array(f, mask=f.mask[:, :])
f = np.swapaxes(f, 0, 1)
im.set_array(f)
plot_update_function(im)
......@@ -283,7 +284,8 @@ def scalar_field_animation(run_function, plot_setup_function=lambda *_: None, re
return animation.FuncAnimation(fig, update_figure, interval=interval, frames=frames)
def surface_plot_animation(run_function, frames=90, interval=30):
def surface_plot_animation(run_function, frames=90, interval=30, **kwargs):
"""Animation of scalar field as 3D plot."""
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.animation as animation
import matplotlib.pyplot as plt
......@@ -291,14 +293,19 @@ def surface_plot_animation(run_function, frames=90, interval=30):
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
x, y, data = run_function(1)
ax.plot_surface(x, y, data, rstride=2, cstride=2, color='b', cmap=cm.coolwarm,)
data = run_function()
x, y = np.meshgrid(np.arange(data.shape[0]), np.arange(data.shape[1]), indexing='ij')
kwargs.setdefault('rstride', 2)
kwargs.setdefault('cstride', 2)
kwargs.setdefault('color', 'b')
kwargs.setdefault('cmap', cm.coolwarm)
ax.plot_surface(x, y, data, **kwargs)
ax.set_zlim(-1.0, 1.0)
def update_figure(*_):
x_grid, y_grid, d = run_function(1)
d = run_function()
ax.clear()
plot = ax.plot_surface(x_grid, y_grid, d, rstride=2, cstride=2, color='b', cmap=cm.coolwarm,)
plot = ax.plot_surface(x, y, d, **kwargs)
ax.set_zlim(-1.0, 1.0)
return plot,
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment