Skip to content

Commit

Permalink
Add a few fine-tuning improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
timesler committed Sep 10, 2019
1 parent 1c91225 commit ed9d89e
Showing 1 changed file with 14 additions and 23 deletions.
37 changes: 14 additions & 23 deletions examples/finetune.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
"source": [
"# Face detection and recognition training pipeline\n",
"\n",
"The following example illustrates how to use the `facenet_pytorch` python package to perform face detection and recogition on an image dataset using an Inception Resnet V1 pretrained on the VGGFace2 dataset."
"The following example illustrates how to fine-tune an InceptionResnetV1 model on your own dataset. This will mostly follow standard pytorch training patterns."
]
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -23,8 +23,6 @@
"from torch.utils.tensorboard import SummaryWriter\n",
"from torchvision import datasets, transforms\n",
"import numpy as np\n",
"import pandas as pd\n",
"import multiprocessing as mp\n",
"import os"
]
},
Expand All @@ -37,13 +35,14 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data_dir = '/mnt/windows/Users/times/Data/vggface2/test'\n",
"batch_size = 32\n",
"epochs = 15"
"epochs = 15\n",
"workers = 4"
]
},
{
Expand All @@ -55,17 +54,9 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running on device: cuda:0\n"
]
}
],
"outputs": [],
"source": [
"device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
"print('Running on device: {}'.format(device))"
Expand All @@ -84,7 +75,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -114,7 +105,7 @@
"source": [
"dataset = datasets.ImageFolder(data_dir)\n",
"dataset.idx_to_class = {i:c for c, i in dataset.class_to_idx.items()}\n",
"loader = DataLoader(dataset, collate_fn=lambda x: x[0], num_workers=mp.cpu_count(), shuffle=False)\n",
"loader = DataLoader(dataset, collate_fn=lambda x: x[0], num_workers=workers)\n",
"\n",
"for i, (x, y) in enumerate(loader):\n",
" print(f'\\rImages processed: {i + 1:8d} of {len(loader):8d}', end='')\n",
Expand All @@ -139,7 +130,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -159,7 +150,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -179,13 +170,13 @@
"\n",
"train_loader = DataLoader(\n",
" dataset,\n",
" num_workers=mp.cpu_count(),\n",
" num_workers=workers,\n",
" batch_size=batch_size,\n",
" sampler=SubsetRandomSampler(train_inds)\n",
")\n",
"val_loader = DataLoader(\n",
" dataset,\n",
" num_workers=mp.cpu_count(),\n",
" num_workers=workers,\n",
" batch_size=batch_size,\n",
" sampler=SubsetRandomSampler(val_inds)\n",
")"
Expand All @@ -200,7 +191,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down

0 comments on commit ed9d89e

Please sign in to comment.