-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
0 parents
commit bf20f72
Showing
65 changed files
with
7,467 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Yet another Deep Danbooru project | ||
But based on [RegNetY-8G](https://arxiv.org/abs/2003.13678), relative lightweight, designed to run fast on GPU. \ | ||
Training is done using mixed precision training on a single RTX2080Ti for 3 weeks. \ | ||
Some code are from https://github.com/facebookresearch/pycls | ||
# What do I need? | ||
You need to download [save_4000000.ckpt]() from release and place on the same folder as `test.py`. | ||
# How to use? | ||
`python test.py --model save_4000000.ckpt --image <PATH_TO_IMAGE>` | ||
# What to do in the future? | ||
1. Quantize to 8 bit | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from RegNetY_8G import build_model | ||
|
||
class RegDanbooru2019(nn.Module) : | ||
def __init__(self) : | ||
super(RegDanbooru2019, self).__init__() | ||
self.backbone = build_model() | ||
num_p = sum(p.numel() for p in self.backbone.parameters() if p.requires_grad) | ||
print( 'Backbone has %d parameters' % num_p ) | ||
self.head_danbooru = nn.Linear(2016, 4096) | ||
|
||
def forward_train_head(self, images) : | ||
""" | ||
images of shape [N, 3, 512, 512] | ||
""" | ||
with torch.no_grad() : | ||
feats = self.backbone(images) | ||
feats = F.adaptive_avg_pool2d(feats, 1).view(-1, 2016) | ||
danbooru_logits = self.head_danbooru(feats) # [N, 4096] | ||
return danbooru_logits | ||
|
||
def forward(self, images) : | ||
""" | ||
images of shape [N, 3, 512, 512] | ||
""" | ||
feats = self.backbone(images) | ||
feats = F.adaptive_avg_pool2d(feats, 1).view(-1, 2016) | ||
danbooru_logits = self.head_danbooru(feats) # [N, 4096] | ||
return danbooru_logits |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
MODEL: | ||
TYPE: regnet | ||
NUM_CLASSES: 1000 | ||
REGNET: | ||
SE_ON: true | ||
DEPTH: 17 | ||
W0: 192 | ||
WA: 76.82 | ||
WM: 2.19 | ||
GROUP_W: 56 | ||
OPTIM: | ||
LR_POLICY: cos | ||
BASE_LR: 0.4 | ||
MAX_EPOCH: 100 | ||
MOMENTUM: 0.9 | ||
WEIGHT_DECAY: 5e-5 | ||
WARMUP_EPOCHS: 5 | ||
TRAIN: | ||
DATASET: imagenet | ||
IM_SIZE: 512 | ||
BATCH_SIZE: 512 | ||
TEST: | ||
DATASET: imagenet | ||
IM_SIZE: 512 | ||
BATCH_SIZE: 400 | ||
NUM_GPUS: 1 | ||
OUT_DIR: . |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""Test a trained classification model.""" | ||
|
||
import argparse | ||
import sys | ||
|
||
import numpy as np | ||
import pycls.core.losses as losses | ||
import pycls.core.model_builder as model_builder | ||
import pycls.datasets.loader as loader | ||
import pycls.utils.benchmark as bu | ||
import pycls.utils.checkpoint as cu | ||
import pycls.utils.distributed as du | ||
import pycls.utils.logging as lu | ||
import pycls.utils.metrics as mu | ||
import pycls.utils.multiprocessing as mpu | ||
import pycls.utils.net as nu | ||
import torch | ||
from pycls.core.config import assert_and_infer_cfg, cfg | ||
from pycls.utils.meters import TestMeter | ||
|
||
|
||
logger = lu.get_logger(__name__) | ||
|
||
def log_model_info(model): | ||
"""Logs model info""" | ||
logger.info("Model:\n{}".format(model)) | ||
logger.info("Params: {:,}".format(mu.params_count(model))) | ||
logger.info("Flops: {:,}".format(mu.flops_count(model))) | ||
logger.info("Acts: {:,}".format(mu.acts_count(model))) | ||
|
||
def build_model(): | ||
|
||
# Load config options | ||
cfg.merge_from_file('RegNetY-8.0GF_dds_8gpu.yaml') | ||
cfg.merge_from_list([]) | ||
assert_and_infer_cfg() | ||
cfg.freeze() | ||
# Setup logging | ||
lu.setup_logging() | ||
# Show the config | ||
logger.info("Config:\n{}".format(cfg)) | ||
|
||
# Fix the RNG seeds (see RNG comment in core/config.py for discussion) | ||
np.random.seed(cfg.RNG_SEED) | ||
torch.manual_seed(cfg.RNG_SEED) | ||
# Configure the CUDNN backend | ||
torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK | ||
|
||
# Build the model (before the loaders to speed up debugging) | ||
model = model_builder.build_model() | ||
log_model_info(model) | ||
|
||
# Load model weights | ||
#cu.load_checkpoint('RegNetY-8.0GF_dds_8gpu.pyth', model) | ||
logger.info("Loaded model weights from: {}".format('RegNetY-8.0GF_dds_8gpu.pyth')) | ||
|
||
del model.head | ||
|
||
return model | ||
|
Large diffs are not rendered by default.
Oops, something went wrong.
Empty file.
Binary file not shown.
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""Loss functions.""" | ||
|
||
import torch.nn as nn | ||
from pycls.core.config import cfg | ||
|
||
|
||
# Supported loss functions | ||
_loss_funs = {"cross_entropy": nn.CrossEntropyLoss} | ||
|
||
|
||
def get_loss_fun(): | ||
"""Retrieves the loss function.""" | ||
assert ( | ||
cfg.MODEL.LOSS_FUN in _loss_funs.keys() | ||
), "Loss function '{}' not supported".format(cfg.TRAIN.LOSS) | ||
return _loss_funs[cfg.MODEL.LOSS_FUN]().cuda() | ||
|
||
|
||
def register_loss_fun(name, ctor): | ||
"""Registers a loss function dynamically.""" | ||
_loss_funs[name] = ctor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""Model construction functions.""" | ||
|
||
import pycls.utils.logging as lu | ||
import torch | ||
from pycls.core.config import cfg | ||
from pycls.models.anynet import AnyNet | ||
from pycls.models.effnet import EffNet | ||
from pycls.models.regnet import RegNet | ||
from pycls.models.resnet import ResNet | ||
|
||
|
||
logger = lu.get_logger(__name__) | ||
|
||
# Supported models | ||
_models = {"anynet": AnyNet, "effnet": EffNet, "resnet": ResNet, "regnet": RegNet} | ||
|
||
|
||
def build_model(): | ||
"""Builds the model.""" | ||
assert cfg.MODEL.TYPE in _models.keys(), "Model type '{}' not supported".format( | ||
cfg.MODEL.TYPE | ||
) | ||
assert ( | ||
cfg.NUM_GPUS <= torch.cuda.device_count() | ||
), "Cannot use more GPU devices than available" | ||
# Construct the model | ||
model = _models[cfg.MODEL.TYPE]() | ||
# Determine the GPU used by the current process | ||
cur_device = torch.cuda.current_device() | ||
# Transfer the model to the current GPU device | ||
model = model.cuda(device=cur_device) | ||
# Use multi-process data parallel model in the multi-gpu setting | ||
if cfg.NUM_GPUS > 1: | ||
# Make model replica operate on the current device | ||
model = torch.nn.parallel.DistributedDataParallel( | ||
module=model, device_ids=[cur_device], output_device=cur_device | ||
) | ||
return model | ||
|
||
|
||
def register_model(name, ctor): | ||
"""Registers a model dynamically.""" | ||
_models[name] = ctor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""Optimizer.""" | ||
|
||
import pycls.utils.lr_policy as lr_policy | ||
import torch | ||
from pycls.core.config import cfg | ||
|
||
|
||
def construct_optimizer(model): | ||
"""Constructs the optimizer. | ||
Note that the momentum update in PyTorch differs from the one in Caffe2. | ||
In particular, | ||
Caffe2: | ||
V := mu * V + lr * g | ||
p := p - V | ||
PyTorch: | ||
V := mu * V + g | ||
p := p - lr * V | ||
where V is the velocity, mu is the momentum factor, lr is the learning rate, | ||
g is the gradient and p are the parameters. | ||
Since V is defined independently of the learning rate in PyTorch, | ||
when the learning rate is changed there is no need to perform the | ||
momentum correction by scaling V (unlike in the Caffe2 case). | ||
""" | ||
# Batchnorm parameters. | ||
bn_params = [] | ||
# Non-batchnorm parameters. | ||
non_bn_parameters = [] | ||
for name, p in model.named_parameters(): | ||
if "bn" in name: | ||
bn_params.append(p) | ||
else: | ||
non_bn_parameters.append(p) | ||
# Apply different weight decay to Batchnorm and non-batchnorm parameters. | ||
bn_weight_decay = ( | ||
cfg.BN.CUSTOM_WEIGHT_DECAY | ||
if cfg.BN.USE_CUSTOM_WEIGHT_DECAY | ||
else cfg.OPTIM.WEIGHT_DECAY | ||
) | ||
optim_params = [ | ||
{"params": bn_params, "weight_decay": bn_weight_decay}, | ||
{"params": non_bn_parameters, "weight_decay": cfg.OPTIM.WEIGHT_DECAY}, | ||
] | ||
# Check all parameters will be passed into optimizer. | ||
assert len(list(model.parameters())) == len(non_bn_parameters) + len( | ||
bn_params | ||
), "parameter size does not match: {} + {} != {}".format( | ||
len(non_bn_parameters), len(bn_params), len(list(model.parameters())) | ||
) | ||
return torch.optim.SGD( | ||
optim_params, | ||
lr=cfg.OPTIM.BASE_LR, | ||
momentum=cfg.OPTIM.MOMENTUM, | ||
weight_decay=cfg.OPTIM.WEIGHT_DECAY, | ||
dampening=cfg.OPTIM.DAMPENING, | ||
nesterov=cfg.OPTIM.NESTEROV, | ||
) | ||
|
||
|
||
def get_epoch_lr(cur_epoch): | ||
"""Retrieves the lr for the given epoch (as specified by the lr policy).""" | ||
return lr_policy.get_epoch_lr(cur_epoch) | ||
|
||
|
||
def set_lr(optimizer, new_lr): | ||
"""Sets the optimizer lr to the specified value.""" | ||
for param_group in optimizer.param_groups: | ||
param_group["lr"] = new_lr |
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""CIFAR10 dataset.""" | ||
|
||
import os | ||
import pickle | ||
|
||
import numpy as np | ||
import pycls.datasets.transforms as transforms | ||
import pycls.utils.logging as lu | ||
import torch | ||
import torch.utils.data | ||
from pycls.core.config import cfg | ||
|
||
|
||
logger = lu.get_logger(__name__) | ||
|
||
# Per-channel mean and SD values in BGR order | ||
_MEAN = [125.3, 123.0, 113.9] | ||
_SD = [63.0, 62.1, 66.7] | ||
|
||
|
||
class Cifar10(torch.utils.data.Dataset): | ||
"""CIFAR-10 dataset.""" | ||
|
||
def __init__(self, data_path, split): | ||
assert os.path.exists(data_path), "Data path '{}' not found".format(data_path) | ||
assert split in ["train", "test"], "Split '{}' not supported for cifar".format( | ||
split | ||
) | ||
logger.info("Constructing CIFAR-10 {}...".format(split)) | ||
self._data_path = data_path | ||
self._split = split | ||
# Data format: | ||
# self._inputs - (split_size, 3, im_size, im_size) ndarray | ||
# self._labels - split_size list | ||
self._inputs, self._labels = self._load_data() | ||
|
||
def _load_batch(self, batch_path): | ||
with open(batch_path, "rb") as f: | ||
d = pickle.load(f, encoding="bytes") | ||
return d[b"data"], d[b"labels"] | ||
|
||
def _load_data(self): | ||
"""Loads data in memory.""" | ||
logger.info("{} data path: {}".format(self._split, self._data_path)) | ||
# Compute data batch names | ||
if self._split == "train": | ||
batch_names = ["data_batch_{}".format(i) for i in range(1, 6)] | ||
else: | ||
batch_names = ["test_batch"] | ||
# Load data batches | ||
inputs, labels = [], [] | ||
for batch_name in batch_names: | ||
batch_path = os.path.join(self._data_path, batch_name) | ||
inputs_batch, labels_batch = self._load_batch(batch_path) | ||
inputs.append(inputs_batch) | ||
labels += labels_batch | ||
# Combine and reshape the inputs | ||
inputs = np.vstack(inputs).astype(np.float32) | ||
inputs = inputs.reshape((-1, 3, cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE)) | ||
return inputs, labels | ||
|
||
def _prepare_im(self, im): | ||
"""Prepares the image for network input.""" | ||
im = transforms.color_norm(im, _MEAN, _SD) | ||
if self._split == "train": | ||
im = transforms.horizontal_flip(im=im, p=0.5) | ||
im = transforms.random_crop(im=im, size=cfg.TRAIN.IM_SIZE, pad_size=4) | ||
return im | ||
|
||
def __getitem__(self, index): | ||
im, label = self._inputs[index, ...].copy(), self._labels[index] | ||
im = self._prepare_im(im) | ||
return im, label | ||
|
||
def __len__(self): | ||
return self._inputs.shape[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""ImageNet dataset.""" | ||
|
||
import os | ||
import re | ||
|
||
import cv2 | ||
import numpy as np | ||
import pycls.datasets.transforms as transforms | ||
import pycls.utils.logging as lu | ||
import torch | ||
import torch.utils.data | ||
from pycls.core.config import cfg | ||
|
||
|
||
logger = lu.get_logger(__name__) | ||
|
||
# Per-channel mean and SD values in BGR order | ||
_MEAN = [0.406, 0.456, 0.485] | ||
_SD = [0.225, 0.224, 0.229] | ||
|
||
# Eig vals and vecs of the cov mat | ||
_EIG_VALS = np.array([[0.2175, 0.0188, 0.0045]]) | ||
_EIG_VECS = np.array( | ||
[[-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140], [-0.5836, -0.6948, 0.4203]] | ||
) | ||
|
||
|
||
class ImageNet(torch.utils.data.Dataset): | ||
"""ImageNet dataset.""" | ||
|
||
def __init__(self, data_path, split): | ||
assert os.path.exists(data_path), "Data path '{}' not found".format(data_path) | ||
assert split in [ | ||
"train", | ||
"val", | ||
], "Split '{}' not supported for ImageNet".format(split) | ||
logger.info("Constructing ImageNet {}...".format(split)) | ||
self._data_path = data_path | ||
self._split = split | ||
self._construct_imdb() | ||
|
||
def _construct_imdb(self): | ||
"""Constructs the imdb.""" | ||
# Compile the split data path | ||
split_path = os.path.join(self._data_path, self._split) | ||
logger.info("{} data path: {}".format(self._split, split_path)) | ||
# Images are stored per class in subdirs (format: n<number>) | ||
self._class_ids = sorted( | ||
f for f in os.listdir(split_path) if re.match(r"^n[0-9]+$", f) | ||
) | ||
# Map ImageNet class ids to contiguous ids | ||
self._class_id_cont_id = {v: i for i, v in enumerate(self._class_ids)} | ||
# Construct the image db | ||
self._imdb = [] | ||
for class_id in self._class_ids: | ||
cont_id = self._class_id_cont_id[class_id] | ||
im_dir = os.path.join(split_path, class_id) | ||
for im_name in os.listdir(im_dir): | ||
self._imdb.append( | ||
{"im_path": os.path.join(im_dir, im_name), "class": cont_id} | ||
) | ||
logger.info("Number of images: {}".format(len(self._imdb))) | ||
logger.info("Number of classes: {}".format(len(self._class_ids))) | ||
|
||
def _prepare_im(self, im): | ||
"""Prepares the image for network input.""" | ||
# Train and test setups differ | ||
if self._split == "train": | ||
# Scale and aspect ratio | ||
im = transforms.random_sized_crop( | ||
im=im, size=cfg.TRAIN.IM_SIZE, area_frac=0.08 | ||
) | ||
# Horizontal flip | ||
im = transforms.horizontal_flip(im=im, p=0.5, order="HWC") | ||
else: | ||
# Scale and center crop | ||
im = transforms.scale(cfg.TEST.IM_SIZE, im) | ||
im = transforms.center_crop(cfg.TRAIN.IM_SIZE, im) | ||
# HWC -> CHW | ||
im = im.transpose([2, 0, 1]) | ||
# [0, 255] -> [0, 1] | ||
im = im / 255.0 | ||
# PCA jitter | ||
if self._split == "train": | ||
im = transforms.lighting(im, 0.1, _EIG_VALS, _EIG_VECS) | ||
# Color normalization | ||
im = transforms.color_norm(im, _MEAN, _SD) | ||
return im | ||
|
||
def __getitem__(self, index): | ||
# Load the image | ||
im = cv2.imread(self._imdb[index]["im_path"]) | ||
im = im.astype(np.float32, copy=False) | ||
# Prepare the image for training / testing | ||
im = self._prepare_im(im) | ||
# Retrieve the label | ||
label = self._imdb[index]["class"] | ||
return im, label | ||
|
||
def __len__(self): | ||
return len(self._imdb) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""Data loader.""" | ||
|
||
import pycls.datasets.paths as dp | ||
import torch | ||
from pycls.core.config import cfg | ||
from pycls.datasets.cifar10 import Cifar10 | ||
from pycls.datasets.imagenet import ImageNet | ||
from torch.utils.data.distributed import DistributedSampler | ||
from torch.utils.data.sampler import RandomSampler | ||
|
||
|
||
# Supported datasets | ||
_DATASET_CATALOG = {"cifar10": Cifar10, "imagenet": ImageNet} | ||
|
||
|
||
def _construct_loader(dataset_name, split, batch_size, shuffle, drop_last): | ||
"""Constructs the data loader for the given dataset.""" | ||
assert dataset_name in _DATASET_CATALOG.keys(), "Dataset '{}' not supported".format( | ||
dataset_name | ||
) | ||
assert dp.has_data_path(dataset_name), "Dataset '{}' has no data path".format( | ||
dataset_name | ||
) | ||
# Retrieve the data path for the dataset | ||
data_path = dp.get_data_path(dataset_name) | ||
# Construct the dataset | ||
dataset = _DATASET_CATALOG[dataset_name](data_path, split) | ||
# Create a sampler for multi-process training | ||
sampler = DistributedSampler(dataset) if cfg.NUM_GPUS > 1 else None | ||
# Create a loader | ||
loader = torch.utils.data.DataLoader( | ||
dataset, | ||
batch_size=batch_size, | ||
shuffle=(False if sampler else shuffle), | ||
sampler=sampler, | ||
num_workers=cfg.DATA_LOADER.NUM_WORKERS, | ||
pin_memory=cfg.DATA_LOADER.PIN_MEMORY, | ||
drop_last=drop_last, | ||
) | ||
return loader | ||
|
||
|
||
def construct_train_loader(): | ||
"""Train loader wrapper.""" | ||
return _construct_loader( | ||
dataset_name=cfg.TRAIN.DATASET, | ||
split=cfg.TRAIN.SPLIT, | ||
batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS), | ||
shuffle=True, | ||
drop_last=True, | ||
) | ||
|
||
|
||
def construct_test_loader(): | ||
"""Test loader wrapper.""" | ||
return _construct_loader( | ||
dataset_name=cfg.TEST.DATASET, | ||
split=cfg.TEST.SPLIT, | ||
batch_size=int(cfg.TEST.BATCH_SIZE / cfg.NUM_GPUS), | ||
shuffle=False, | ||
drop_last=False, | ||
) | ||
|
||
|
||
def shuffle(loader, cur_epoch): | ||
""""Shuffles the data.""" | ||
assert isinstance( | ||
loader.sampler, (RandomSampler, DistributedSampler) | ||
), "Sampler type '{}' not supported".format(type(loader.sampler)) | ||
# RandomSampler handles shuffling automatically | ||
if isinstance(loader.sampler, DistributedSampler): | ||
# DistributedSampler shuffles data based on epoch | ||
loader.sampler.set_epoch(cur_epoch) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""Dataset paths.""" | ||
|
||
import os | ||
|
||
|
||
# Default data directory (/path/pycls/pycls/datasets/data) | ||
_DEF_DATA_DIR = os.path.join(os.path.dirname(__file__), "data") | ||
|
||
# Data paths | ||
_paths = { | ||
"cifar10": _DEF_DATA_DIR + "/cifar10", | ||
"imagenet": _DEF_DATA_DIR + "/imagenet", | ||
} | ||
|
||
|
||
def has_data_path(dataset_name): | ||
"""Determines if the dataset has a data path.""" | ||
return dataset_name in _paths.keys() | ||
|
||
|
||
def get_data_path(dataset_name): | ||
"""Retrieves data path for the dataset.""" | ||
return _paths[dataset_name] | ||
|
||
|
||
def register_path(name, path): | ||
"""Registers a dataset path dynamically.""" | ||
_paths[name] = path |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""Image transformations.""" | ||
|
||
import math | ||
|
||
import cv2 | ||
import numpy as np | ||
|
||
|
||
def color_norm(im, mean, std): | ||
"""Performs per-channel normalization (CHW format).""" | ||
for i in range(im.shape[0]): | ||
im[i] = im[i] - mean[i] | ||
im[i] = im[i] / std[i] | ||
return im | ||
|
||
|
||
def zero_pad(im, pad_size): | ||
"""Performs zero padding (CHW format).""" | ||
pad_width = ((0, 0), (pad_size, pad_size), (pad_size, pad_size)) | ||
return np.pad(im, pad_width, mode="constant") | ||
|
||
|
||
def horizontal_flip(im, p, order="CHW"): | ||
"""Performs horizontal flip (CHW or HWC format).""" | ||
assert order in ["CHW", "HWC"] | ||
if np.random.uniform() < p: | ||
if order == "CHW": | ||
im = im[:, :, ::-1] | ||
else: | ||
im = im[:, ::-1, :] | ||
return im | ||
|
||
|
||
def random_crop(im, size, pad_size=0): | ||
"""Performs random crop (CHW format).""" | ||
if pad_size > 0: | ||
im = zero_pad(im=im, pad_size=pad_size) | ||
h, w = im.shape[1:] | ||
y = np.random.randint(0, h - size) | ||
x = np.random.randint(0, w - size) | ||
im_crop = im[:, y : (y + size), x : (x + size)] | ||
assert im_crop.shape[1:] == (size, size) | ||
return im_crop | ||
|
||
|
||
def scale(size, im): | ||
"""Performs scaling (HWC format).""" | ||
h, w = im.shape[:2] | ||
if (w <= h and w == size) or (h <= w and h == size): | ||
return im | ||
h_new, w_new = size, size | ||
if w < h: | ||
h_new = int(math.floor((float(h) / w) * size)) | ||
else: | ||
w_new = int(math.floor((float(w) / h) * size)) | ||
im = cv2.resize(im, (w_new, h_new), interpolation=cv2.INTER_LINEAR) | ||
return im.astype(np.float32) | ||
|
||
|
||
def center_crop(size, im): | ||
"""Performs center cropping (HWC format).""" | ||
h, w = im.shape[:2] | ||
y = int(math.ceil((h - size) / 2)) | ||
x = int(math.ceil((w - size) / 2)) | ||
im_crop = im[y : (y + size), x : (x + size), :] | ||
assert im_crop.shape[:2] == (size, size) | ||
return im_crop | ||
|
||
|
||
def random_sized_crop(im, size, area_frac=0.08, max_iter=10): | ||
"""Performs Inception-style cropping (HWC format).""" | ||
h, w = im.shape[:2] | ||
area = h * w | ||
for _ in range(max_iter): | ||
target_area = np.random.uniform(area_frac, 1.0) * area | ||
aspect_ratio = np.random.uniform(3.0 / 4.0, 4.0 / 3.0) | ||
w_crop = int(round(math.sqrt(float(target_area) * aspect_ratio))) | ||
h_crop = int(round(math.sqrt(float(target_area) / aspect_ratio))) | ||
if np.random.uniform() < 0.5: | ||
w_crop, h_crop = h_crop, w_crop | ||
if h_crop <= h and w_crop <= w: | ||
y = 0 if h_crop == h else np.random.randint(0, h - h_crop) | ||
x = 0 if w_crop == w else np.random.randint(0, w - w_crop) | ||
im_crop = im[y : (y + h_crop), x : (x + w_crop), :] | ||
assert im_crop.shape[:2] == (h_crop, w_crop) | ||
im_crop = cv2.resize(im_crop, (size, size), interpolation=cv2.INTER_LINEAR) | ||
return im_crop.astype(np.float32) | ||
return center_crop(size, scale(size, im)) | ||
|
||
|
||
def lighting(im, alpha_std, eig_val, eig_vec): | ||
"""Performs AlexNet-style PCA jitter (CHW format).""" | ||
if alpha_std == 0: | ||
return im | ||
alpha = np.random.normal(0, alpha_std, size=(1, 3)) | ||
rgb = np.sum( | ||
eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0), axis=1 | ||
) | ||
for i in range(im.shape[0]): | ||
im[i] = im[i] + rgb[2 - i] | ||
return im |
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,380 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""AnyNet models.""" | ||
|
||
import pycls.utils.logging as lu | ||
import pycls.utils.net as nu | ||
import torch.nn as nn | ||
from pycls.core.config import cfg | ||
|
||
|
||
logger = lu.get_logger(__name__) | ||
|
||
|
||
def get_stem_fun(stem_type): | ||
"""Retrives the stem function by name.""" | ||
stem_funs = { | ||
"res_stem_cifar": ResStemCifar, | ||
"res_stem_in": ResStemIN, | ||
"simple_stem_in": SimpleStemIN, | ||
} | ||
assert stem_type in stem_funs.keys(), "Stem type '{}' not supported".format( | ||
stem_type | ||
) | ||
return stem_funs[stem_type] | ||
|
||
|
||
def get_block_fun(block_type): | ||
"""Retrieves the block function by name.""" | ||
block_funs = { | ||
"vanilla_block": VanillaBlock, | ||
"res_basic_block": ResBasicBlock, | ||
"res_bottleneck_block": ResBottleneckBlock, | ||
} | ||
assert block_type in block_funs.keys(), "Block type '{}' not supported".format( | ||
block_type | ||
) | ||
return block_funs[block_type] | ||
|
||
|
||
class AnyHead(nn.Module): | ||
"""AnyNet head.""" | ||
|
||
def __init__(self, w_in, nc): | ||
super(AnyHead, self).__init__() | ||
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) | ||
self.fc = nn.Linear(w_in, nc, bias=True) | ||
|
||
def forward(self, x): | ||
x = self.avg_pool(x) | ||
x = x.view(x.size(0), -1) | ||
x = self.fc(x) | ||
return x | ||
|
||
|
||
class VanillaBlock(nn.Module): | ||
"""Vanilla block: [3x3 conv, BN, Relu] x2""" | ||
|
||
def __init__(self, w_in, w_out, stride, bm=None, gw=None, se_r=None): | ||
assert ( | ||
bm is None and gw is None and se_r is None | ||
), "Vanilla block does not support bm, gw, and se_r options" | ||
super(VanillaBlock, self).__init__() | ||
self._construct(w_in, w_out, stride) | ||
|
||
def _construct(self, w_in, w_out, stride): | ||
# 3x3, BN, ReLU | ||
self.a = nn.Conv2d( | ||
w_in, w_out, kernel_size=3, stride=stride, padding=1, bias=False | ||
) | ||
self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) | ||
# 3x3, BN, ReLU | ||
self.b = nn.Conv2d(w_out, w_out, kernel_size=3, stride=1, padding=1, bias=False) | ||
self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||
self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) | ||
|
||
def forward(self, x): | ||
for layer in self.children(): | ||
x = layer(x) | ||
return x | ||
|
||
|
||
class BasicTransform(nn.Module): | ||
"""Basic transformation: [3x3 conv, BN, Relu] x2""" | ||
|
||
def __init__(self, w_in, w_out, stride): | ||
super(BasicTransform, self).__init__() | ||
self._construct(w_in, w_out, stride) | ||
|
||
def _construct(self, w_in, w_out, stride): | ||
# 3x3, BN, ReLU | ||
self.a = nn.Conv2d( | ||
w_in, w_out, kernel_size=3, stride=stride, padding=1, bias=False | ||
) | ||
self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) | ||
# 3x3, BN | ||
self.b = nn.Conv2d(w_out, w_out, kernel_size=3, stride=1, padding=1, bias=False) | ||
self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||
self.b_bn.final_bn = True | ||
|
||
def forward(self, x): | ||
for layer in self.children(): | ||
x = layer(x) | ||
return x | ||
|
||
|
||
class ResBasicBlock(nn.Module): | ||
"""Residual basic block: x + F(x), F = basic transform""" | ||
|
||
def __init__(self, w_in, w_out, stride, bm=None, gw=None, se_r=None): | ||
assert ( | ||
bm is None and gw is None and se_r is None | ||
), "Basic transform does not support bm, gw, and se_r options" | ||
super(ResBasicBlock, self).__init__() | ||
self._construct(w_in, w_out, stride) | ||
|
||
def _add_skip_proj(self, w_in, w_out, stride): | ||
self.proj = nn.Conv2d( | ||
w_in, w_out, kernel_size=1, stride=stride, padding=0, bias=False | ||
) | ||
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||
|
||
def _construct(self, w_in, w_out, stride): | ||
# Use skip connection with projection if shape changes | ||
self.proj_block = (w_in != w_out) or (stride != 1) | ||
if self.proj_block: | ||
self._add_skip_proj(w_in, w_out, stride) | ||
self.f = BasicTransform(w_in, w_out, stride) | ||
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) | ||
|
||
def forward(self, x): | ||
if self.proj_block: | ||
x = self.bn(self.proj(x)) + self.f(x) | ||
else: | ||
x = x + self.f(x) | ||
x = self.relu(x) | ||
return x | ||
|
||
|
||
class SE(nn.Module): | ||
"""Squeeze-and-Excitation (SE) block""" | ||
|
||
def __init__(self, w_in, w_se): | ||
super(SE, self).__init__() | ||
self._construct(w_in, w_se) | ||
|
||
def _construct(self, w_in, w_se): | ||
# AvgPool | ||
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) | ||
# FC, Activation, FC, Sigmoid | ||
self.f_ex = nn.Sequential( | ||
nn.Conv2d(w_in, w_se, kernel_size=1, bias=True), | ||
nn.ReLU(inplace=cfg.MEM.RELU_INPLACE), | ||
nn.Conv2d(w_se, w_in, kernel_size=1, bias=True), | ||
nn.Sigmoid(), | ||
) | ||
|
||
def forward(self, x): | ||
return x * self.f_ex(self.avg_pool(x)) | ||
|
||
|
||
class BottleneckTransform(nn.Module): | ||
"""Bottlenect transformation: 1x1, 3x3, 1x1""" | ||
|
||
def __init__(self, w_in, w_out, stride, bm, gw, se_r): | ||
super(BottleneckTransform, self).__init__() | ||
self._construct(w_in, w_out, stride, bm, gw, se_r) | ||
|
||
def _construct(self, w_in, w_out, stride, bm, gw, se_r): | ||
# Compute the bottleneck width | ||
w_b = int(round(w_out * bm)) | ||
# Compute the number of groups | ||
num_gs = w_b // gw | ||
# 1x1, BN, ReLU | ||
self.a = nn.Conv2d(w_in, w_b, kernel_size=1, stride=1, padding=0, bias=False) | ||
self.a_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) | ||
# 3x3, BN, ReLU | ||
self.b = nn.Conv2d( | ||
w_b, w_b, kernel_size=3, stride=stride, padding=1, groups=num_gs, bias=False | ||
) | ||
self.b_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||
self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) | ||
# Squeeze-and-Excitation (SE) | ||
if se_r: | ||
w_se = int(round(w_in * se_r)) | ||
self.se = SE(w_b, w_se) | ||
# 1x1, BN | ||
self.c = nn.Conv2d(w_b, w_out, kernel_size=1, stride=1, padding=0, bias=False) | ||
self.c_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||
self.c_bn.final_bn = True | ||
|
||
def forward(self, x): | ||
for layer in self.children(): | ||
x = layer(x) | ||
return x | ||
|
||
|
||
class ResBottleneckBlock(nn.Module): | ||
"""Residual bottleneck block: x + F(x), F = bottleneck transform""" | ||
|
||
def __init__(self, w_in, w_out, stride, bm=1.0, gw=1, se_r=None): | ||
super(ResBottleneckBlock, self).__init__() | ||
self._construct(w_in, w_out, stride, bm, gw, se_r) | ||
|
||
def _add_skip_proj(self, w_in, w_out, stride): | ||
self.proj = nn.Conv2d( | ||
w_in, w_out, kernel_size=1, stride=stride, padding=0, bias=False | ||
) | ||
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||
|
||
def _construct(self, w_in, w_out, stride, bm, gw, se_r): | ||
# Use skip connection with projection if shape changes | ||
self.proj_block = (w_in != w_out) or (stride != 1) | ||
if self.proj_block: | ||
self._add_skip_proj(w_in, w_out, stride) | ||
self.f = BottleneckTransform(w_in, w_out, stride, bm, gw, se_r) | ||
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) | ||
|
||
def forward(self, x): | ||
if self.proj_block: | ||
x = self.bn(self.proj(x)) + self.f(x) | ||
else: | ||
x = x + self.f(x) | ||
x = self.relu(x) | ||
return x | ||
|
||
|
||
class ResStemCifar(nn.Module): | ||
"""ResNet stem for CIFAR.""" | ||
|
||
def __init__(self, w_in, w_out): | ||
super(ResStemCifar, self).__init__() | ||
self._construct(w_in, w_out) | ||
|
||
def _construct(self, w_in, w_out): | ||
# 3x3, BN, ReLU | ||
self.conv = nn.Conv2d( | ||
w_in, w_out, kernel_size=3, stride=1, padding=1, bias=False | ||
) | ||
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) | ||
|
||
def forward(self, x): | ||
for layer in self.children(): | ||
x = layer(x) | ||
return x | ||
|
||
|
||
class ResStemIN(nn.Module): | ||
"""ResNet stem for ImageNet.""" | ||
|
||
def __init__(self, w_in, w_out): | ||
super(ResStemIN, self).__init__() | ||
self._construct(w_in, w_out) | ||
|
||
def _construct(self, w_in, w_out): | ||
# 7x7, BN, ReLU, maxpool | ||
self.conv = nn.Conv2d( | ||
w_in, w_out, kernel_size=7, stride=2, padding=3, bias=False | ||
) | ||
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) | ||
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | ||
|
||
def forward(self, x): | ||
for layer in self.children(): | ||
x = layer(x) | ||
return x | ||
|
||
|
||
class SimpleStemIN(nn.Module): | ||
"""Simple stem for ImageNet.""" | ||
|
||
def __init__(self, in_w, out_w): | ||
super(SimpleStemIN, self).__init__() | ||
self._construct(in_w, out_w) | ||
|
||
def _construct(self, in_w, out_w): | ||
# 3x3, BN, ReLU | ||
self.conv = nn.Conv2d( | ||
in_w, out_w, kernel_size=3, stride=2, padding=1, bias=False | ||
) | ||
self.bn = nn.BatchNorm2d(out_w, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) | ||
|
||
def forward(self, x): | ||
for layer in self.children(): | ||
x = layer(x) | ||
return x | ||
|
||
|
||
class AnyStage(nn.Module): | ||
"""AnyNet stage (sequence of blocks w/ the same output shape).""" | ||
|
||
def __init__(self, w_in, w_out, stride, d, block_fun, bm, gw, se_r): | ||
super(AnyStage, self).__init__() | ||
self._construct(w_in, w_out, stride, d, block_fun, bm, gw, se_r) | ||
|
||
def _construct(self, w_in, w_out, stride, d, block_fun, bm, gw, se_r): | ||
# Construct the blocks | ||
for i in range(d): | ||
# Stride and w_in apply to the first block of the stage | ||
b_stride = stride if i == 0 else 1 | ||
b_w_in = w_in if i == 0 else w_out | ||
# Construct the block | ||
self.add_module( | ||
"b{}".format(i + 1), block_fun(b_w_in, w_out, b_stride, bm, gw, se_r) | ||
) | ||
|
||
def forward(self, x): | ||
for block in self.children(): | ||
x = block(x) | ||
return x | ||
|
||
|
||
class AnyNet(nn.Module): | ||
"""AnyNet model.""" | ||
|
||
def __init__(self, **kwargs): | ||
super(AnyNet, self).__init__() | ||
if kwargs: | ||
self._construct( | ||
stem_type=kwargs["stem_type"], | ||
stem_w=kwargs["stem_w"], | ||
block_type=kwargs["block_type"], | ||
ds=kwargs["ds"], | ||
ws=kwargs["ws"], | ||
ss=kwargs["ss"], | ||
bms=kwargs["bms"], | ||
gws=kwargs["gws"], | ||
se_r=kwargs["se_r"], | ||
nc=kwargs["nc"], | ||
) | ||
else: | ||
self._construct( | ||
stem_type=cfg.ANYNET.STEM_TYPE, | ||
stem_w=cfg.ANYNET.STEM_W, | ||
block_type=cfg.ANYNET.BLOCK_TYPE, | ||
ds=cfg.ANYNET.DEPTHS, | ||
ws=cfg.ANYNET.WIDTHS, | ||
ss=cfg.ANYNET.STRIDES, | ||
bms=cfg.ANYNET.BOT_MULS, | ||
gws=cfg.ANYNET.GROUP_WS, | ||
se_r=cfg.ANYNET.SE_R if cfg.ANYNET.SE_ON else None, | ||
nc=cfg.MODEL.NUM_CLASSES, | ||
) | ||
self.apply(nu.init_weights) | ||
|
||
def _construct(self, stem_type, stem_w, block_type, ds, ws, ss, bms, gws, se_r, nc): | ||
logger.info("Constructing AnyNet: ds={}, ws={}".format(ds, ws)) | ||
# Generate dummy bot muls and gs for models that do not use them | ||
bms = bms if bms else [1.0 for _d in ds] | ||
gws = gws if gws else [1 for _d in ds] | ||
# Group params by stage | ||
stage_params = list(zip(ds, ws, ss, bms, gws)) | ||
# Construct the stem | ||
stem_fun = get_stem_fun(stem_type) | ||
self.stem = stem_fun(3, stem_w) | ||
# Construct the stages | ||
block_fun = get_block_fun(block_type) | ||
prev_w = stem_w | ||
for i, (d, w, s, bm, gw) in enumerate(stage_params): | ||
self.add_module( | ||
"s{}".format(i + 1), AnyStage(prev_w, w, s, d, block_fun, bm, gw, se_r) | ||
) | ||
prev_w = w | ||
# Construct the head | ||
self.head = AnyHead(w_in=prev_w, nc=nc) | ||
|
||
def forward(self, x): | ||
for module in self.children(): | ||
x = module(x) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,235 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""EfficientNet models.""" | ||
|
||
import pycls.utils.logging as logging | ||
import pycls.utils.net as nu | ||
import torch | ||
import torch.nn as nn | ||
from pycls.core.config import cfg | ||
|
||
|
||
logger = logging.get_logger(__name__) | ||
|
||
|
||
class EffHead(nn.Module): | ||
"""EfficientNet head.""" | ||
|
||
def __init__(self, w_in, w_out, nc): | ||
super(EffHead, self).__init__() | ||
self._construct(w_in, w_out, nc) | ||
|
||
def _construct(self, w_in, w_out, nc): | ||
# 1x1, BN, Swish | ||
self.conv = nn.Conv2d( | ||
w_in, w_out, kernel_size=1, stride=1, padding=0, bias=False | ||
) | ||
self.conv_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||
self.conv_swish = Swish() | ||
# AvgPool | ||
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) | ||
# Dropout | ||
if cfg.EN.DROPOUT_RATIO > 0.0: | ||
self.dropout = nn.Dropout(p=cfg.EN.DROPOUT_RATIO) | ||
# FC | ||
self.fc = nn.Linear(w_out, nc, bias=True) | ||
|
||
def forward(self, x): | ||
x = self.conv_swish(self.conv_bn(self.conv(x))) | ||
x = self.avg_pool(x) | ||
x = x.view(x.size(0), -1) | ||
x = self.dropout(x) if hasattr(self, "dropout") else x | ||
x = self.fc(x) | ||
return x | ||
|
||
|
||
class Swish(nn.Module): | ||
"""Swish activation function: x * sigmoid(x)""" | ||
|
||
def __init__(self): | ||
super(Swish, self).__init__() | ||
|
||
def forward(self, x): | ||
return x * torch.sigmoid(x) | ||
|
||
|
||
class SE(nn.Module): | ||
"""Squeeze-and-Excitation (SE) block w/ Swish.""" | ||
|
||
def __init__(self, w_in, w_se): | ||
super(SE, self).__init__() | ||
self._construct(w_in, w_se) | ||
|
||
def _construct(self, w_in, w_se): | ||
# AvgPool | ||
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) | ||
# FC, Swish, FC, Sigmoid | ||
self.f_ex = nn.Sequential( | ||
nn.Conv2d(w_in, w_se, kernel_size=1, bias=True), | ||
Swish(), | ||
nn.Conv2d(w_se, w_in, kernel_size=1, bias=True), | ||
nn.Sigmoid(), | ||
) | ||
|
||
def forward(self, x): | ||
return x * self.f_ex(self.avg_pool(x)) | ||
|
||
|
||
class MBConv(nn.Module): | ||
"""Mobile inverted bottleneck block w/ SE (MBConv).""" | ||
|
||
def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out): | ||
super(MBConv, self).__init__() | ||
self._construct(w_in, exp_r, kernel, stride, se_r, w_out) | ||
|
||
def _construct(self, w_in, exp_r, kernel, stride, se_r, w_out): | ||
# Expansion ratio is wrt the input width | ||
self.exp = None | ||
w_exp = int(w_in * exp_r) | ||
# Include exp ops only if the exp ratio is different from 1 | ||
if w_exp != w_in: | ||
# 1x1, BN, Swish | ||
self.exp = nn.Conv2d( | ||
w_in, w_exp, kernel_size=1, stride=1, padding=0, bias=False | ||
) | ||
self.exp_bn = nn.BatchNorm2d(w_exp, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||
self.exp_swish = Swish() | ||
# 3x3 dwise, BN, Swish | ||
self.dwise = nn.Conv2d( | ||
w_exp, | ||
w_exp, | ||
kernel_size=kernel, | ||
stride=stride, | ||
groups=w_exp, | ||
bias=False, | ||
# Hacky padding to preserve res (supports only 3x3 and 5x5) | ||
padding=(1 if kernel == 3 else 2), | ||
) | ||
self.dwise_bn = nn.BatchNorm2d(w_exp, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||
self.dwise_swish = Swish() | ||
# Squeeze-and-Excitation (SE) | ||
w_se = int(w_in * se_r) | ||
self.se = SE(w_exp, w_se) | ||
# 1x1, BN | ||
self.lin_proj = nn.Conv2d( | ||
w_exp, w_out, kernel_size=1, stride=1, padding=0, bias=False | ||
) | ||
self.lin_proj_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||
# Skip connection if in and out shapes are the same (MN-V2 style) | ||
self.has_skip = (stride == 1) and (w_in == w_out) | ||
|
||
def forward(self, x): | ||
f_x = x | ||
# Expansion | ||
if self.exp: | ||
f_x = self.exp_swish(self.exp_bn(self.exp(f_x))) | ||
# Depthwise | ||
f_x = self.dwise_swish(self.dwise_bn(self.dwise(f_x))) | ||
# SE | ||
f_x = self.se(f_x) | ||
# Linear projection | ||
f_x = self.lin_proj_bn(self.lin_proj(f_x)) | ||
# Skip connection | ||
if self.has_skip: | ||
# Drop connect | ||
if self.training and cfg.EN.DC_RATIO > 0.0: | ||
f_x = nu.drop_connect(f_x, cfg.EN.DC_RATIO) | ||
f_x = x + f_x | ||
return f_x | ||
|
||
|
||
class EffStage(nn.Module): | ||
"""EfficientNet stage.""" | ||
|
||
def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out, d): | ||
super(EffStage, self).__init__() | ||
self._construct(w_in, exp_r, kernel, stride, se_r, w_out, d) | ||
|
||
def _construct(self, w_in, exp_r, kernel, stride, se_r, w_out, d): | ||
# Construct the blocks | ||
for i in range(d): | ||
# Stride and input width apply to the first block of the stage | ||
b_stride = stride if i == 0 else 1 | ||
b_w_in = w_in if i == 0 else w_out | ||
# Construct the block | ||
self.add_module( | ||
"b{}".format(i + 1), | ||
MBConv(b_w_in, exp_r, kernel, b_stride, se_r, w_out), | ||
) | ||
|
||
def forward(self, x): | ||
for block in self.children(): | ||
x = block(x) | ||
return x | ||
|
||
|
||
class StemIN(nn.Module): | ||
"""EfficientNet stem for ImageNet.""" | ||
|
||
def __init__(self, w_in, w_out): | ||
super(StemIN, self).__init__() | ||
self._construct(w_in, w_out) | ||
|
||
def _construct(self, w_in, w_out): | ||
# 3x3, BN, Swish | ||
self.conv = nn.Conv2d( | ||
w_in, w_out, kernel_size=3, stride=2, padding=1, bias=False | ||
) | ||
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||
self.swish = Swish() | ||
|
||
def forward(self, x): | ||
for layer in self.children(): | ||
x = layer(x) | ||
return x | ||
|
||
|
||
class EffNet(nn.Module): | ||
"""EfficientNet model.""" | ||
|
||
def __init__(self): | ||
assert cfg.TRAIN.DATASET in [ | ||
"imagenet" | ||
], "Training on {} is not supported".format(cfg.TRAIN.DATASET) | ||
assert cfg.TEST.DATASET in [ | ||
"imagenet" | ||
], "Testing on {} is not supported".format(cfg.TEST.DATASET) | ||
super(EffNet, self).__init__() | ||
self._construct( | ||
stem_w=cfg.EN.STEM_W, | ||
ds=cfg.EN.DEPTHS, | ||
ws=cfg.EN.WIDTHS, | ||
exp_rs=cfg.EN.EXP_RATIOS, | ||
se_r=cfg.EN.SE_R, | ||
ss=cfg.EN.STRIDES, | ||
ks=cfg.EN.KERNELS, | ||
head_w=cfg.EN.HEAD_W, | ||
nc=cfg.MODEL.NUM_CLASSES, | ||
) | ||
self.apply(nu.init_weights) | ||
|
||
def _construct(self, stem_w, ds, ws, exp_rs, se_r, ss, ks, head_w, nc): | ||
# Group params by stage | ||
stage_params = list(zip(ds, ws, exp_rs, ss, ks)) | ||
logger.info("Constructing: EfficientNet-{}".format(stage_params)) | ||
# Construct the stem | ||
self.stem = StemIN(3, stem_w) | ||
prev_w = stem_w | ||
# Construct the stages | ||
for i, (d, w, exp_r, stride, kernel) in enumerate(stage_params): | ||
self.add_module( | ||
"s{}".format(i + 1), EffStage(prev_w, exp_r, kernel, stride, se_r, w, d) | ||
) | ||
prev_w = w | ||
# Construct the head | ||
self.head = EffHead(prev_w, head_w, nc) | ||
|
||
def forward(self, x): | ||
for module in self.children(): | ||
x = module(x) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""RegNet models.""" | ||
|
||
import numpy as np | ||
import pycls.utils.logging as lu | ||
from pycls.core.config import cfg | ||
from pycls.models.anynet import AnyNet | ||
|
||
|
||
logger = lu.get_logger(__name__) | ||
|
||
|
||
def quantize_float(f, q): | ||
"""Converts a float to closest non-zero int divisible by q.""" | ||
return int(round(f / q) * q) | ||
|
||
|
||
def adjust_ws_gs_comp(ws, bms, gs): | ||
"""Adjusts the compatibility of widths and groups.""" | ||
ws_bot = [int(w * b) for w, b in zip(ws, bms)] | ||
gs = [min(g, w_bot) for g, w_bot in zip(gs, ws_bot)] | ||
ws_bot = [quantize_float(w_bot, g) for w_bot, g in zip(ws_bot, gs)] | ||
ws = [int(w_bot / b) for w_bot, b in zip(ws_bot, bms)] | ||
return ws, gs | ||
|
||
|
||
def get_stages_from_blocks(ws, rs): | ||
"""Gets ws/ds of network at each stage from per block values.""" | ||
ts_temp = zip(ws + [0], [0] + ws, rs + [0], [0] + rs) | ||
ts = [w != wp or r != rp for w, wp, r, rp in ts_temp] | ||
s_ws = [w for w, t in zip(ws, ts[:-1]) if t] | ||
s_ds = np.diff([d for d, t in zip(range(len(ts)), ts) if t]).tolist() | ||
return s_ws, s_ds | ||
|
||
|
||
def generate_regnet(w_a, w_0, w_m, d, q=8): | ||
"""Generates per block ws from RegNet parameters.""" | ||
assert w_a >= 0 and w_0 > 0 and w_m > 1 and w_0 % q == 0 | ||
ws_cont = np.arange(d) * w_a + w_0 | ||
ks = np.round(np.log(ws_cont / w_0) / np.log(w_m)) | ||
ws = w_0 * np.power(w_m, ks) | ||
ws = np.round(np.divide(ws, q)) * q | ||
num_stages, max_stage = len(np.unique(ws)), ks.max() + 1 | ||
ws, ws_cont = ws.astype(int).tolist(), ws_cont.tolist() | ||
return ws, num_stages, max_stage, ws_cont | ||
|
||
|
||
class RegNet(AnyNet): | ||
"""RegNet model.""" | ||
|
||
def __init__(self): | ||
# Generate RegNet ws per block | ||
b_ws, num_s, _, _ = generate_regnet( | ||
cfg.REGNET.WA, cfg.REGNET.W0, cfg.REGNET.WM, cfg.REGNET.DEPTH | ||
) | ||
# Convert to per stage format | ||
ws, ds = get_stages_from_blocks(b_ws, b_ws) | ||
# Generate group widths and bot muls | ||
gws = [cfg.REGNET.GROUP_W for _ in range(num_s)] | ||
bms = [cfg.REGNET.BOT_MUL for _ in range(num_s)] | ||
# Adjust the compatibility of ws and gws | ||
ws, gws = adjust_ws_gs_comp(ws, bms, gws) | ||
# Use the same stride for each stage | ||
ss = [cfg.REGNET.STRIDE for _ in range(num_s)] | ||
# Use SE for RegNetY | ||
se_r = cfg.REGNET.SE_R if cfg.REGNET.SE_ON else None | ||
# Construct the model | ||
kwargs = { | ||
"stem_type": cfg.REGNET.STEM_TYPE, | ||
"stem_w": cfg.REGNET.STEM_W, | ||
"block_type": cfg.REGNET.BLOCK_TYPE, | ||
"ss": ss, | ||
"ds": ds, | ||
"ws": ws, | ||
"bms": bms, | ||
"gws": gws, | ||
"se_r": se_r, | ||
"nc": cfg.MODEL.NUM_CLASSES, | ||
} | ||
super(RegNet, self).__init__(**kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,275 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""ResNe(X)t models.""" | ||
|
||
import pycls.utils.logging as lu | ||
import pycls.utils.net as nu | ||
import torch.nn as nn | ||
from pycls.core.config import cfg | ||
|
||
|
||
logger = lu.get_logger(__name__) | ||
|
||
|
||
# Stage depths for ImageNet models | ||
_IN_STAGE_DS = {50: (3, 4, 6, 3), 101: (3, 4, 23, 3), 152: (3, 8, 36, 3)} | ||
|
||
|
||
def get_trans_fun(name): | ||
"""Retrieves the transformation function by name.""" | ||
trans_funs = { | ||
"basic_transform": BasicTransform, | ||
"bottleneck_transform": BottleneckTransform, | ||
} | ||
assert ( | ||
name in trans_funs.keys() | ||
), "Transformation function '{}' not supported".format(name) | ||
return trans_funs[name] | ||
|
||
|
||
class ResHead(nn.Module): | ||
"""ResNet head.""" | ||
|
||
def __init__(self, w_in, nc): | ||
super(ResHead, self).__init__() | ||
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) | ||
self.fc = nn.Linear(w_in, nc, bias=True) | ||
|
||
def forward(self, x): | ||
x = self.avg_pool(x) | ||
x = x.view(x.size(0), -1) | ||
x = self.fc(x) | ||
return x | ||
|
||
|
||
class BasicTransform(nn.Module): | ||
"""Basic transformation: 3x3, 3x3""" | ||
|
||
def __init__(self, w_in, w_out, stride, w_b=None, num_gs=1): | ||
assert ( | ||
w_b is None and num_gs == 1 | ||
), "Basic transform does not support w_b and num_gs options" | ||
super(BasicTransform, self).__init__() | ||
self._construct(w_in, w_out, stride) | ||
|
||
def _construct(self, w_in, w_out, stride): | ||
# 3x3, BN, ReLU | ||
self.a = nn.Conv2d( | ||
w_in, w_out, kernel_size=3, stride=stride, padding=1, bias=False | ||
) | ||
self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) | ||
# 3x3, BN | ||
self.b = nn.Conv2d(w_out, w_out, kernel_size=3, stride=1, padding=1, bias=False) | ||
self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||
self.b_bn.final_bn = True | ||
|
||
def forward(self, x): | ||
for layer in self.children(): | ||
x = layer(x) | ||
return x | ||
|
||
|
||
class BottleneckTransform(nn.Module): | ||
"""Bottleneck transformation: 1x1, 3x3, 1x1""" | ||
|
||
def __init__(self, w_in, w_out, stride, w_b, num_gs): | ||
super(BottleneckTransform, self).__init__() | ||
self._construct(w_in, w_out, stride, w_b, num_gs) | ||
|
||
def _construct(self, w_in, w_out, stride, w_b, num_gs): | ||
# MSRA -> stride=2 is on 1x1; TH/C2 -> stride=2 is on 3x3 | ||
(str1x1, str3x3) = (stride, 1) if cfg.RESNET.STRIDE_1X1 else (1, stride) | ||
# 1x1, BN, ReLU | ||
self.a = nn.Conv2d( | ||
w_in, w_b, kernel_size=1, stride=str1x1, padding=0, bias=False | ||
) | ||
self.a_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) | ||
# 3x3, BN, ReLU | ||
self.b = nn.Conv2d( | ||
w_b, w_b, kernel_size=3, stride=str3x3, padding=1, groups=num_gs, bias=False | ||
) | ||
self.b_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||
self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) | ||
# 1x1, BN | ||
self.c = nn.Conv2d(w_b, w_out, kernel_size=1, stride=1, padding=0, bias=False) | ||
self.c_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||
self.c_bn.final_bn = True | ||
|
||
def forward(self, x): | ||
for layer in self.children(): | ||
x = layer(x) | ||
return x | ||
|
||
|
||
class ResBlock(nn.Module): | ||
"""Residual block: x + F(x)""" | ||
|
||
def __init__(self, w_in, w_out, stride, trans_fun, w_b=None, num_gs=1): | ||
super(ResBlock, self).__init__() | ||
self._construct(w_in, w_out, stride, trans_fun, w_b, num_gs) | ||
|
||
def _add_skip_proj(self, w_in, w_out, stride): | ||
self.proj = nn.Conv2d( | ||
w_in, w_out, kernel_size=1, stride=stride, padding=0, bias=False | ||
) | ||
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||
|
||
def _construct(self, w_in, w_out, stride, trans_fun, w_b, num_gs): | ||
# Use skip connection with projection if shape changes | ||
self.proj_block = (w_in != w_out) or (stride != 1) | ||
if self.proj_block: | ||
self._add_skip_proj(w_in, w_out, stride) | ||
self.f = trans_fun(w_in, w_out, stride, w_b, num_gs) | ||
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) | ||
|
||
def forward(self, x): | ||
if self.proj_block: | ||
x = self.bn(self.proj(x)) + self.f(x) | ||
else: | ||
x = x + self.f(x) | ||
x = self.relu(x) | ||
return x | ||
|
||
|
||
class ResStage(nn.Module): | ||
"""Stage of ResNet.""" | ||
|
||
def __init__(self, w_in, w_out, stride, d, w_b=None, num_gs=1): | ||
super(ResStage, self).__init__() | ||
self._construct(w_in, w_out, stride, d, w_b, num_gs) | ||
|
||
def _construct(self, w_in, w_out, stride, d, w_b, num_gs): | ||
# Construct the blocks | ||
for i in range(d): | ||
# Stride and w_in apply to the first block of the stage | ||
b_stride = stride if i == 0 else 1 | ||
b_w_in = w_in if i == 0 else w_out | ||
# Retrieve the transformation function | ||
trans_fun = get_trans_fun(cfg.RESNET.TRANS_FUN) | ||
# Construct the block | ||
res_block = ResBlock(b_w_in, w_out, b_stride, trans_fun, w_b, num_gs) | ||
self.add_module("b{}".format(i + 1), res_block) | ||
|
||
def forward(self, x): | ||
for block in self.children(): | ||
x = block(x) | ||
return x | ||
|
||
|
||
class ResStem(nn.Module): | ||
"""Stem of ResNet.""" | ||
|
||
def __init__(self, w_in, w_out): | ||
assert ( | ||
cfg.TRAIN.DATASET == cfg.TEST.DATASET | ||
), "Train and test dataset must be the same for now" | ||
super(ResStem, self).__init__() | ||
if "cifar" in cfg.TRAIN.DATASET: | ||
self._construct_cifar(w_in, w_out) | ||
else: | ||
self._construct_imagenet(w_in, w_out) | ||
|
||
def _construct_cifar(self, w_in, w_out): | ||
# 3x3, BN, ReLU | ||
self.conv = nn.Conv2d( | ||
w_in, w_out, kernel_size=3, stride=1, padding=1, bias=False | ||
) | ||
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) | ||
|
||
def _construct_imagenet(self, w_in, w_out): | ||
# 7x7, BN, ReLU, maxpool | ||
self.conv = nn.Conv2d( | ||
w_in, w_out, kernel_size=7, stride=2, padding=3, bias=False | ||
) | ||
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) | ||
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | ||
|
||
def forward(self, x): | ||
for layer in self.children(): | ||
x = layer(x) | ||
return x | ||
|
||
|
||
class ResNet(nn.Module): | ||
"""ResNet model.""" | ||
|
||
def __init__(self): | ||
assert cfg.TRAIN.DATASET in [ | ||
"cifar10", | ||
"imagenet", | ||
], "Training ResNet on {} is not supported".format(cfg.TRAIN.DATASET) | ||
assert cfg.TEST.DATASET in [ | ||
"cifar10", | ||
"imagenet", | ||
], "Testing ResNet on {} is not supported".format(cfg.TEST.DATASET) | ||
super(ResNet, self).__init__() | ||
if "cifar" in cfg.TRAIN.DATASET: | ||
self._construct_cifar() | ||
else: | ||
self._construct_imagenet() | ||
self.apply(nu.init_weights) | ||
|
||
def _construct_cifar(self): | ||
assert ( | ||
cfg.MODEL.DEPTH - 2 | ||
) % 6 == 0, "Model depth should be of the format 6n + 2 for cifar" | ||
logger.info("Constructing: ResNet-{}".format(cfg.MODEL.DEPTH)) | ||
# Each stage has the same number of blocks for cifar | ||
d = int((cfg.MODEL.DEPTH - 2) / 6) | ||
# Stem: (N, 3, 32, 32) -> (N, 16, 32, 32) | ||
self.stem = ResStem(w_in=3, w_out=16) | ||
# Stage 1: (N, 16, 32, 32) -> (N, 16, 32, 32) | ||
self.s1 = ResStage(w_in=16, w_out=16, stride=1, d=d) | ||
# Stage 2: (N, 16, 32, 32) -> (N, 32, 16, 16) | ||
self.s2 = ResStage(w_in=16, w_out=32, stride=2, d=d) | ||
# Stage 3: (N, 32, 16, 16) -> (N, 64, 8, 8) | ||
self.s3 = ResStage(w_in=32, w_out=64, stride=2, d=d) | ||
# Head: (N, 64, 8, 8) -> (N, num_classes) | ||
self.head = ResHead(w_in=64, nc=cfg.MODEL.NUM_CLASSES) | ||
|
||
def _construct_imagenet(self): | ||
logger.info( | ||
"Constructing: ResNe(X)t-{}-{}x{}, {}".format( | ||
cfg.MODEL.DEPTH, | ||
cfg.RESNET.NUM_GROUPS, | ||
cfg.RESNET.WIDTH_PER_GROUP, | ||
cfg.RESNET.TRANS_FUN, | ||
) | ||
) | ||
# Retrieve the number of blocks per stage | ||
(d1, d2, d3, d4) = _IN_STAGE_DS[cfg.MODEL.DEPTH] | ||
# Compute the initial bottleneck width | ||
num_gs = cfg.RESNET.NUM_GROUPS | ||
w_b = cfg.RESNET.WIDTH_PER_GROUP * num_gs | ||
# Stem: (N, 3, 224, 224) -> (N, 64, 56, 56) | ||
self.stem = ResStem(w_in=3, w_out=64) | ||
# Stage 1: (N, 64, 56, 56) -> (N, 256, 56, 56) | ||
self.s1 = ResStage(w_in=64, w_out=256, stride=1, d=d1, w_b=w_b, num_gs=num_gs) | ||
# Stage 2: (N, 256, 56, 56) -> (N, 512, 28, 28) | ||
self.s2 = ResStage( | ||
w_in=256, w_out=512, stride=2, d=d2, w_b=w_b * 2, num_gs=num_gs | ||
) | ||
# Stage 3: (N, 512, 56, 56) -> (N, 1024, 14, 14) | ||
self.s3 = ResStage( | ||
w_in=512, w_out=1024, stride=2, d=d3, w_b=w_b * 4, num_gs=num_gs | ||
) | ||
# Stage 4: (N, 1024, 14, 14) -> (N, 2048, 7, 7) | ||
self.s4 = ResStage( | ||
w_in=1024, w_out=2048, stride=2, d=d4, w_b=w_b * 8, num_gs=num_gs | ||
) | ||
# Head: (N, 2048, 7, 7) -> (N, num_classes) | ||
self.head = ResHead(w_in=2048, nc=cfg.MODEL.NUM_CLASSES) | ||
|
||
def forward(self, x): | ||
for module in self.children(): | ||
x = module(x) | ||
return x |
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""Functions for benchmarking networks.""" | ||
|
||
import pycls.utils.logging as lu | ||
import torch | ||
from pycls.core.config import cfg | ||
from pycls.utils.timer import Timer | ||
|
||
|
||
@torch.no_grad() | ||
def compute_fw_test_time(model, inputs): | ||
"""Computes forward test time (no grad, eval mode).""" | ||
# Use eval mode | ||
model.eval() | ||
# Warm up the caches | ||
for _cur_iter in range(cfg.PREC_TIME.WARMUP_ITER): | ||
model(inputs) | ||
# Make sure warmup kernels completed | ||
torch.cuda.synchronize() | ||
# Compute precise forward pass time | ||
timer = Timer() | ||
for _cur_iter in range(cfg.PREC_TIME.NUM_ITER): | ||
timer.tic() | ||
model(inputs) | ||
torch.cuda.synchronize() | ||
timer.toc() | ||
# Make sure forward kernels completed | ||
torch.cuda.synchronize() | ||
return timer.average_time | ||
|
||
|
||
def compute_fw_bw_time(model, loss_fun, inputs, labels): | ||
"""Computes forward backward time.""" | ||
# Use train mode | ||
model.train() | ||
# Warm up the caches | ||
for _cur_iter in range(cfg.PREC_TIME.WARMUP_ITER): | ||
preds = model(inputs) | ||
loss = loss_fun(preds, labels) | ||
loss.backward() | ||
# Make sure warmup kernels completed | ||
torch.cuda.synchronize() | ||
# Compute precise forward backward pass time | ||
fw_timer = Timer() | ||
bw_timer = Timer() | ||
for _cur_iter in range(cfg.PREC_TIME.NUM_ITER): | ||
# Forward | ||
fw_timer.tic() | ||
preds = model(inputs) | ||
loss = loss_fun(preds, labels) | ||
torch.cuda.synchronize() | ||
fw_timer.toc() | ||
# Backward | ||
bw_timer.tic() | ||
loss.backward() | ||
torch.cuda.synchronize() | ||
bw_timer.toc() | ||
# Make sure forward backward kernels completed | ||
torch.cuda.synchronize() | ||
return fw_timer.average_time, bw_timer.average_time | ||
|
||
|
||
def compute_precise_time(model, loss_fun): | ||
"""Computes precise time.""" | ||
# Generate a dummy mini-batch | ||
im_size = cfg.TRAIN.IM_SIZE | ||
inputs = torch.rand(cfg.PREC_TIME.BATCH_SIZE, 3, im_size, im_size) | ||
labels = torch.zeros(cfg.PREC_TIME.BATCH_SIZE, dtype=torch.int64) | ||
# Copy the data to the GPU | ||
inputs = inputs.cuda(non_blocking=False) | ||
labels = labels.cuda(non_blocking=False) | ||
# Compute precise time | ||
fw_test_time = compute_fw_test_time(model, inputs) | ||
fw_time, bw_time = compute_fw_bw_time(model, loss_fun, inputs, labels) | ||
# Log precise time | ||
lu.log_json_stats( | ||
{ | ||
"prec_test_fw_time": fw_test_time, | ||
"prec_train_fw_time": fw_time, | ||
"prec_train_bw_time": bw_time, | ||
"prec_train_fw_bw_time": fw_time + bw_time, | ||
} | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""Functions that handle saving and loading of checkpoints.""" | ||
|
||
import os | ||
|
||
import pycls.utils.distributed as du | ||
import torch | ||
from pycls.core.config import cfg | ||
|
||
|
||
# Common prefix for checkpoint file names | ||
_NAME_PREFIX = "model_epoch_" | ||
# Checkpoints directory name | ||
_DIR_NAME = "checkpoints" | ||
|
||
|
||
def get_checkpoint_dir(): | ||
"""Retrieves the location for storing checkpoints.""" | ||
return os.path.join(cfg.OUT_DIR, _DIR_NAME) | ||
|
||
|
||
def get_checkpoint(epoch): | ||
"""Retrieves the path to a checkpoint file.""" | ||
name = "{}{:04d}.pyth".format(_NAME_PREFIX, epoch) | ||
return os.path.join(get_checkpoint_dir(), name) | ||
|
||
|
||
def get_last_checkpoint(): | ||
"""Retrieves the most recent checkpoint (highest epoch number).""" | ||
checkpoint_dir = get_checkpoint_dir() | ||
# Checkpoint file names are in lexicographic order | ||
checkpoints = [f for f in os.listdir(checkpoint_dir) if _NAME_PREFIX in f] | ||
last_checkpoint_name = sorted(checkpoints)[-1] | ||
return os.path.join(checkpoint_dir, last_checkpoint_name) | ||
|
||
|
||
def has_checkpoint(): | ||
"""Determines if there are checkpoints available.""" | ||
checkpoint_dir = get_checkpoint_dir() | ||
if not os.path.exists(checkpoint_dir): | ||
return False | ||
return any(_NAME_PREFIX in f for f in os.listdir(checkpoint_dir)) | ||
|
||
|
||
def is_checkpoint_epoch(cur_epoch): | ||
"""Determines if a checkpoint should be saved on current epoch.""" | ||
return (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0 | ||
|
||
|
||
def save_checkpoint(model, optimizer, epoch): | ||
"""Saves a checkpoint.""" | ||
# Save checkpoints only from the master process | ||
if not du.is_master_proc(): | ||
return | ||
# Ensure that the checkpoint dir exists | ||
os.makedirs(get_checkpoint_dir(), exist_ok=True) | ||
# Omit the DDP wrapper in the multi-gpu setting | ||
sd = model.module.state_dict() if cfg.NUM_GPUS > 1 else model.state_dict() | ||
# Record the state | ||
checkpoint = { | ||
"epoch": epoch, | ||
"model_state": sd, | ||
"optimizer_state": optimizer.state_dict(), | ||
"cfg": cfg.dump(), | ||
} | ||
# Write the checkpoint | ||
checkpoint_file = get_checkpoint(epoch + 1) | ||
torch.save(checkpoint, checkpoint_file) | ||
return checkpoint_file | ||
|
||
|
||
def load_checkpoint(checkpoint_file, model, optimizer=None): | ||
"""Loads the checkpoint from the given file.""" | ||
assert os.path.exists(checkpoint_file), "Checkpoint '{}' not found".format( | ||
checkpoint_file | ||
) | ||
# Load the checkpoint on CPU to avoid GPU mem spike | ||
checkpoint = torch.load(checkpoint_file, map_location="cpu") | ||
# Account for the DDP wrapper in the multi-gpu setting | ||
ms = model.module if cfg.NUM_GPUS > 1 else model | ||
ms.load_state_dict(checkpoint["model_state"]) | ||
# Load the optimizer state (commonly not done when fine-tuning) | ||
if optimizer: | ||
optimizer.load_state_dict(checkpoint["optimizer_state"]) | ||
return checkpoint["epoch"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""Distributed helpers.""" | ||
|
||
import torch | ||
from pycls.core.config import cfg | ||
|
||
|
||
def is_master_proc(): | ||
"""Determines if the current process is the master process. | ||
Master process is responsible for logging, writing and loading checkpoints. | ||
In the multi GPU setting, we assign the master role to the rank 0 process. | ||
When training using a single GPU, there is only one training processes | ||
which is considered the master processes. | ||
""" | ||
return cfg.NUM_GPUS == 1 or torch.distributed.get_rank() == 0 | ||
|
||
|
||
def init_process_group(proc_rank, world_size): | ||
"""Initializes the default process group.""" | ||
# Set the GPU to use | ||
torch.cuda.set_device(proc_rank) | ||
# Initialize the process group | ||
torch.distributed.init_process_group( | ||
backend=cfg.DIST_BACKEND, | ||
init_method="tcp://{}:{}".format(cfg.HOST, cfg.PORT), | ||
world_size=world_size, | ||
rank=proc_rank, | ||
) | ||
|
||
|
||
def destroy_process_group(): | ||
"""Destroys the default process group.""" | ||
torch.distributed.destroy_process_group() | ||
|
||
|
||
def scaled_all_reduce(tensors): | ||
"""Performs the scaled all_reduce operation on the provided tensors. | ||
The input tensors are modified in-place. Currently supports only the sum | ||
reduction operator. The reduced values are scaled by the inverse size of | ||
the process group (equivalent to cfg.NUM_GPUS). | ||
""" | ||
# Queue the reductions | ||
reductions = [] | ||
for tensor in tensors: | ||
reduction = torch.distributed.all_reduce(tensor, async_op=True) | ||
reductions.append(reduction) | ||
# Wait for reductions to finish | ||
for reduction in reductions: | ||
reduction.wait() | ||
# Scale the results | ||
for tensor in tensors: | ||
tensor.mul_(1.0 / cfg.NUM_GPUS) | ||
return tensors |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""Multiprocessing error handler.""" | ||
|
||
import os | ||
import signal | ||
import threading | ||
|
||
|
||
class ChildException(Exception): | ||
"""Wraps an exception from a child process.""" | ||
|
||
def __init__(self, child_trace): | ||
super(ChildException, self).__init__(child_trace) | ||
|
||
|
||
class ErrorHandler(object): | ||
"""Multiprocessing error handler (based on fairseq's). | ||
Listens for errors in child processes and | ||
propagates the tracebacks to the parent process. | ||
""" | ||
|
||
def __init__(self, error_queue): | ||
# Shared error queue | ||
self.error_queue = error_queue | ||
# Children processes sharing the error queue | ||
self.children_pids = [] | ||
# Start a thread listening to errors | ||
self.error_listener = threading.Thread(target=self.listen, daemon=True) | ||
self.error_listener.start() | ||
# Register the signal handler | ||
signal.signal(signal.SIGUSR1, self.signal_handler) | ||
|
||
def add_child(self, pid): | ||
"""Registers a child process.""" | ||
self.children_pids.append(pid) | ||
|
||
def listen(self): | ||
"""Listens for errors in the error queue.""" | ||
# Wait until there is an error in the queue | ||
child_trace = self.error_queue.get() | ||
# Put the error back for the signal handler | ||
self.error_queue.put(child_trace) | ||
# Invoke the signal handler | ||
os.kill(os.getpid(), signal.SIGUSR1) | ||
|
||
def signal_handler(self, _sig_num, _stack_frame): | ||
"""Signal handler.""" | ||
# Kill children processes | ||
for pid in self.children_pids: | ||
os.kill(pid, signal.SIGINT) | ||
# Propagate the error from the child process | ||
raise ChildException(self.error_queue.get()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""IO utilities (adapted from Detectron)""" | ||
|
||
import logging | ||
import os | ||
import re | ||
import sys | ||
from urllib import request as urlrequest | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
_PYCLS_BASE_URL = "https://dl.fbaipublicfiles.com/pycls" | ||
|
||
|
||
def cache_url(url_or_file, cache_dir): | ||
"""Download the file specified by the URL to the cache_dir and return the | ||
path to the cached file. If the argument is not a URL, simply return it as | ||
is. | ||
""" | ||
is_url = re.match(r"^(?:http)s?://", url_or_file, re.IGNORECASE) is not None | ||
|
||
if not is_url: | ||
return url_or_file | ||
|
||
url = url_or_file | ||
assert url.startswith(_PYCLS_BASE_URL), ( | ||
"pycls only automatically caches URLs in the pycls S3 bucket: {}" | ||
).format(_PYCLS_BASE_URL) | ||
|
||
cache_file_path = url.replace(_PYCLS_BASE_URL, cache_dir) | ||
if os.path.exists(cache_file_path): | ||
return cache_file_path | ||
|
||
cache_file_dir = os.path.dirname(cache_file_path) | ||
if not os.path.exists(cache_file_dir): | ||
os.makedirs(cache_file_dir) | ||
|
||
logger.info("Downloading remote file {} to {}".format(url, cache_file_path)) | ||
download_url(url, cache_file_path) | ||
return cache_file_path | ||
|
||
|
||
def _progress_bar(count, total): | ||
"""Report download progress. | ||
Credit: | ||
https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console/27871113 | ||
""" | ||
bar_len = 60 | ||
filled_len = int(round(bar_len * count / float(total))) | ||
|
||
percents = round(100.0 * count / float(total), 1) | ||
bar = "=" * filled_len + "-" * (bar_len - filled_len) | ||
|
||
sys.stdout.write( | ||
" [{}] {}% of {:.1f}MB file \r".format(bar, percents, total / 1024 / 1024) | ||
) | ||
sys.stdout.flush() | ||
if count >= total: | ||
sys.stdout.write("\n") | ||
|
||
|
||
def download_url(url, dst_file_path, chunk_size=8192, progress_hook=_progress_bar): | ||
"""Download url and write it to dst_file_path. | ||
Credit: | ||
https://stackoverflow.com/questions/2028517/python-urllib2-progress-hook | ||
""" | ||
req = urlrequest.Request(url) | ||
response = urlrequest.urlopen(req) | ||
total_size = response.info().get("Content-Length").strip() | ||
total_size = int(total_size) | ||
bytes_so_far = 0 | ||
|
||
with open(dst_file_path, "wb") as f: | ||
while 1: | ||
chunk = response.read(chunk_size) | ||
bytes_so_far += len(chunk) | ||
if not chunk: | ||
break | ||
if progress_hook: | ||
progress_hook(bytes_so_far, total_size) | ||
f.write(chunk) | ||
|
||
return bytes_so_far |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""Logging.""" | ||
|
||
import builtins | ||
import decimal | ||
import logging | ||
import os | ||
import sys | ||
|
||
import pycls.utils.distributed as du | ||
import simplejson | ||
from pycls.core.config import cfg | ||
|
||
|
||
# Show filename and line number in logs | ||
_FORMAT = "[%(filename)s: %(lineno)3d]: %(message)s" | ||
|
||
# Log file name (for cfg.LOG_DEST = 'file') | ||
_LOG_FILE = "stdout.log" | ||
|
||
# Printed json stats lines will be tagged w/ this | ||
_TAG = "json_stats: " | ||
|
||
|
||
def _suppress_print(): | ||
"""Suppresses printing from the current process.""" | ||
|
||
def ignore(*_objects, _sep=" ", _end="\n", _file=sys.stdout, _flush=False): | ||
pass | ||
|
||
builtins.print = ignore | ||
|
||
|
||
def setup_logging(): | ||
"""Sets up the logging.""" | ||
# Enable logging only for the master process | ||
if du.is_master_proc(): | ||
# Clear the root logger to prevent any existing logging config | ||
# (e.g. set by another module) from messing with our setup | ||
logging.root.handlers = [] | ||
# Construct logging configuration | ||
logging_config = {"level": logging.INFO, "format": _FORMAT} | ||
# Log either to stdout or to a file | ||
if cfg.LOG_DEST == "stdout": | ||
logging_config["stream"] = sys.stdout | ||
else: | ||
logging_config["filename"] = os.path.join(cfg.OUT_DIR, _LOG_FILE) | ||
# Configure logging | ||
logging.basicConfig(**logging_config) | ||
else: | ||
_suppress_print() | ||
|
||
|
||
def get_logger(name): | ||
"""Retrieves the logger.""" | ||
return logging.getLogger(name) | ||
|
||
|
||
def log_json_stats(stats): | ||
"""Logs json stats.""" | ||
# Decimal + string workaround for having fixed len float vals in logs | ||
stats = { | ||
k: decimal.Decimal("{:.6f}".format(v)) if isinstance(v, float) else v | ||
for k, v in stats.items() | ||
} | ||
json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True) | ||
logger = get_logger(__name__) | ||
logger.info("{:s}{:s}".format(_TAG, json_stats)) | ||
|
||
|
||
def load_json_stats(log_file): | ||
"""Loads json_stats from a single log file.""" | ||
with open(log_file, "r") as f: | ||
lines = f.readlines() | ||
json_lines = [l[l.find(_TAG) + len(_TAG) :] for l in lines if _TAG in l] | ||
json_stats = [simplejson.loads(l) for l in json_lines] | ||
return json_stats | ||
|
||
|
||
def parse_json_stats(log, row_type, key): | ||
"""Extract values corresponding to row_type/key out of log.""" | ||
vals = [row[key] for row in log if row["_type"] == row_type and key in row] | ||
if key == "iter" or key == "epoch": | ||
vals = [int(val.split("/")[0]) for val in vals] | ||
return vals | ||
|
||
|
||
def get_log_files(log_dir, name_filter=""): | ||
"""Get all log files in directory containing subdirs of trained models.""" | ||
names = [n for n in sorted(os.listdir(log_dir)) if name_filter in n] | ||
files = [os.path.join(log_dir, n, _LOG_FILE) for n in names] | ||
f_n_ps = [(f, n) for (f, n) in zip(files, names) if os.path.exists(f)] | ||
files, names = zip(*f_n_ps) | ||
return files, names |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""Learning rate policies.""" | ||
|
||
import numpy as np | ||
from pycls.core.config import cfg | ||
|
||
|
||
def lr_fun_steps(cur_epoch): | ||
"""Steps schedule (cfg.OPTIM.LR_POLICY = 'steps').""" | ||
ind = [i for i, s in enumerate(cfg.OPTIM.STEPS) if cur_epoch >= s][-1] | ||
return cfg.OPTIM.BASE_LR * (cfg.OPTIM.LR_MULT ** ind) | ||
|
||
|
||
def lr_fun_exp(cur_epoch): | ||
"""Exponential schedule (cfg.OPTIM.LR_POLICY = 'exp').""" | ||
return cfg.OPTIM.BASE_LR * (cfg.OPTIM.GAMMA ** cur_epoch) | ||
|
||
|
||
def lr_fun_cos(cur_epoch): | ||
"""Cosine schedule (cfg.OPTIM.LR_POLICY = 'cos').""" | ||
base_lr, max_epoch = cfg.OPTIM.BASE_LR, cfg.OPTIM.MAX_EPOCH | ||
return 0.5 * base_lr * (1.0 + np.cos(np.pi * cur_epoch / max_epoch)) | ||
|
||
|
||
def get_lr_fun(): | ||
"""Retrieves the specified lr policy function""" | ||
lr_fun = "lr_fun_" + cfg.OPTIM.LR_POLICY | ||
if lr_fun not in globals(): | ||
raise NotImplementedError("Unknown LR policy:" + cfg.OPTIM.LR_POLICY) | ||
return globals()[lr_fun] | ||
|
||
|
||
def get_epoch_lr(cur_epoch): | ||
"""Retrieves the lr for the given epoch according to the policy.""" | ||
lr = get_lr_fun()(cur_epoch) | ||
# Linear warmup | ||
if cur_epoch < cfg.OPTIM.WARMUP_EPOCHS: | ||
alpha = cur_epoch / cfg.OPTIM.WARMUP_EPOCHS | ||
warmup_factor = cfg.OPTIM.WARMUP_FACTOR * (1.0 - alpha) + alpha | ||
lr *= warmup_factor | ||
return lr |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,239 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""Meters.""" | ||
|
||
import datetime | ||
from collections import deque | ||
|
||
import numpy as np | ||
import pycls.utils.logging as lu | ||
import pycls.utils.metrics as metrics | ||
from pycls.core.config import cfg | ||
from pycls.utils.timer import Timer | ||
|
||
|
||
def eta_str(eta_td): | ||
"""Converts an eta timedelta to a fixed-width string format.""" | ||
days = eta_td.days | ||
hrs, rem = divmod(eta_td.seconds, 3600) | ||
mins, secs = divmod(rem, 60) | ||
return "{0:02},{1:02}:{2:02}:{3:02}".format(days, hrs, mins, secs) | ||
|
||
|
||
class ScalarMeter(object): | ||
"""Measures a scalar value (adapted from Detectron).""" | ||
|
||
def __init__(self, window_size): | ||
self.deque = deque(maxlen=window_size) | ||
self.total = 0.0 | ||
self.count = 0 | ||
|
||
def reset(self): | ||
self.deque.clear() | ||
self.total = 0.0 | ||
self.count = 0 | ||
|
||
def add_value(self, value): | ||
self.deque.append(value) | ||
self.count += 1 | ||
self.total += value | ||
|
||
def get_win_median(self): | ||
return np.median(self.deque) | ||
|
||
def get_win_avg(self): | ||
return np.mean(self.deque) | ||
|
||
def get_global_avg(self): | ||
return self.total / self.count | ||
|
||
|
||
class TrainMeter(object): | ||
"""Measures training stats.""" | ||
|
||
def __init__(self, epoch_iters): | ||
self.epoch_iters = epoch_iters | ||
self.max_iter = cfg.OPTIM.MAX_EPOCH * epoch_iters | ||
self.iter_timer = Timer() | ||
self.loss = ScalarMeter(cfg.LOG_PERIOD) | ||
self.loss_total = 0.0 | ||
self.lr = None | ||
# Current minibatch errors (smoothed over a window) | ||
self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD) | ||
self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD) | ||
# Number of misclassified examples | ||
self.num_top1_mis = 0 | ||
self.num_top5_mis = 0 | ||
self.num_samples = 0 | ||
|
||
def reset(self, timer=False): | ||
if timer: | ||
self.iter_timer.reset() | ||
self.loss.reset() | ||
self.loss_total = 0.0 | ||
self.lr = None | ||
self.mb_top1_err.reset() | ||
self.mb_top5_err.reset() | ||
self.num_top1_mis = 0 | ||
self.num_top5_mis = 0 | ||
self.num_samples = 0 | ||
|
||
def iter_tic(self): | ||
self.iter_timer.tic() | ||
|
||
def iter_toc(self): | ||
self.iter_timer.toc() | ||
|
||
def update_stats(self, top1_err, top5_err, loss, lr, mb_size): | ||
# Current minibatch stats | ||
self.mb_top1_err.add_value(top1_err) | ||
self.mb_top5_err.add_value(top5_err) | ||
self.loss.add_value(loss) | ||
self.lr = lr | ||
# Aggregate stats | ||
self.num_top1_mis += top1_err * mb_size | ||
self.num_top5_mis += top5_err * mb_size | ||
self.loss_total += loss * mb_size | ||
self.num_samples += mb_size | ||
|
||
def get_iter_stats(self, cur_epoch, cur_iter): | ||
eta_sec = self.iter_timer.average_time * ( | ||
self.max_iter - (cur_epoch * self.epoch_iters + cur_iter + 1) | ||
) | ||
eta_td = datetime.timedelta(seconds=int(eta_sec)) | ||
mem_usage = metrics.gpu_mem_usage() | ||
stats = { | ||
"_type": "train_iter", | ||
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), | ||
"iter": "{}/{}".format(cur_iter + 1, self.epoch_iters), | ||
"time_avg": self.iter_timer.average_time, | ||
"time_diff": self.iter_timer.diff, | ||
"eta": eta_str(eta_td), | ||
"top1_err": self.mb_top1_err.get_win_median(), | ||
"top5_err": self.mb_top5_err.get_win_median(), | ||
"loss": self.loss.get_win_median(), | ||
"lr": self.lr, | ||
"mem": int(np.ceil(mem_usage)), | ||
} | ||
return stats | ||
|
||
def log_iter_stats(self, cur_epoch, cur_iter): | ||
if (cur_iter + 1) % cfg.LOG_PERIOD != 0: | ||
return | ||
stats = self.get_iter_stats(cur_epoch, cur_iter) | ||
lu.log_json_stats(stats) | ||
|
||
def get_epoch_stats(self, cur_epoch): | ||
eta_sec = self.iter_timer.average_time * ( | ||
self.max_iter - (cur_epoch + 1) * self.epoch_iters | ||
) | ||
eta_td = datetime.timedelta(seconds=int(eta_sec)) | ||
mem_usage = metrics.gpu_mem_usage() | ||
top1_err = self.num_top1_mis / self.num_samples | ||
top5_err = self.num_top5_mis / self.num_samples | ||
avg_loss = self.loss_total / self.num_samples | ||
stats = { | ||
"_type": "train_epoch", | ||
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), | ||
"time_avg": self.iter_timer.average_time, | ||
"eta": eta_str(eta_td), | ||
"top1_err": top1_err, | ||
"top5_err": top5_err, | ||
"loss": avg_loss, | ||
"lr": self.lr, | ||
"mem": int(np.ceil(mem_usage)), | ||
} | ||
return stats | ||
|
||
def log_epoch_stats(self, cur_epoch): | ||
stats = self.get_epoch_stats(cur_epoch) | ||
lu.log_json_stats(stats) | ||
|
||
|
||
class TestMeter(object): | ||
"""Measures testing stats.""" | ||
|
||
def __init__(self, max_iter): | ||
self.max_iter = max_iter | ||
self.iter_timer = Timer() | ||
# Current minibatch errors (smoothed over a window) | ||
self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD) | ||
self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD) | ||
# Min errors (over the full test set) | ||
self.min_top1_err = 100.0 | ||
self.min_top5_err = 100.0 | ||
# Number of misclassified examples | ||
self.num_top1_mis = 0 | ||
self.num_top5_mis = 0 | ||
self.num_samples = 0 | ||
|
||
def reset(self, min_errs=False): | ||
if min_errs: | ||
self.min_top1_err = 100.0 | ||
self.min_top5_err = 100.0 | ||
self.iter_timer.reset() | ||
self.mb_top1_err.reset() | ||
self.mb_top5_err.reset() | ||
self.num_top1_mis = 0 | ||
self.num_top5_mis = 0 | ||
self.num_samples = 0 | ||
|
||
def iter_tic(self): | ||
self.iter_timer.tic() | ||
|
||
def iter_toc(self): | ||
self.iter_timer.toc() | ||
|
||
def update_stats(self, top1_err, top5_err, mb_size): | ||
self.mb_top1_err.add_value(top1_err) | ||
self.mb_top5_err.add_value(top5_err) | ||
self.num_top1_mis += top1_err * mb_size | ||
self.num_top5_mis += top5_err * mb_size | ||
self.num_samples += mb_size | ||
|
||
def get_iter_stats(self, cur_epoch, cur_iter): | ||
mem_usage = metrics.gpu_mem_usage() | ||
iter_stats = { | ||
"_type": "test_iter", | ||
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), | ||
"iter": "{}/{}".format(cur_iter + 1, self.max_iter), | ||
"time_avg": self.iter_timer.average_time, | ||
"time_diff": self.iter_timer.diff, | ||
"top1_err": self.mb_top1_err.get_win_median(), | ||
"top5_err": self.mb_top5_err.get_win_median(), | ||
"mem": int(np.ceil(mem_usage)), | ||
} | ||
return iter_stats | ||
|
||
def log_iter_stats(self, cur_epoch, cur_iter): | ||
if (cur_iter + 1) % cfg.LOG_PERIOD != 0: | ||
return | ||
stats = self.get_iter_stats(cur_epoch, cur_iter) | ||
lu.log_json_stats(stats) | ||
|
||
def get_epoch_stats(self, cur_epoch): | ||
top1_err = self.num_top1_mis / self.num_samples | ||
top5_err = self.num_top5_mis / self.num_samples | ||
self.min_top1_err = min(self.min_top1_err, top1_err) | ||
self.min_top5_err = min(self.min_top5_err, top5_err) | ||
mem_usage = metrics.gpu_mem_usage() | ||
stats = { | ||
"_type": "test_epoch", | ||
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), | ||
"time_avg": self.iter_timer.average_time, | ||
"top1_err": top1_err, | ||
"top5_err": top5_err, | ||
"min_top1_err": self.min_top1_err, | ||
"min_top5_err": self.min_top5_err, | ||
"mem": int(np.ceil(mem_usage)), | ||
} | ||
return stats | ||
|
||
def log_epoch_stats(self, cur_epoch): | ||
stats = self.get_epoch_stats(cur_epoch) | ||
lu.log_json_stats(stats) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""Functions for computing metrics.""" | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
from pycls.core.config import cfg | ||
|
||
|
||
# Number of bytes in a megabyte | ||
_B_IN_MB = 1024 * 1024 | ||
|
||
|
||
def topks_correct(preds, labels, ks): | ||
"""Computes the number of top-k correct predictions for each k.""" | ||
assert preds.size(0) == labels.size( | ||
0 | ||
), "Batch dim of predictions and labels must match" | ||
# Find the top max_k predictions for each sample | ||
_top_max_k_vals, top_max_k_inds = torch.topk( | ||
preds, max(ks), dim=1, largest=True, sorted=True | ||
) | ||
# (batch_size, max_k) -> (max_k, batch_size) | ||
top_max_k_inds = top_max_k_inds.t() | ||
# (batch_size, ) -> (max_k, batch_size) | ||
rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds) | ||
# (i, j) = 1 if top i-th prediction for the j-th sample is correct | ||
top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels) | ||
# Compute the number of topk correct predictions for each k | ||
topks_correct = [top_max_k_correct[:k, :].view(-1).float().sum() for k in ks] | ||
return topks_correct | ||
|
||
|
||
def topk_errors(preds, labels, ks): | ||
"""Computes the top-k error for each k.""" | ||
num_topks_correct = topks_correct(preds, labels, ks) | ||
return [(1.0 - x / preds.size(0)) * 100.0 for x in num_topks_correct] | ||
|
||
|
||
def topk_accuracies(preds, labels, ks): | ||
"""Computes the top-k accuracy for each k.""" | ||
num_topks_correct = topks_correct(preds, labels, ks) | ||
return [(x / preds.size(0)) * 100.0 for x in num_topks_correct] | ||
|
||
|
||
def params_count(model): | ||
"""Computes the number of parameters.""" | ||
return np.sum([p.numel() for p in model.parameters()]).item() | ||
|
||
|
||
def flops_count(model): | ||
"""Computes the number of flops statically.""" | ||
h, w = cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE | ||
count = 0 | ||
for n, m in model.named_modules(): | ||
if isinstance(m, nn.Conv2d): | ||
if "se." in n: | ||
count += m.in_channels * m.out_channels + m.bias.numel() | ||
continue | ||
h_out = (h + 2 * m.padding[0] - m.kernel_size[0]) // m.stride[0] + 1 | ||
w_out = (w + 2 * m.padding[1] - m.kernel_size[1]) // m.stride[1] + 1 | ||
count += np.prod([m.weight.numel(), h_out, w_out]) | ||
if ".proj" not in n: | ||
h, w = h_out, w_out | ||
elif isinstance(m, nn.MaxPool2d): | ||
h = (h + 2 * m.padding - m.kernel_size) // m.stride + 1 | ||
w = (w + 2 * m.padding - m.kernel_size) // m.stride + 1 | ||
elif isinstance(m, nn.Linear): | ||
count += m.in_features * m.out_features + m.bias.numel() | ||
return count.item() | ||
|
||
|
||
def acts_count(model): | ||
"""Computes the number of activations statically.""" | ||
h, w = cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE | ||
count = 0 | ||
for n, m in model.named_modules(): | ||
if isinstance(m, nn.Conv2d): | ||
if "se." in n: | ||
count += m.out_channels | ||
continue | ||
h_out = (h + 2 * m.padding[0] - m.kernel_size[0]) // m.stride[0] + 1 | ||
w_out = (w + 2 * m.padding[1] - m.kernel_size[1]) // m.stride[1] + 1 | ||
count += np.prod([m.out_channels, h_out, w_out]) | ||
if ".proj" not in n: | ||
h, w = h_out, w_out | ||
elif isinstance(m, nn.MaxPool2d): | ||
h = (h + 2 * m.padding - m.kernel_size) // m.stride + 1 | ||
w = (w + 2 * m.padding - m.kernel_size) // m.stride + 1 | ||
elif isinstance(m, nn.Linear): | ||
count += m.out_features | ||
return count.item() | ||
|
||
|
||
def gpu_mem_usage(): | ||
"""Computes the GPU memory usage for the current device (MB).""" | ||
mem_usage_bytes = torch.cuda.max_memory_allocated() | ||
return mem_usage_bytes / _B_IN_MB |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""Multiprocessing helpers.""" | ||
|
||
import multiprocessing as mp | ||
import traceback | ||
|
||
import pycls.utils.distributed as du | ||
from pycls.utils.error_handler import ErrorHandler | ||
|
||
|
||
def run(proc_rank, world_size, error_queue, fun, fun_args, fun_kwargs): | ||
"""Runs a function from a child process.""" | ||
try: | ||
# Initialize the process group | ||
du.init_process_group(proc_rank, world_size) | ||
# Run the function | ||
fun(*fun_args, **fun_kwargs) | ||
except KeyboardInterrupt: | ||
# Killed by the parent process | ||
pass | ||
except Exception: | ||
# Propagate exception to the parent process | ||
error_queue.put(traceback.format_exc()) | ||
finally: | ||
# Destroy the process group | ||
du.destroy_process_group() | ||
|
||
|
||
def multi_proc_run(num_proc, fun, fun_args=(), fun_kwargs=None): | ||
"""Runs a function in a multi-proc setting.""" | ||
|
||
if fun_kwargs is None: | ||
fun_kwargs = {} | ||
|
||
# Handle errors from training subprocesses | ||
error_queue = mp.SimpleQueue() | ||
error_handler = ErrorHandler(error_queue) | ||
|
||
# Run each training subprocess | ||
ps = [] | ||
for i in range(num_proc): | ||
p_i = mp.Process( | ||
target=run, args=(i, num_proc, error_queue, fun, fun_args, fun_kwargs) | ||
) | ||
ps.append(p_i) | ||
p_i.start() | ||
error_handler.add_child(p_i.pid) | ||
|
||
# Wait for each subprocess to finish | ||
for p in ps: | ||
p.join() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""Functions for manipulating networks.""" | ||
|
||
import itertools | ||
import math | ||
|
||
import torch | ||
import torch.nn as nn | ||
from pycls.core.config import cfg | ||
|
||
|
||
def init_weights(m): | ||
"""Performs ResNet-style weight initialization.""" | ||
if isinstance(m, nn.Conv2d): | ||
# Note that there is no bias due to BN | ||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | ||
m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out)) | ||
elif isinstance(m, nn.BatchNorm2d): | ||
zero_init_gamma = ( | ||
hasattr(m, "final_bn") and m.final_bn and cfg.BN.ZERO_INIT_FINAL_GAMMA | ||
) | ||
m.weight.data.fill_(0.0 if zero_init_gamma else 1.0) | ||
m.bias.data.zero_() | ||
elif isinstance(m, nn.Linear): | ||
m.weight.data.normal_(mean=0.0, std=0.01) | ||
m.bias.data.zero_() | ||
|
||
|
||
@torch.no_grad() | ||
def compute_precise_bn_stats(model, loader): | ||
"""Computes precise BN stats on training data.""" | ||
# Compute the number of minibatches to use | ||
num_iter = min(cfg.BN.NUM_SAMPLES_PRECISE // loader.batch_size, len(loader)) | ||
# Retrieve the BN layers | ||
bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)] | ||
# Initialize stats storage | ||
mus = [torch.zeros_like(bn.running_mean) for bn in bns] | ||
sqs = [torch.zeros_like(bn.running_var) for bn in bns] | ||
# Remember momentum values | ||
moms = [bn.momentum for bn in bns] | ||
# Disable momentum | ||
for bn in bns: | ||
bn.momentum = 1.0 | ||
# Accumulate the stats across the data samples | ||
for inputs, _labels in itertools.islice(loader, num_iter): | ||
model(inputs.cuda()) | ||
# Accumulate the stats for each BN layer | ||
for i, bn in enumerate(bns): | ||
m, v = bn.running_mean, bn.running_var | ||
sqs[i] += (v + m * m) / num_iter | ||
mus[i] += m / num_iter | ||
# Set the stats and restore momentum values | ||
for i, bn in enumerate(bns): | ||
bn.running_var = sqs[i] - mus[i] * mus[i] | ||
bn.running_mean = mus[i] | ||
bn.momentum = moms[i] | ||
|
||
|
||
def reset_bn_stats(model): | ||
"""Resets running BN stats.""" | ||
for m in model.modules(): | ||
if isinstance(m, torch.nn.BatchNorm2d): | ||
m.reset_running_stats() | ||
|
||
|
||
def drop_connect(x, drop_ratio): | ||
"""Drop connect (adapted from DARTS).""" | ||
keep_ratio = 1.0 - drop_ratio | ||
mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device) | ||
mask.bernoulli_(keep_ratio) | ||
x.div_(keep_ratio) | ||
x.mul_(mask) | ||
return x | ||
|
||
|
||
def get_flat_weights(model): | ||
"""Gets all model weights as a single flat vector.""" | ||
return torch.cat([p.data.view(-1, 1) for p in model.parameters()], 0) | ||
|
||
|
||
def set_flat_weights(model, flat_weights): | ||
"""Sets all model weights from a single flat vector.""" | ||
k = 0 | ||
for p in model.parameters(): | ||
n = p.data.numel() | ||
p.data.copy_(flat_weights[k : (k + n)].view_as(p.data)) | ||
k += n | ||
assert k == flat_weights.numel() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""Plotting functions.""" | ||
|
||
import colorlover as cl | ||
import matplotlib.pyplot as plt | ||
import plotly.graph_objs as go | ||
import plotly.offline as offline | ||
import pycls.utils.logging as lu | ||
|
||
|
||
def get_plot_colors(max_colors, color_format="pyplot"): | ||
"""Generate colors for plotting.""" | ||
colors = cl.scales["11"]["qual"]["Paired"] | ||
if max_colors > len(colors): | ||
colors = cl.to_rgb(cl.interp(colors, max_colors)) | ||
if color_format == "pyplot": | ||
return [[j / 255.0 for j in c] for c in cl.to_numeric(colors)] | ||
return colors | ||
|
||
|
||
def prepare_plot_data(log_files, names, key="top1_err"): | ||
"""Load logs and extract data for plotting error curves.""" | ||
plot_data = [] | ||
for file, name in zip(log_files, names): | ||
d, log = {}, lu.load_json_stats(file) | ||
for phase in ["train", "test"]: | ||
x = lu.parse_json_stats(log, phase + "_epoch", "epoch") | ||
y = lu.parse_json_stats(log, phase + "_epoch", key) | ||
d["x_" + phase], d["y_" + phase] = x, y | ||
d[phase + "_label"] = "[{:5.2f}] ".format(min(y) if y else 0) + name | ||
plot_data.append(d) | ||
assert len(plot_data) > 0, "No data to plot" | ||
return plot_data | ||
|
||
|
||
def plot_error_curves_plotly(log_files, names, filename, key="top1_err"): | ||
"""Plot error curves using plotly and save to file.""" | ||
plot_data = prepare_plot_data(log_files, names, key) | ||
colors = get_plot_colors(len(plot_data), "plotly") | ||
# Prepare data for plots (3 sets, train duplicated w and w/o legend) | ||
data = [] | ||
for i, d in enumerate(plot_data): | ||
s = str(i) | ||
line_train = {"color": colors[i], "dash": "dashdot", "width": 1.5} | ||
line_test = {"color": colors[i], "dash": "solid", "width": 1.5} | ||
data.append( | ||
go.Scatter( | ||
x=d["x_train"], | ||
y=d["y_train"], | ||
mode="lines", | ||
name=d["train_label"], | ||
line=line_train, | ||
legendgroup=s, | ||
visible=True, | ||
showlegend=False, | ||
) | ||
) | ||
data.append( | ||
go.Scatter( | ||
x=d["x_test"], | ||
y=d["y_test"], | ||
mode="lines", | ||
name=d["test_label"], | ||
line=line_test, | ||
legendgroup=s, | ||
visible=True, | ||
showlegend=True, | ||
) | ||
) | ||
data.append( | ||
go.Scatter( | ||
x=d["x_train"], | ||
y=d["y_train"], | ||
mode="lines", | ||
name=d["train_label"], | ||
line=line_train, | ||
legendgroup=s, | ||
visible=False, | ||
showlegend=True, | ||
) | ||
) | ||
# Prepare layout w ability to toggle 'all', 'train', 'test' | ||
titlefont = {"size": 18, "color": "#7f7f7f"} | ||
vis = [[True, True, False], [False, False, True], [False, True, False]] | ||
buttons = zip(["all", "train", "test"], [[{"visible": v}] for v in vis]) | ||
buttons = [{"label": l, "args": v, "method": "update"} for l, v in buttons] | ||
layout = go.Layout( | ||
title=key + " vs. epoch<br>[dash=train, solid=test]", | ||
xaxis={"title": "epoch", "titlefont": titlefont}, | ||
yaxis={"title": key, "titlefont": titlefont}, | ||
showlegend=True, | ||
hoverlabel={"namelength": -1}, | ||
updatemenus=[ | ||
{ | ||
"buttons": buttons, | ||
"direction": "down", | ||
"showactive": True, | ||
"x": 1.02, | ||
"xanchor": "left", | ||
"y": 1.08, | ||
"yanchor": "top", | ||
} | ||
], | ||
) | ||
# Create plotly plot | ||
offline.plot({"data": data, "layout": layout}, filename=filename) | ||
|
||
|
||
def plot_error_curves_pyplot(log_files, names, filename=None, key="top1_err"): | ||
"""Plot error curves using matplotlib.pyplot and save to file.""" | ||
plot_data = prepare_plot_data(log_files, names, key) | ||
colors = get_plot_colors(len(names)) | ||
for ind, d in enumerate(plot_data): | ||
c, lbl = colors[ind], d["test_label"] | ||
plt.plot(d["x_train"], d["y_train"], "--", c=c, alpha=0.8) | ||
plt.plot(d["x_test"], d["y_test"], "-", c=c, alpha=0.8, label=lbl) | ||
plt.title(key + " vs. epoch\n[dash=train, solid=test]", fontsize=14) | ||
plt.xlabel("epoch", fontsize=14) | ||
plt.ylabel(key, fontsize=14) | ||
plt.grid(alpha=0.4) | ||
plt.legend() | ||
if filename: | ||
plt.savefig(filename) | ||
plt.clf() | ||
else: | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""Timer.""" | ||
|
||
import time | ||
|
||
|
||
class Timer(object): | ||
"""A simple timer (adapted from Detectron).""" | ||
|
||
def __init__(self): | ||
self.reset() | ||
|
||
def tic(self): | ||
# using time.time instead of time.clock because time time.clock | ||
# does not normalize for multithreading | ||
self.start_time = time.time() | ||
|
||
def toc(self): | ||
self.diff = time.time() - self.start_time | ||
self.total_time += self.diff | ||
self.calls += 1 | ||
self.average_time = self.total_time / self.calls | ||
|
||
def reset(self): | ||
self.total_time = 0.0 | ||
self.calls = 0 | ||
self.start_time = 0.0 | ||
self.diff = 0.0 | ||
self.average_time = 0.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
|
||
import torch | ||
import argparse | ||
import numpy as np | ||
|
||
from PIL import Image | ||
|
||
from RegDanbooru2019_8G import RegDanbooru2019 | ||
|
||
parser = argparse.ArgumentParser(description='Test RegDeepDanbooru') | ||
parser.add_argument('--model', default='', type=str, help='trained model') | ||
parser.add_argument('--image', default='', type=str, help='image to test') | ||
parser.add_argument('--size', default=768, type=int, help='canvas size') | ||
parser.add_argument('--threshold', default=0.5, type=float, help='threshold') | ||
args = parser.parse_args() | ||
|
||
DANBOORU_LABEL_MAP = {} | ||
|
||
def load_danbooru_label_map() : | ||
print(' -- Loading danbooru2019 labels') | ||
global DANBOORU_LABEL_MAP | ||
with open('danbooru_labels.txt', 'r') as fp : | ||
for l in fp : | ||
l = l.strip() | ||
if l : | ||
idx, tag = l.split(' ') | ||
DANBOORU_LABEL_MAP[int(idx)] = tag | ||
|
||
def test(model, image_resized) : | ||
print(' -- Running model on GPU') | ||
image_resized_torch = torch.from_numpy(image_resized).float() / 127.5 - 1.0 | ||
if len(image_resized_torch.shape) == 3 : | ||
image_resized_torch = image_resized_torch.unsqueeze(0).permute(0, 3, 1, 2) | ||
elif len(image_resized_torch.shape) == 4 : | ||
image_resized_torch = image_resized_torch.permute(0, 3, 1, 2) | ||
image_resized_torch = image_resized_torch.cuda() | ||
with torch.no_grad() : | ||
danbooru_logits = model(image_resized_torch) | ||
danbooru = danbooru_logits.sigmoid().cpu() | ||
return danbooru | ||
|
||
def load_and_resize_image(img_path, canvas_size = 512) : | ||
img = Image.open(img_path).convert('RGB') | ||
old_size = img.size | ||
w, h = old_size | ||
w, h = float(w), float(h) | ||
ratio = float(canvas_size) / max(old_size) | ||
new_size = tuple([int(round(x * ratio)) for x in old_size]) | ||
print(f'Test image size: {new_size}') | ||
return np.array(img.resize(new_size, Image.ANTIALIAS)) | ||
|
||
def translate_danbooru_labels(probs, threshold = 0.8) : | ||
global DANBOORU_LABEL_MAP | ||
choosen_indices = (probs > threshold).nonzero() | ||
result = [] | ||
for i in range(probs.size(0)) : | ||
prob_single = probs[0].numpy() | ||
indices_single = choosen_indices[choosen_indices[:, 0] == i][:, 1].numpy() | ||
tag_prob_map = {DANBOORU_LABEL_MAP[idx]: prob_single[idx] for idx in indices_single} | ||
result.append(tag_prob_map) | ||
return result | ||
|
||
def main() : | ||
model = RegDanbooru2019().cuda() | ||
model.load_state_dict(torch.load(args.model)['model']) | ||
model.eval() | ||
torch.save(model, 'RegNetY-8G.pth',) | ||
|
||
test_img = load_and_resize_image(args.image, args.size) | ||
|
||
danbooru = test(model, test_img) | ||
|
||
tags = translate_danbooru_labels(danbooru, args.threshold) | ||
print(tags) | ||
|
||
if __name__ == "__main__": | ||
load_danbooru_label_map() | ||
main() |