diff --git a/__init__.py b/__init__.py index 0c574e9129ff8eb1998437e400dce2bac12eba0c..20db8d6b9047045a8e2b7d642eb413e7f71ea9f0 100644 --- a/__init__.py +++ b/__init__.py @@ -10,6 +10,7 @@ from .assignment import Assignment from .sympyextensions import SymbolCreator from .datahandling import create_data_handling from .kernel_decorator import kernel +from . import fd __all__ = ['Field', 'FieldType', 'fields', 'TypedSymbol', @@ -20,5 +21,6 @@ __all__ = ['Field', 'FieldType', 'fields', 'Assignment', 'SymbolCreator', 'create_data_handling', - 'kernel'] + 'kernel', + 'fd'] diff --git a/cpu/kernelcreation.py b/cpu/kernelcreation.py index 6884f3d8c5f734c8902ce194a6dd24632dbcc968..557f817eb8117334c1d64be662fd10c1c15fcc07 100644 --- a/cpu/kernelcreation.py +++ b/cpu/kernelcreation.py @@ -17,8 +17,7 @@ AssignmentOrAstNodeList = List[Union[Assignment, ast.Node]] def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "kernel", type_info='double', split_groups=(), iteration_slice=None, ghost_layers=None, skip_independence_check=False) -> KernelFunction: - """ - Creates an abstract syntax tree for a kernel function, by taking a list of update rules. + """Creates an abstract syntax tree for a kernel function, by taking a list of update rules. Loops are created according to the field accesses in the equations. diff --git a/jupytersetup.py b/jupytersetup.py index f478bf18e48e0cc2ecfd5f2f6563841f5235f038..b4b5cafff184df85661d9b4f9430797d54f2ef73 100644 --- a/jupytersetup.py +++ b/jupytersetup.py @@ -5,8 +5,7 @@ from tempfile import NamedTemporaryFile import base64 import sympy as sp -__all__ = ['log_progress', 'make_imshow_animation', 'make_surface_plot_animation', - 'display_animation', 'set_display_mode'] +__all__ = ['log_progress', 'make_imshow_animation', 'display_animation', 'set_display_mode'] def log_progress(sequence, every=None, size=None, name='Items'): @@ -98,28 +97,6 @@ def make_imshow_animation(grid, grid_update_function, frames=90, **_): return animation.FuncAnimation(fig, partial(update_figure, image=grid), frames=frames) -def make_surface_plot_animation(run_function, frames=90, interval=30): - from mpl_toolkits.mplot3d import Axes3D - import matplotlib.animation as animation - import matplotlib.pyplot as plt - from matplotlib import cm - - fig = plt.figure() - ax = fig.add_subplot(111, projection='3d') - x, y, data = run_function(1) - ax.plot_surface(x, y, data, rstride=2, cstride=2, color='b', cmap=cm.coolwarm,) - ax.set_zlim(-1.0, 1.0) - - def update_figure(*_): - x_grid, y_grid, d = run_function(1) - ax.clear() - plot = ax.plot_surface(x_grid, y_grid, d, rstride=2, cstride=2, color='b', cmap=cm.coolwarm,) - ax.set_zlim(-1.0, 1.0) - return plot, - - return animation.FuncAnimation(fig, update_figure, interval=interval, frames=frames, blit=False) - - # ------- Version 1: Embed the animation as HTML5 video --------- ---------------------------------- def display_as_html_video(animation, fps=30, show=True, **_): diff --git a/plot2d.py b/plot2d.py index e0935c831186e378413613221f7dad25a227c9da..6157f93141343c64a00fe0e8acd8d4398ad17b20 100644 --- a/plot2d.py +++ b/plot2d.py @@ -251,3 +251,25 @@ def vector_field_magnitude_animation(run_function, plot_setup_function=lambda *_ return im, return animation.FuncAnimation(fig, update_figure, interval=interval, frames=frames) + + +def surface_plot_animation(run_function, frames=90, interval=30): + from mpl_toolkits.mplot3d import Axes3D + import matplotlib.animation as animation + import matplotlib.pyplot as plt + from matplotlib import cm + + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + x, y, data = run_function(1) + ax.plot_surface(x, y, data, rstride=2, cstride=2, color='b', cmap=cm.coolwarm,) + ax.set_zlim(-1.0, 1.0) + + def update_figure(*_): + x_grid, y_grid, d = run_function(1) + ax.clear() + plot = ax.plot_surface(x_grid, y_grid, d, rstride=2, cstride=2, color='b', cmap=cm.coolwarm,) + ax.set_zlim(-1.0, 1.0) + return plot, + + return animation.FuncAnimation(fig, update_figure, interval=interval, frames=frames, blit=False)