Skip to content

Commit ccf8502

Browse files
committed
Simplify plot_digits() and add comments, fixes ageron#479
1 parent ff86dfd commit ccf8502

File tree

2 files changed

+33
-16
lines changed

2 files changed

+33
-16
lines changed

03_classification.ipynb

+16-8
Original file line numberDiff line numberDiff line change
@@ -272,16 +272,24 @@
272272
"def plot_digits(instances, images_per_row=10, **options):\n",
273273
" size = 28\n",
274274
" images_per_row = min(len(instances), images_per_row)\n",
275-
" images = [instance.reshape(size,size) for instance in instances]\n",
275+
" # This is equivalent to n_rows = ceil(len(instances) / images_per_row):\n",
276276
" n_rows = (len(instances) - 1) // images_per_row + 1\n",
277-
" row_images = []\n",
277+
"\n",
278+
" # Append empty images to fill the end of the grid, if needed:\n",
278279
" n_empty = n_rows * images_per_row - len(instances)\n",
279-
" images.append(np.zeros((size, size * n_empty)))\n",
280-
" for row in range(n_rows):\n",
281-
" rimages = images[row * images_per_row : (row + 1) * images_per_row]\n",
282-
" row_images.append(np.concatenate(rimages, axis=1))\n",
283-
" image = np.concatenate(row_images, axis=0)\n",
284-
" plt.imshow(image, cmap = mpl.cm.binary, **options)\n",
280+
" padded_instances = np.concatenate([instances, np.zeros((n_empty, size * size))], axis=0)\n",
281+
"\n",
282+
" # Reshape the array so it's organized as a grid containing 28×28 images:\n",
283+
" image_grid = padded_instances.reshape((n_rows, images_per_row, size, size))\n",
284+
"\n",
285+
" # Combine axes 0 and 2 (vertical image grid axis, and vertical image axis),\n",
286+
" # and axes 1 and 3 (horizontal axes). We first need to move the axes that we\n",
287+
" # want to combine next to each other, using transpose(), and only then we\n",
288+
" # can reshape:\n",
289+
" big_image = image_grid.transpose(0, 2, 1, 3).reshape(n_rows * size,\n",
290+
" images_per_row * size)\n",
291+
" # Now that we have a big image, we just need to show it:\n",
292+
" plt.imshow(big_image, cmap = mpl.cm.binary, **options)\n",
285293
" plt.axis(\"off\")"
286294
]
287295
},

08_dimensionality_reduction.ipynb

+17-8
Original file line numberDiff line numberDiff line change
@@ -1319,19 +1319,28 @@
13191319
"metadata": {},
13201320
"outputs": [],
13211321
"source": [
1322+
"# EXTRA\n",
13221323
"def plot_digits(instances, images_per_row=5, **options):\n",
13231324
" size = 28\n",
13241325
" images_per_row = min(len(instances), images_per_row)\n",
1325-
" images = [instance.reshape(size,size) for instance in instances]\n",
1326+
" # This is equivalent to n_rows = ceil(len(instances) / images_per_row):\n",
13261327
" n_rows = (len(instances) - 1) // images_per_row + 1\n",
1327-
" row_images = []\n",
1328+
"\n",
1329+
" # Append empty images to fill the end of the grid, if needed:\n",
13281330
" n_empty = n_rows * images_per_row - len(instances)\n",
1329-
" images.append(np.zeros((size, size * n_empty)))\n",
1330-
" for row in range(n_rows):\n",
1331-
" rimages = images[row * images_per_row : (row + 1) * images_per_row]\n",
1332-
" row_images.append(np.concatenate(rimages, axis=1))\n",
1333-
" image = np.concatenate(row_images, axis=0)\n",
1334-
" plt.imshow(image, cmap = mpl.cm.binary, **options)\n",
1331+
" padded_instances = np.concatenate([instances, np.zeros((n_empty, size * size))], axis=0)\n",
1332+
"\n",
1333+
" # Reshape the array so it's organized as a grid containing 28×28 images:\n",
1334+
" image_grid = padded_instances.reshape((n_rows, images_per_row, size, size))\n",
1335+
"\n",
1336+
" # Combine axes 0 and 2 (vertical image grid axis, and vertical image axis),\n",
1337+
" # and axes 1 and 3 (horizontal axes). We first need to move the axes that we\n",
1338+
" # want to combine next to each other, using transpose(), and only then we\n",
1339+
" # can reshape:\n",
1340+
" big_image = image_grid.transpose(0, 2, 1, 3).reshape(n_rows * size,\n",
1341+
" images_per_row * size)\n",
1342+
" # Now that we have a big image, we just need to show it:\n",
1343+
" plt.imshow(big_image, cmap = mpl.cm.binary, **options)\n",
13351344
" plt.axis(\"off\")"
13361345
]
13371346
},

0 commit comments

Comments
 (0)