forked from timesler/facenet-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
171 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .models.inception_resnet_v1 import InceptionResnetV1 | ||
from .models.mtcnn import MTCNN, PNet, RNet, ONet, prewhiten | ||
from .models.utils.detect_face import extract_face | ||
from .models.utils import training |
Binary file added
BIN
+949 Bytes
...09_18-46-44_tim-xps-ubuntu/acc/Train/events.out.tfevents.1568080115.tim-xps-ubuntu.7662.6
Binary file not shown.
Binary file added
BIN
+79 Bytes
...09_18-46-44_tim-xps-ubuntu/acc/Valid/events.out.tfevents.1568080114.tim-xps-ubuntu.7662.3
Binary file not shown.
Binary file added
BIN
+40 Bytes
...s/runs/Sep09_18-46-44_tim-xps-ubuntu/events.out.tfevents.1568080004.tim-xps-ubuntu.7662.0
Binary file not shown.
Binary file added
BIN
+949 Bytes
...09_18-46-44_tim-xps-ubuntu/fps/Train/events.out.tfevents.1568080115.tim-xps-ubuntu.7662.5
Binary file not shown.
Binary file added
BIN
+79 Bytes
...09_18-46-44_tim-xps-ubuntu/fps/Valid/events.out.tfevents.1568080114.tim-xps-ubuntu.7662.2
Binary file not shown.
Binary file added
BIN
+971 Bytes
...9_18-46-44_tim-xps-ubuntu/loss/Train/events.out.tfevents.1568080115.tim-xps-ubuntu.7662.4
Binary file not shown.
Binary file added
BIN
+80 Bytes
...9_18-46-44_tim-xps-ubuntu/loss/Valid/events.out.tfevents.1568080114.tim-xps-ubuntu.7662.1
Binary file not shown.
Binary file added
BIN
+40 Bytes
.../runs/Sep09_18-59-28_tim-xps-ubuntu/events.out.tfevents.1568080768.tim-xps-ubuntu.31566.0
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
import torch | ||
import numpy as np | ||
import time | ||
|
||
|
||
class Logger(object): | ||
|
||
def __init__(self, mode, length, calculate_mean=False): | ||
self.mode = mode | ||
self.length = length | ||
self.calculate_mean = calculate_mean | ||
if self.calculate_mean: | ||
self.fn = lambda x, i: x / (i + 1) | ||
else: | ||
self.fn = lambda x, i: x | ||
|
||
def __call__(self, loss, metrics, i): | ||
track_str = f'\r{self.mode} | {i + 1:5d}/{self.length:<5d}| ' | ||
loss_str = f'loss: {self.fn(loss, i):9.4f} | ' | ||
metric_str = ' | '.join(f'{k}: {self.fn(v, i):9.4f}' for k, v in metrics.items()) | ||
print(track_str + loss_str + metric_str + ' ', end='') | ||
if i + 1 == self.length: | ||
print('') | ||
|
||
|
||
class BatchTimer(object): | ||
"""Batch timing class. | ||
Use this class for tracking training and testing time/rate per batch or per sample. | ||
Keyword Arguments: | ||
rate {bool} -- Whether to report a rate (batches or samples per second) or a time (seconds | ||
per batch or sample). (default: {True}) | ||
per_sample {bool} -- Whether to report times or rates per sample or per batch. | ||
(default: {True}) | ||
""" | ||
|
||
def __init__(self, rate=True, per_sample=True): | ||
self.start = time.time() | ||
self.end = None | ||
self.rate = rate | ||
self.per_sample = per_sample | ||
|
||
def __call__(self, y_pred, y): | ||
self.end = time.time() | ||
elapsed = self.end - self.start | ||
self.start = self.end | ||
self.end = None | ||
|
||
if self.per_sample: | ||
elapsed /= len(y_pred) | ||
if self.rate: | ||
elapsed = 1 / elapsed | ||
|
||
return torch.tensor(elapsed) | ||
|
||
|
||
def accuracy(logits, y): | ||
_, preds = torch.max(logits, 1) | ||
return (preds == y).float().mean() | ||
|
||
|
||
def pass_epoch( | ||
model, loss_fn, loader, optimizer=None, scheduler=None, | ||
batch_metrics={'time': BatchTimer()}, show_running=True, | ||
device='cpu', writer=None | ||
): | ||
"""Train or evaluate over a data epoch. | ||
Arguments: | ||
model {torch.nn.Module} -- Pytorch model. | ||
loss_fn {callable} -- A function to compute (scalar) loss. | ||
loader {torch.utils.data.DataLoader} -- A pytorch data loader. | ||
Keyword Arguments: | ||
optimizer {torch.optim.Optimizer} -- A pytorch optimizer. | ||
scheduler {torch.optim.lr_scheduler._LRScheduler} -- LR scheduler (default: {None}) | ||
batch_metrics {dict} -- Dictionary of metric functions to call on each batch. The default | ||
is a simple timer. A progressive average of these metrics, along with the average | ||
loss, is printed every batch. (default: {{'time': iter_timer()}}) | ||
show_running {bool} -- Whether or not to print losses and metrics for the current batch | ||
or rolling averages. (default: {False}) | ||
device {str or torch.device} -- Device for pytorch to use. (default: {'cpu'}) | ||
writer {torch.utils.tensorboard.SummaryWriter} -- Tensorboard SummaryWriter. (default: {None}) | ||
Returns: | ||
tuple(torch.Tensor, dict) -- A tuple of the average loss and a dictionary of average | ||
metric values across the epoch. | ||
""" | ||
|
||
mode = 'Train' if model.training else 'Valid' | ||
logger = Logger(mode, length=len(loader), calculate_mean=show_running) | ||
loss = 0 | ||
metrics = {} | ||
|
||
for i_batch, (x, y) in enumerate(loader): | ||
x = x.to(device) | ||
y = y.to(device) | ||
y_pred = model(x) | ||
loss_batch = loss_fn(y_pred, y) | ||
|
||
if model.training: | ||
loss_batch.backward() | ||
optimizer.step() | ||
optimizer.zero_grad() | ||
|
||
metrics_batch = {} | ||
for metric_name, metric_fn in batch_metrics.items(): | ||
metrics_batch[metric_name] = metric_fn(y_pred, y).detach().cpu() | ||
metrics[metric_name] = metrics.get(metric_name, 0) + metrics_batch[metric_name] | ||
|
||
if writer is not None and model.training: | ||
if writer.iteration % writer.interval == 0: | ||
writer.add_scalars('loss', {mode: loss_batch.detach().cpu()}, writer.iteration) | ||
for metric_name, metric_batch in metrics_batch.items(): | ||
writer.add_scalars(metric_name, {mode: metric_batch}, writer.iteration) | ||
writer.iteration += 1 | ||
|
||
loss_batch = loss_batch.detach().cpu() | ||
loss += loss_batch | ||
if show_running: | ||
logger(loss, metrics, i_batch) | ||
else: | ||
logger(loss_batch, metrics_batch, i_batch) | ||
|
||
if model.training and scheduler is not None: | ||
scheduler.step() | ||
|
||
loss = loss / (i_batch + 1) | ||
metrics = {k: v / (i_batch + 1) for k, v in metrics.items()} | ||
|
||
if writer is not None and not model.training: | ||
writer.add_scalars('loss', {mode: loss.detach()}, writer.iteration) | ||
for metric_name, metric in metrics.items(): | ||
writer.add_scalars(metric_name, {mode: metric}) | ||
|
||
return loss, metrics |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,34 @@ | ||
import setuptools | ||
import setuptools, os | ||
|
||
with open('facenet_pytorch/README.md', 'r') as f: | ||
PACKAGE_NAME = 'facenet-pytorch' | ||
VERSION = '0.2.2' | ||
AUTHOR = 'Tim Esler' | ||
EMAIL = '[email protected]' | ||
DESCRIPTION = 'Pretrained Pytorch face detection and recognition models' | ||
GITHUB_URL = 'https://github.com/timesler/facenet-pytorch' | ||
|
||
parent_dir = os.path.dirname(os.path.realpath(__file__)) | ||
import_name = os.path.basename(parent_dir) | ||
|
||
with open(f'{parent_dir}/README.md', 'r') as f: | ||
long_description = f.read() | ||
|
||
setuptools.setup( | ||
name='facenet-pytorch', | ||
version='0.1.0', | ||
author='Tim Esler', | ||
author_email='[email protected]', | ||
description='Pretrained Pytorch face detection and recognition models', | ||
name=PACKAGE_NAME, | ||
version=VERSION, | ||
author=AUTHOR, | ||
author_email=EMAIL, | ||
description=DESCRIPTION, | ||
long_description=long_description, | ||
long_description_content_type='text/markdown', | ||
url='https://github.com/timesler/facenet-pytorch', | ||
url=GITHUB_URL, | ||
packages=[ | ||
'facenet_pytorch', | ||
'facenet_pytorch.models', | ||
'facenet_pytorch.models.utils', | ||
'facenet_pytorch.data', | ||
], | ||
], | ||
package_dir={'facenet_pytorch':'.'}, | ||
package_data={'': ['*net.pt']}, | ||
classifiers=[ | ||
"Programming Language :: Python :: 3", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters