|
272 | 272 | "def plot_digits(instances, images_per_row=10, **options):\n",
|
273 | 273 | " size = 28\n",
|
274 | 274 | " images_per_row = min(len(instances), images_per_row)\n",
|
275 |
| - " images = [instance.reshape(size,size) for instance in instances]\n", |
| 275 | + " # This is equivalent to n_rows = ceil(len(instances) / images_per_row):\n", |
276 | 276 | " n_rows = (len(instances) - 1) // images_per_row + 1\n",
|
277 |
| - " row_images = []\n", |
| 277 | + "\n", |
| 278 | + " # Append empty images to fill the end of the grid, if needed:\n", |
278 | 279 | " n_empty = n_rows * images_per_row - len(instances)\n",
|
279 |
| - " images.append(np.zeros((size, size * n_empty)))\n", |
280 |
| - " for row in range(n_rows):\n", |
281 |
| - " rimages = images[row * images_per_row : (row + 1) * images_per_row]\n", |
282 |
| - " row_images.append(np.concatenate(rimages, axis=1))\n", |
283 |
| - " image = np.concatenate(row_images, axis=0)\n", |
284 |
| - " plt.imshow(image, cmap = mpl.cm.binary, **options)\n", |
| 280 | + " padded_instances = np.concatenate([instances, np.zeros((n_empty, size * size))], axis=0)\n", |
| 281 | + "\n", |
| 282 | + " # Reshape the array so it's organized as a grid containing 28×28 images:\n", |
| 283 | + " image_grid = padded_instances.reshape((n_rows, images_per_row, size, size))\n", |
| 284 | + "\n", |
| 285 | + " # Combine axes 0 and 2 (vertical image grid axis, and vertical image axis),\n", |
| 286 | + " # and axes 1 and 3 (horizontal axes). We first need to move the axes that we\n", |
| 287 | + " # want to combine next to each other, using transpose(), and only then we\n", |
| 288 | + " # can reshape:\n", |
| 289 | + " big_image = image_grid.transpose(0, 2, 1, 3).reshape(n_rows * size,\n", |
| 290 | + " images_per_row * size)\n", |
| 291 | + " # Now that we have a big image, we just need to show it:\n", |
| 292 | + " plt.imshow(big_image, cmap = mpl.cm.binary, **options)\n", |
285 | 293 | " plt.axis(\"off\")"
|
286 | 294 | ]
|
287 | 295 | },
|
|
0 commit comments