diff --git a/__init__.py b/__init__.py index 3e7452a2..201b5e56 100644 --- a/__init__.py +++ b/__init__.py @@ -1,3 +1,4 @@ from .models.inception_resnet_v1 import InceptionResnetV1 from .models.mtcnn import MTCNN, PNet, RNet, ONet, prewhiten from .models.utils.detect_face import extract_face +from .models.utils import training diff --git a/examples/train.ipynb b/examples/train.ipynb new file mode 100644 index 00000000..dcf5557f --- /dev/null +++ b/examples/train.ipynb @@ -0,0 +1,320 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Face detection and recognition training pipeline\n", + "\n", + "The following example illustrates how to use the `facenet_pytorch` python package to perform face detection and recogition on an image dataset using an Inception Resnet V1 pretrained on the VGGFace2 dataset.\n", + "\n", + "The following Pytorch methods are included:\n", + "* Datasets\n", + "* Dataloaders\n", + "* GPU/CPU processing" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from facenet_pytorch import MTCNN, InceptionResnetV1, prewhiten, training\n", + "import torch\n", + "from torch.utils.data import DataLoader, SubsetRandomSampler\n", + "from torch import optim\n", + "from torch.optim.lr_scheduler import MultiStepLR\n", + "from torchvision import datasets, transforms\n", + "import numpy as np\n", + "import pandas as pd\n", + "import multiprocessing as mp\n", + "import os" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Define run parameters" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "data_dir = '../../../data/vggface2/train'\n", + "batch_size = 16\n", + "epochs = 15" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Determine if an nvidia GPU is available" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running on device: cpu\n" + ] + } + ], + "source": [ + "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", + "print('Running on device: {}'.format(device))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Define MTCNN module\n", + "\n", + "Default params shown for illustration, but not needed. Note that, since MTCNN is a collection of neural nets and other code, the device must be passed in the following way to enable copying of objects when needed internally.\n", + "\n", + "See `help(MTCNN)` for more details." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "mtcnn = MTCNN(\n", + " image_size=160, margin=0, min_face_size=20,\n", + " thresholds=[0.6, 0.7, 0.7], factor=0.709, prewhiten=True,\n", + " device=device\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Perfom MTCNN facial detection\n", + "\n", + "Iterate through the DataLoader object and obtained cropped faces." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Images processed: 6353 of 6353" + ] + } + ], + "source": [ + "dataset = datasets.ImageFolder(data_dir)\n", + "dataset.idx_to_class = {i:c for c, i in dataset.class_to_idx.items()}\n", + "loader = DataLoader(dataset, collate_fn=lambda x: x[0], num_workers=mp.cpu_count(), shuffle=False)\n", + "\n", + "for i, (x, y) in enumerate(loader):\n", + " print(f'\\rImages processed: {i + 1} of {len(loader)}', end='')\n", + " save_dir = os.path.join(data_dir + '_cropped', dataset.idx_to_class[y])\n", + " os.makedirs(save_dir, exist_ok=True)\n", + " filename = f'{len(os.listdir(save_dir)):05n}.png'\n", + " mtcnn(x, save_path=os.path.join(save_dir, filename))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Define Inception Resnet V1 module\n", + "\n", + "Set classify=True for classifier.\n", + "\n", + "See `help(InceptionResnetV1)` for more details." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "resnet = InceptionResnetV1(\n", + " pretrained='vggface2',\n", + " classify=True,\n", + " num_classes=len(dataset.class_to_idx)\n", + ").to(device)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Define optimizer, scheduler, dataset, and dataloader" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = optim.Adam(resnet.parameters(), lr=0.001)\n", + "scheduler = MultiStepLR(optimizer, [5, 10])\n", + "\n", + "trans = transforms.Compose([\n", + " np.float32,\n", + " transforms.ToTensor(),\n", + " prewhiten\n", + "])\n", + "dataset = datasets.ImageFolder(data_dir + '_cropped', transform=trans)\n", + "img_inds = np.arange(len(dataset))\n", + "np.random.shuffle(img_inds)\n", + "train_inds = img_inds[:int(0.8 * len(img_inds))]\n", + "val_inds = img_inds[int(0.8 * len(img_inds)):]\n", + "\n", + "train_loader = DataLoader(\n", + " dataset,\n", + " num_workers=mp.cpu_count(),\n", + " batch_size=batch_size,\n", + " sampler=SubsetRandomSampler(train_inds)\n", + ")\n", + "val_loader = DataLoader(\n", + " dataset,\n", + " num_workers=mp.cpu_count(),\n", + " batch_size=batch_size,\n", + " sampler=SubsetRandomSampler(val_inds)\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Define loss and evaluation functions" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "loss_fn = torch.nn.CrossEntropyLoss()\n", + "metrics = {\n", + " 'fps': training.BatchTimer(),\n", + " 'acc': training.accuracy\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Train model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "Initial\n", + "----------\n", + "Eval | 80/80 | loss: 2.9421 | fps: 7.6358 | acc: 0.0602 \n", + "\n", + "\n", + "Epoch 1/15\n", + "----------\n", + "Train | 317/317 | loss: 1.9690 | fps: 2.4324 | acc: 0.5260 \n", + "Eval | 80/80 | loss: 1.4802 | fps: 8.2792 | acc: 0.5591 \n", + "\n", + "\n", + "Epoch 2/15\n", + "----------\n", + "Train | 317/317 | loss: 1.0367 | fps: 2.4487 | acc: 0.7467 \n", + "Eval | 80/80 | loss: 0.8572 | fps: 8.0474 | acc: 0.7799 \n", + "\n", + "\n", + "Epoch 3/15\n", + "----------\n", + "Train | 124/317 | loss: 0.6837 | fps: 2.4360 | acc: 0.8362 " + ] + } + ], + "source": [ + "print(f'\\n\\nInitial')\n", + "print('-' * 10)\n", + "resnet.eval()\n", + "training.pass_epoch(\n", + " resnet, loss_fn, val_loader,\n", + " batch_metrics=metrics, show_running=True, device=device\n", + ")\n", + "\n", + "for epoch in range(epochs):\n", + " print(f'\\n\\nEpoch {epoch + 1}/{epochs}')\n", + " print('-' * 10)\n", + "\n", + " resnet.train()\n", + " training.pass_epoch(\n", + " resnet, loss_fn, train_loader, optimizer, scheduler,\n", + " batch_metrics=metrics, show_running=True, device=device\n", + " )\n", + "\n", + " resnet.eval()\n", + " training.pass_epoch(\n", + " resnet, loss_fn, val_loader,\n", + " batch_metrics=metrics, show_running=True, device=device\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/models/inception_resnet_v1.py b/models/inception_resnet_v1.py index 2294cd79..ea0ca062 100644 --- a/models/inception_resnet_v1.py +++ b/models/inception_resnet_v1.py @@ -191,8 +191,9 @@ class InceptionResnetV1(nn.Module): (default: {None}) classify {bool} -- Whether the model should output classification probabilities or feature embeddings. (default: {False}) - num_classes {int} -- Number of output classes. Ignored if 'pretrained' is set, in which - case the number of classes is set to that used for training. (default: {1001}) + num_classes {int} -- Number of output classes. Ignored if 'pretrained' is set, and + num_classes not equal to that used for the pretrained model, the final linear layer + will be randomly initialized. (default: {1001}) """ def __init__(self, pretrained=None, classify=False, num_classes=1001): super().__init__() @@ -202,10 +203,11 @@ def __init__(self, pretrained=None, classify=False, num_classes=1001): self.classify = classify self.num_classes = num_classes + tmp_classes = self.num_classes if pretrained == 'vggface2': - self.num_classes = 8631 + tmp_classes = 8631 elif pretrained == 'casia-webface': - self.num_classes = 10575 + tmp_classes = 10575 # Define layers self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2) @@ -248,11 +250,14 @@ def __init__(self, pretrained=None, classify=False, num_classes=1001): self.last_linear = nn.Linear(1792, 512, bias=False) self.last_bn = nn.BatchNorm1d(512, eps=0.001, momentum=0.1, affine=True) - self.logits = nn.Linear(512, self.num_classes) + self.logits = nn.Linear(512, tmp_classes) self.softmax = nn.Softmax(dim=1) if pretrained is not None: load_weights(self, pretrained) + + if self.num_classes != tmp_classes: + self.logits = nn.Linear(512, self.num_classes) def forward(self, x): """Calculate embeddings or probabilities given a batch of input image tensors. @@ -282,7 +287,6 @@ def forward(self, x): x = F.normalize(x, p=2, dim=1) if self.classify: x = self.logits(x) - x = self.softmax(x) return x diff --git a/models/utils/training.py b/models/utils/training.py new file mode 100644 index 00000000..bab8f306 --- /dev/null +++ b/models/utils/training.py @@ -0,0 +1,123 @@ +import torch +import numpy as np +import time + + +class Logger(object): + + def __init__(self, mode, length, calculate_mean=False): + self.mode = mode + self.length = length + self.calculate_mean = calculate_mean + if self.calculate_mean: + self.fn = lambda x, i: x / (i + 1) + else: + self.fn = lambda x, i: x + + def __call__(self, loss, metrics, i): + track_str = f'\r{self.mode} | {i + 1:5d}/{self.length:<5d}| ' + loss_str = f'loss: {self.fn(loss, i):9.4f} | ' + metric_str = ' | '.join(f'{k}: {self.fn(v, i):9.4f}' for k, v in metrics.items()) + print(track_str + loss_str + metric_str + ' ', end='') + if i + 1 == self.length: + print('') + + +class BatchTimer(object): + """Batch timing class. + Use this class for tracking training and testing time/rate per batch or per sample. + + Keyword Arguments: + rate {bool} -- Whether to report a rate (batches or samples per second) or a time (seconds + per batch or sample). (default: {True}) + per_sample {bool} -- Whether to report times or rates per sample or per batch. + (default: {True}) + """ + + def __init__(self, rate=True, per_sample=True): + self.start = time.time() + self.end = None + self.rate = rate + self.per_sample = per_sample + + def __call__(self, y_pred, y): + self.end = time.time() + elapsed = self.end - self.start + self.start = self.end + self.end = None + + if self.per_sample: + elapsed /= len(y_pred) + if self.rate: + elapsed = 1 / elapsed + + return torch.tensor(elapsed) + + +def accuracy(logits, y): + _, preds = torch.max(logits, 1) + return (preds == y).float().mean() + + +def pass_epoch( + model, loss_fn, loader, optimizer=None, scheduler=None, + batch_metrics={'time': BatchTimer()}, show_running=True, + device='cpu' +): + """Train or evaluate over a data epoch. + + Arguments: + model {torch.nn.Module} -- Pytorch model. + loss_fn {callable} -- A function to compute (scalar) loss. + loader {torch.utils.data.DataLoader} -- A pytorch data loader. + + Keyword Arguments: + optimizer {torch.optim.Optimizer} -- A pytorch optimizer. + scheduler {torch.optim.lr_scheduler._LRScheduler} -- LR scheduler (default: {None}) + batch_metrics {dict} -- Dictionary of metric functions to call on each batch. The default + is a simple timer. A progressive average of these metrics, along with the average + loss, is printed every batch. (default: {{'time': iter_timer()}}) + show_running {bool} -- Whether or not to print losses and metrics for the current batch + or rolling averages. (default: {False}) + device {str or torch.device} -- Device for pytorch to use. (default: {'cpu'}) + + Returns: + tuple(torch.Tensor, dict) -- A tuple of the average loss and a dictionary of average + metric values across the epoch. + """ + + mode = 'Train' if model.training else 'Eval ' + logger = Logger(mode, length=len(loader), calculate_mean=show_running) + loss = 0 + metrics = {} + + for i_batch, (x, y) in enumerate(loader): + x = x.to(device) + y = y.to(device) + y_pred = model(x) + loss_batch = loss_fn(y_pred, y) + + if model.training: + loss_batch.backward() + optimizer.step() + optimizer.zero_grad() + + metrics_batch = {} + for metric_name, metric_fn in batch_metrics.items(): + metrics_batch[metric_name] = metric_fn(y_pred, y).detach().cpu() + metrics[metric_name] = metrics.get(metric_name, 0) + metrics_batch[metric_name] + + loss_batch = loss_batch.detach().cpu() + loss += loss_batch + if show_running: + logger(loss, metrics, i_batch) + else: + logger(loss_batch, metrics_batch, i_batch) + + if model.training and scheduler is not None: + scheduler.step() + + loss = loss / (i_batch + 1) + metrics = {k: v / (i_batch + 1) for k, v in metrics.items()} + + return loss, metrics diff --git a/setup.py b/setup.py index cd1af6ae..93a481ca 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name='facenet-pytorch', - version='0.1.0', + version='0.2.1', author='Tim Esler', author_email='tim.esler@gmail.com', description='Pretrained Pytorch face detection and recognition models', diff --git a/tests/travis_test.py b/tests/travis_test.py index 1822bf94..f0e2f432 100644 --- a/tests/travis_test.py +++ b/tests/travis_test.py @@ -118,8 +118,6 @@ def get_image(path, trans): resnet_pt = InceptionResnetV1(pretrained=ds, classify=True).eval() prob = resnet_pt(aligned) -if sys.platform != 'win32': - assert prob.mean().detach().item() - 9.4563e-05 < 1e-5 # MULTI-FACE TEST