Skip to content
Snippets Groups Projects
demo_benchmark.ipynb 47.4 KiB
Newer Older
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pystencils.session import *\n",
    "import timeit\n",
    "%load_ext Cython"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Demo: Benchmark numpy, Cython, pystencils\n",
    "In this notebook we compare and benchmark different ways of implementing a simple stencil kernel in Python.\n",
    "Our simple example computes the average of the four neighbors in 2D and stores it in a second array. To prevent out-of-bounds accesses, we skip the cells at the border and compute values only in the range `[1:-1, 1:-1]`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Implementations\n",
    "\n",
    "The first implementation is a pure Python implementation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def avg_pure_python(src, dst):       \n",
    "    for x in range(1, src.shape[0] - 1):\n",
    "        for y in range(1, src.shape[1] - 1):\n",
    "            dst[x, y] = (src[x + 1, y] + src[x - 1, y] +\n",
    "                         src[x, y + 1] + src[x, y - 1]) / 4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Obviously, this will be a rather slow version, since the loops are written directly in Python. \n",
    "\n",
    "Next, we use *numpy* functions to delegate the looping to numpy. The first version uses the `roll` function to shift the array by one element in each direction. This version has to allocate a new array for each accessed neighbor."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def avg_numpy_roll(src, dst):\n",
    "    neighbors = [np.roll(src, axis=a, shift=s) for a in (0, 1) for s in (-1, 1)]\n",
    "    np.divide(sum(neighbors), 4, out=dst)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Using views, we can get rid of the additional copies:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def avg_numpy_slice(src, dst):\n",
    "    dst[1:-1, 1:-1] = src[2:, 1:-1] + src[:-2, 1:-1] + \\\n",
    "                      src[1:-1, 2:] + src[1:-1, :-2]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To further optimize the kernel we switch to Cython, to get a compiled C version."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%cython\n",
    "import cython\n",
    "\n",
    "@cython.boundscheck(False)\n",
    "@cython.wraparound(False)\n",
    "def avg_cython(object[double, ndim=2] src, object[double, ndim=2] dst):\n",
    "    cdef int xs, ys, x, y\n",
    "    xs, ys = src.shape\n",
    "    for x in range(1, xs - 1):\n",
    "        for y in range(1, ys - 1):\n",
    "            dst[x, y] = (src[x + 1, y] + src[x - 1, y] +\n",
    "                         src[x, y + 1] + src[x, y - 1]) / 4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If available, we also try the numba just-in-time compiler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    from numba import jit\n",
    "\n",
    "    @jit(nopython=True)\n",
    "    def avg_numba(src, dst):\n",
    "        dst[1:-1, 1:-1] = src[2:, 1:-1] + src[:-2, 1:-1] + \\\n",
    "                          src[1:-1, 2:] + src[1:-1, :-2]\n",
    "except ImportError:\n",
    "    pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And finally we also create a *pystencils* version of the same stencil code:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "src, dst = ps.fields(\"src, dst: [2D]\")\n",
    "\n",
    "update = ps.Assignment(dst[0,0], \n",
    "                       (src[1, 0] + src[-1, 0] + src[0, 1] + src[0, -1]) / 4)\n",
    "avg_pystencils = ps.create_kernel(update).compile()"
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_implementations = {\n",
    "    'pure Python': avg_pure_python,\n",
    "    'numpy roll': avg_numpy_roll,\n",
    "    'numpy slice': avg_numpy_slice,\n",
    "    'pystencils': avg_pystencils,\n",
    "}\n",
    "if 'avg_cython' in globals():\n",
    "    all_implementations['Cython'] = avg_cython\n",
    "if 'avg_numba' in globals():\n",
    "    all_implementations['numba'] = avg_numba"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Benchmark functions\n",
    "\n",
    "We implement a short function to get in- and output arrays of a given shape and to measure the runtime."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_arrays(shape):\n",
    "    in_arr = np.random.rand(*shape)\n",
    "    out_arr = np.empty_like(in_arr)\n",
    "    return in_arr, out_arr\n",
    "\n",
    "def do_benchmark(func, shape):\n",
    "    in_arr, out_arr = get_arrays(shape)\n",
    "    func(src=in_arr, dst=out_arr) # warmup\n",
    "    timer = timeit.Timer('f(src=src, dst=dst)', globals={'f': func, 'src': in_arr, 'dst': out_arr})\n",
    "    calls, time_taken = timer.autorange()\n",
    "    return time_taken / calls"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Comparison"
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 576x576 with 3 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plot_order = ['pystencils', 'Cython', 'numba', 'numpy slice', 'numpy roll', 'pure Python']\n",
    "plot_order = [p for p in plot_order if p in all_implementations]\n",
    "\n",
    "def bar_plot(*shape):\n",
    "    names = plot_order\n",
    "    runtimes = tuple(do_benchmark(all_implementations[name], shape) for name in names)\n",
    "    for runtime, name in zip(runtimes, names):\n",
    "        # assert that pystencils is the fastest\n",
    "        # if some change degrades performance of pystencils, we see this automatically in CI system\n",
    "        assert runtime >= runtimes[names.index('pystencils')], \"pystencils is slower than \" + name\n",
    "    speedups = tuple(runtime / min(runtimes) for runtime in runtimes)\n",
    "    y_pos = np.arange(len(names))\n",
    "    labels = tuple(f\"{name} ({round(speedup, 1)} x)\" for name, speedup in zip(names, speedups))\n",
    "    \n",
    "    plt.text(0.5, 0.5, f\"Size {shape}\", horizontalalignment='center', fontsize=16,\n",
    "             verticalalignment='center', transform=plt.gca().transAxes)\n",
    "    plt.barh(y_pos, runtimes, log=True)\n",
    "     \n",
    "    plt.yticks(y_pos, labels);\n",
    "    plt.xlabel('Runtime of single iteration')\n",
    "    \n",
    "plt.figure(figsize=(8, 8))\n",
    "plt.subplot(3, 1, 1)\n",
    "bar_plot(32, 32)\n",
    "\n",
    "plt.subplot(3, 1, 2)\n",
    "bar_plot(128, 128)\n",
    "\n",
    "plt.subplot(3, 1, 3)\n",
    "bar_plot(2048, 2048)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "All runtimes are plotted logarithmically. Numbers next to the labels show how much slower the version is than the fastest one."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}