Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
zyddnys committed Oct 11, 2020
0 parents commit bf20f72
Showing 65 changed files with 7,467 additions and 0 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Yet another Deep Danbooru project
But based on [RegNetY-8G](https://arxiv.org/abs/2003.13678), relative lightweight, designed to run fast on GPU. \
Training is done using mixed precision training on a single RTX2080Ti for 3 weeks. \
Some code are from https://github.com/facebookresearch/pycls
# What do I need?
You need to download [save_4000000.ckpt]() from release and place on the same folder as `test.py`.
# How to use?
`python test.py --model save_4000000.ckpt --image <PATH_TO_IMAGE>`
# What to do in the future?
1. Quantize to 8 bit

33 changes: 33 additions & 0 deletions RegDanbooru2019_8G.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@

import torch
import torch.nn as nn
import torch.nn.functional as F

from RegNetY_8G import build_model

class RegDanbooru2019(nn.Module) :
def __init__(self) :
super(RegDanbooru2019, self).__init__()
self.backbone = build_model()
num_p = sum(p.numel() for p in self.backbone.parameters() if p.requires_grad)
print( 'Backbone has %d parameters' % num_p )
self.head_danbooru = nn.Linear(2016, 4096)

def forward_train_head(self, images) :
"""
images of shape [N, 3, 512, 512]
"""
with torch.no_grad() :
feats = self.backbone(images)
feats = F.adaptive_avg_pool2d(feats, 1).view(-1, 2016)
danbooru_logits = self.head_danbooru(feats) # [N, 4096]
return danbooru_logits

def forward(self, images) :
"""
images of shape [N, 3, 512, 512]
"""
feats = self.backbone(images)
feats = F.adaptive_avg_pool2d(feats, 1).view(-1, 2016)
danbooru_logits = self.head_danbooru(feats) # [N, 4096]
return danbooru_logits
27 changes: 27 additions & 0 deletions RegNetY-8.0GF_dds_8gpu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
MODEL:
TYPE: regnet
NUM_CLASSES: 1000
REGNET:
SE_ON: true
DEPTH: 17
W0: 192
WA: 76.82
WM: 2.19
GROUP_W: 56
OPTIM:
LR_POLICY: cos
BASE_LR: 0.4
MAX_EPOCH: 100
MOMENTUM: 0.9
WEIGHT_DECAY: 5e-5
WARMUP_EPOCHS: 5
TRAIN:
DATASET: imagenet
IM_SIZE: 512
BATCH_SIZE: 512
TEST:
DATASET: imagenet
IM_SIZE: 512
BATCH_SIZE: 400
NUM_GPUS: 1
OUT_DIR: .
67 changes: 67 additions & 0 deletions RegNetY_8G.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Test a trained classification model."""

import argparse
import sys

import numpy as np
import pycls.core.losses as losses
import pycls.core.model_builder as model_builder
import pycls.datasets.loader as loader
import pycls.utils.benchmark as bu
import pycls.utils.checkpoint as cu
import pycls.utils.distributed as du
import pycls.utils.logging as lu
import pycls.utils.metrics as mu
import pycls.utils.multiprocessing as mpu
import pycls.utils.net as nu
import torch
from pycls.core.config import assert_and_infer_cfg, cfg
from pycls.utils.meters import TestMeter


logger = lu.get_logger(__name__)

def log_model_info(model):
"""Logs model info"""
logger.info("Model:\n{}".format(model))
logger.info("Params: {:,}".format(mu.params_count(model)))
logger.info("Flops: {:,}".format(mu.flops_count(model)))
logger.info("Acts: {:,}".format(mu.acts_count(model)))

def build_model():

# Load config options
cfg.merge_from_file('RegNetY-8.0GF_dds_8gpu.yaml')
cfg.merge_from_list([])
assert_and_infer_cfg()
cfg.freeze()
# Setup logging
lu.setup_logging()
# Show the config
logger.info("Config:\n{}".format(cfg))

# Fix the RNG seeds (see RNG comment in core/config.py for discussion)
np.random.seed(cfg.RNG_SEED)
torch.manual_seed(cfg.RNG_SEED)
# Configure the CUDNN backend
torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK

# Build the model (before the loaders to speed up debugging)
model = model_builder.build_model()
log_model_info(model)

# Load model weights
#cu.load_checkpoint('RegNetY-8.0GF_dds_8gpu.pyth', model)
logger.info("Loaded model weights from: {}".format('RegNetY-8.0GF_dds_8gpu.pyth'))

del model.head

return model

4,096 changes: 4,096 additions & 0 deletions danbooru_labels.txt

Large diffs are not rendered by default.

Empty file added pycls/__init__.py
Empty file.
Binary file added pycls/__pycache__/__init__.cpython-37.pyc
Binary file not shown.
Empty file added pycls/core/__init__.py
Empty file.
Binary file added pycls/core/__pycache__/__init__.cpython-37.pyc
Binary file not shown.
Binary file added pycls/core/__pycache__/config.cpython-37.pyc
Binary file not shown.
Binary file added pycls/core/__pycache__/losses.cpython-37.pyc
Binary file not shown.
Binary file not shown.
410 changes: 410 additions & 0 deletions pycls/core/config.py

Large diffs are not rendered by default.

28 changes: 28 additions & 0 deletions pycls/core/losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Loss functions."""

import torch.nn as nn
from pycls.core.config import cfg


# Supported loss functions
_loss_funs = {"cross_entropy": nn.CrossEntropyLoss}


def get_loss_fun():
"""Retrieves the loss function."""
assert (
cfg.MODEL.LOSS_FUN in _loss_funs.keys()
), "Loss function '{}' not supported".format(cfg.TRAIN.LOSS)
return _loss_funs[cfg.MODEL.LOSS_FUN]().cuda()


def register_loss_fun(name, ctor):
"""Registers a loss function dynamically."""
_loss_funs[name] = ctor
50 changes: 50 additions & 0 deletions pycls/core/model_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Model construction functions."""

import pycls.utils.logging as lu
import torch
from pycls.core.config import cfg
from pycls.models.anynet import AnyNet
from pycls.models.effnet import EffNet
from pycls.models.regnet import RegNet
from pycls.models.resnet import ResNet


logger = lu.get_logger(__name__)

# Supported models
_models = {"anynet": AnyNet, "effnet": EffNet, "resnet": ResNet, "regnet": RegNet}


def build_model():
"""Builds the model."""
assert cfg.MODEL.TYPE in _models.keys(), "Model type '{}' not supported".format(
cfg.MODEL.TYPE
)
assert (
cfg.NUM_GPUS <= torch.cuda.device_count()
), "Cannot use more GPU devices than available"
# Construct the model
model = _models[cfg.MODEL.TYPE]()
# Determine the GPU used by the current process
cur_device = torch.cuda.current_device()
# Transfer the model to the current GPU device
model = model.cuda(device=cur_device)
# Use multi-process data parallel model in the multi-gpu setting
if cfg.NUM_GPUS > 1:
# Make model replica operate on the current device
model = torch.nn.parallel.DistributedDataParallel(
module=model, device_ids=[cur_device], output_device=cur_device
)
return model


def register_model(name, ctor):
"""Registers a model dynamically."""
_models[name] = ctor
79 changes: 79 additions & 0 deletions pycls/core/optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Optimizer."""

import pycls.utils.lr_policy as lr_policy
import torch
from pycls.core.config import cfg


def construct_optimizer(model):
"""Constructs the optimizer.
Note that the momentum update in PyTorch differs from the one in Caffe2.
In particular,
Caffe2:
V := mu * V + lr * g
p := p - V
PyTorch:
V := mu * V + g
p := p - lr * V
where V is the velocity, mu is the momentum factor, lr is the learning rate,
g is the gradient and p are the parameters.
Since V is defined independently of the learning rate in PyTorch,
when the learning rate is changed there is no need to perform the
momentum correction by scaling V (unlike in the Caffe2 case).
"""
# Batchnorm parameters.
bn_params = []
# Non-batchnorm parameters.
non_bn_parameters = []
for name, p in model.named_parameters():
if "bn" in name:
bn_params.append(p)
else:
non_bn_parameters.append(p)
# Apply different weight decay to Batchnorm and non-batchnorm parameters.
bn_weight_decay = (
cfg.BN.CUSTOM_WEIGHT_DECAY
if cfg.BN.USE_CUSTOM_WEIGHT_DECAY
else cfg.OPTIM.WEIGHT_DECAY
)
optim_params = [
{"params": bn_params, "weight_decay": bn_weight_decay},
{"params": non_bn_parameters, "weight_decay": cfg.OPTIM.WEIGHT_DECAY},
]
# Check all parameters will be passed into optimizer.
assert len(list(model.parameters())) == len(non_bn_parameters) + len(
bn_params
), "parameter size does not match: {} + {} != {}".format(
len(non_bn_parameters), len(bn_params), len(list(model.parameters()))
)
return torch.optim.SGD(
optim_params,
lr=cfg.OPTIM.BASE_LR,
momentum=cfg.OPTIM.MOMENTUM,
weight_decay=cfg.OPTIM.WEIGHT_DECAY,
dampening=cfg.OPTIM.DAMPENING,
nesterov=cfg.OPTIM.NESTEROV,
)


def get_epoch_lr(cur_epoch):
"""Retrieves the lr for the given epoch (as specified by the lr policy)."""
return lr_policy.get_epoch_lr(cur_epoch)


def set_lr(optimizer, new_lr):
"""Sets the optimizer lr to the specified value."""
for param_group in optimizer.param_groups:
param_group["lr"] = new_lr
Empty file added pycls/datasets/__init__.py
Empty file.
Binary file added pycls/datasets/__pycache__/__init__.cpython-37.pyc
Binary file not shown.
Binary file added pycls/datasets/__pycache__/cifar10.cpython-37.pyc
Binary file not shown.
Binary file not shown.
Binary file added pycls/datasets/__pycache__/loader.cpython-37.pyc
Binary file not shown.
Binary file added pycls/datasets/__pycache__/paths.cpython-37.pyc
Binary file not shown.
Binary file not shown.
83 changes: 83 additions & 0 deletions pycls/datasets/cifar10.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""CIFAR10 dataset."""

import os
import pickle

import numpy as np
import pycls.datasets.transforms as transforms
import pycls.utils.logging as lu
import torch
import torch.utils.data
from pycls.core.config import cfg


logger = lu.get_logger(__name__)

# Per-channel mean and SD values in BGR order
_MEAN = [125.3, 123.0, 113.9]
_SD = [63.0, 62.1, 66.7]


class Cifar10(torch.utils.data.Dataset):
"""CIFAR-10 dataset."""

def __init__(self, data_path, split):
assert os.path.exists(data_path), "Data path '{}' not found".format(data_path)
assert split in ["train", "test"], "Split '{}' not supported for cifar".format(
split
)
logger.info("Constructing CIFAR-10 {}...".format(split))
self._data_path = data_path
self._split = split
# Data format:
# self._inputs - (split_size, 3, im_size, im_size) ndarray
# self._labels - split_size list
self._inputs, self._labels = self._load_data()

def _load_batch(self, batch_path):
with open(batch_path, "rb") as f:
d = pickle.load(f, encoding="bytes")
return d[b"data"], d[b"labels"]

def _load_data(self):
"""Loads data in memory."""
logger.info("{} data path: {}".format(self._split, self._data_path))
# Compute data batch names
if self._split == "train":
batch_names = ["data_batch_{}".format(i) for i in range(1, 6)]
else:
batch_names = ["test_batch"]
# Load data batches
inputs, labels = [], []
for batch_name in batch_names:
batch_path = os.path.join(self._data_path, batch_name)
inputs_batch, labels_batch = self._load_batch(batch_path)
inputs.append(inputs_batch)
labels += labels_batch
# Combine and reshape the inputs
inputs = np.vstack(inputs).astype(np.float32)
inputs = inputs.reshape((-1, 3, cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE))
return inputs, labels

def _prepare_im(self, im):
"""Prepares the image for network input."""
im = transforms.color_norm(im, _MEAN, _SD)
if self._split == "train":
im = transforms.horizontal_flip(im=im, p=0.5)
im = transforms.random_crop(im=im, size=cfg.TRAIN.IM_SIZE, pad_size=4)
return im

def __getitem__(self, index):
im, label = self._inputs[index, ...].copy(), self._labels[index]
im = self._prepare_im(im)
return im, label

def __len__(self):
return self._inputs.shape[0]
108 changes: 108 additions & 0 deletions pycls/datasets/imagenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""ImageNet dataset."""

import os
import re

import cv2
import numpy as np
import pycls.datasets.transforms as transforms
import pycls.utils.logging as lu
import torch
import torch.utils.data
from pycls.core.config import cfg


logger = lu.get_logger(__name__)

# Per-channel mean and SD values in BGR order
_MEAN = [0.406, 0.456, 0.485]
_SD = [0.225, 0.224, 0.229]

# Eig vals and vecs of the cov mat
_EIG_VALS = np.array([[0.2175, 0.0188, 0.0045]])
_EIG_VECS = np.array(
[[-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140], [-0.5836, -0.6948, 0.4203]]
)


class ImageNet(torch.utils.data.Dataset):
"""ImageNet dataset."""

def __init__(self, data_path, split):
assert os.path.exists(data_path), "Data path '{}' not found".format(data_path)
assert split in [
"train",
"val",
], "Split '{}' not supported for ImageNet".format(split)
logger.info("Constructing ImageNet {}...".format(split))
self._data_path = data_path
self._split = split
self._construct_imdb()

def _construct_imdb(self):
"""Constructs the imdb."""
# Compile the split data path
split_path = os.path.join(self._data_path, self._split)
logger.info("{} data path: {}".format(self._split, split_path))
# Images are stored per class in subdirs (format: n<number>)
self._class_ids = sorted(
f for f in os.listdir(split_path) if re.match(r"^n[0-9]+$", f)
)
# Map ImageNet class ids to contiguous ids
self._class_id_cont_id = {v: i for i, v in enumerate(self._class_ids)}
# Construct the image db
self._imdb = []
for class_id in self._class_ids:
cont_id = self._class_id_cont_id[class_id]
im_dir = os.path.join(split_path, class_id)
for im_name in os.listdir(im_dir):
self._imdb.append(
{"im_path": os.path.join(im_dir, im_name), "class": cont_id}
)
logger.info("Number of images: {}".format(len(self._imdb)))
logger.info("Number of classes: {}".format(len(self._class_ids)))

def _prepare_im(self, im):
"""Prepares the image for network input."""
# Train and test setups differ
if self._split == "train":
# Scale and aspect ratio
im = transforms.random_sized_crop(
im=im, size=cfg.TRAIN.IM_SIZE, area_frac=0.08
)
# Horizontal flip
im = transforms.horizontal_flip(im=im, p=0.5, order="HWC")
else:
# Scale and center crop
im = transforms.scale(cfg.TEST.IM_SIZE, im)
im = transforms.center_crop(cfg.TRAIN.IM_SIZE, im)
# HWC -> CHW
im = im.transpose([2, 0, 1])
# [0, 255] -> [0, 1]
im = im / 255.0
# PCA jitter
if self._split == "train":
im = transforms.lighting(im, 0.1, _EIG_VALS, _EIG_VECS)
# Color normalization
im = transforms.color_norm(im, _MEAN, _SD)
return im

def __getitem__(self, index):
# Load the image
im = cv2.imread(self._imdb[index]["im_path"])
im = im.astype(np.float32, copy=False)
# Prepare the image for training / testing
im = self._prepare_im(im)
# Retrieve the label
label = self._imdb[index]["class"]
return im, label

def __len__(self):
return len(self._imdb)
80 changes: 80 additions & 0 deletions pycls/datasets/loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Data loader."""

import pycls.datasets.paths as dp
import torch
from pycls.core.config import cfg
from pycls.datasets.cifar10 import Cifar10
from pycls.datasets.imagenet import ImageNet
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler


# Supported datasets
_DATASET_CATALOG = {"cifar10": Cifar10, "imagenet": ImageNet}


def _construct_loader(dataset_name, split, batch_size, shuffle, drop_last):
"""Constructs the data loader for the given dataset."""
assert dataset_name in _DATASET_CATALOG.keys(), "Dataset '{}' not supported".format(
dataset_name
)
assert dp.has_data_path(dataset_name), "Dataset '{}' has no data path".format(
dataset_name
)
# Retrieve the data path for the dataset
data_path = dp.get_data_path(dataset_name)
# Construct the dataset
dataset = _DATASET_CATALOG[dataset_name](data_path, split)
# Create a sampler for multi-process training
sampler = DistributedSampler(dataset) if cfg.NUM_GPUS > 1 else None
# Create a loader
loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=(False if sampler else shuffle),
sampler=sampler,
num_workers=cfg.DATA_LOADER.NUM_WORKERS,
pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
drop_last=drop_last,
)
return loader


def construct_train_loader():
"""Train loader wrapper."""
return _construct_loader(
dataset_name=cfg.TRAIN.DATASET,
split=cfg.TRAIN.SPLIT,
batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
shuffle=True,
drop_last=True,
)


def construct_test_loader():
"""Test loader wrapper."""
return _construct_loader(
dataset_name=cfg.TEST.DATASET,
split=cfg.TEST.SPLIT,
batch_size=int(cfg.TEST.BATCH_SIZE / cfg.NUM_GPUS),
shuffle=False,
drop_last=False,
)


def shuffle(loader, cur_epoch):
""""Shuffles the data."""
assert isinstance(
loader.sampler, (RandomSampler, DistributedSampler)
), "Sampler type '{}' not supported".format(type(loader.sampler))
# RandomSampler handles shuffling automatically
if isinstance(loader.sampler, DistributedSampler):
# DistributedSampler shuffles data based on epoch
loader.sampler.set_epoch(cur_epoch)
35 changes: 35 additions & 0 deletions pycls/datasets/paths.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Dataset paths."""

import os


# Default data directory (/path/pycls/pycls/datasets/data)
_DEF_DATA_DIR = os.path.join(os.path.dirname(__file__), "data")

# Data paths
_paths = {
"cifar10": _DEF_DATA_DIR + "/cifar10",
"imagenet": _DEF_DATA_DIR + "/imagenet",
}


def has_data_path(dataset_name):
"""Determines if the dataset has a data path."""
return dataset_name in _paths.keys()


def get_data_path(dataset_name):
"""Retrieves data path for the dataset."""
return _paths[dataset_name]


def register_path(name, path):
"""Registers a dataset path dynamically."""
_paths[name] = path
108 changes: 108 additions & 0 deletions pycls/datasets/transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Image transformations."""

import math

import cv2
import numpy as np


def color_norm(im, mean, std):
"""Performs per-channel normalization (CHW format)."""
for i in range(im.shape[0]):
im[i] = im[i] - mean[i]
im[i] = im[i] / std[i]
return im


def zero_pad(im, pad_size):
"""Performs zero padding (CHW format)."""
pad_width = ((0, 0), (pad_size, pad_size), (pad_size, pad_size))
return np.pad(im, pad_width, mode="constant")


def horizontal_flip(im, p, order="CHW"):
"""Performs horizontal flip (CHW or HWC format)."""
assert order in ["CHW", "HWC"]
if np.random.uniform() < p:
if order == "CHW":
im = im[:, :, ::-1]
else:
im = im[:, ::-1, :]
return im


def random_crop(im, size, pad_size=0):
"""Performs random crop (CHW format)."""
if pad_size > 0:
im = zero_pad(im=im, pad_size=pad_size)
h, w = im.shape[1:]
y = np.random.randint(0, h - size)
x = np.random.randint(0, w - size)
im_crop = im[:, y : (y + size), x : (x + size)]
assert im_crop.shape[1:] == (size, size)
return im_crop


def scale(size, im):
"""Performs scaling (HWC format)."""
h, w = im.shape[:2]
if (w <= h and w == size) or (h <= w and h == size):
return im
h_new, w_new = size, size
if w < h:
h_new = int(math.floor((float(h) / w) * size))
else:
w_new = int(math.floor((float(w) / h) * size))
im = cv2.resize(im, (w_new, h_new), interpolation=cv2.INTER_LINEAR)
return im.astype(np.float32)


def center_crop(size, im):
"""Performs center cropping (HWC format)."""
h, w = im.shape[:2]
y = int(math.ceil((h - size) / 2))
x = int(math.ceil((w - size) / 2))
im_crop = im[y : (y + size), x : (x + size), :]
assert im_crop.shape[:2] == (size, size)
return im_crop


def random_sized_crop(im, size, area_frac=0.08, max_iter=10):
"""Performs Inception-style cropping (HWC format)."""
h, w = im.shape[:2]
area = h * w
for _ in range(max_iter):
target_area = np.random.uniform(area_frac, 1.0) * area
aspect_ratio = np.random.uniform(3.0 / 4.0, 4.0 / 3.0)
w_crop = int(round(math.sqrt(float(target_area) * aspect_ratio)))
h_crop = int(round(math.sqrt(float(target_area) / aspect_ratio)))
if np.random.uniform() < 0.5:
w_crop, h_crop = h_crop, w_crop
if h_crop <= h and w_crop <= w:
y = 0 if h_crop == h else np.random.randint(0, h - h_crop)
x = 0 if w_crop == w else np.random.randint(0, w - w_crop)
im_crop = im[y : (y + h_crop), x : (x + w_crop), :]
assert im_crop.shape[:2] == (h_crop, w_crop)
im_crop = cv2.resize(im_crop, (size, size), interpolation=cv2.INTER_LINEAR)
return im_crop.astype(np.float32)
return center_crop(size, scale(size, im))


def lighting(im, alpha_std, eig_val, eig_vec):
"""Performs AlexNet-style PCA jitter (CHW format)."""
if alpha_std == 0:
return im
alpha = np.random.normal(0, alpha_std, size=(1, 3))
rgb = np.sum(
eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0), axis=1
)
for i in range(im.shape[0]):
im[i] = im[i] + rgb[2 - i]
return im
Empty file added pycls/models/__init__.py
Empty file.
Binary file added pycls/models/__pycache__/__init__.cpython-37.pyc
Binary file not shown.
Binary file added pycls/models/__pycache__/anynet.cpython-37.pyc
Binary file not shown.
Binary file added pycls/models/__pycache__/effnet.cpython-37.pyc
Binary file not shown.
Binary file added pycls/models/__pycache__/regnet.cpython-37.pyc
Binary file not shown.
Binary file added pycls/models/__pycache__/resnet.cpython-37.pyc
Binary file not shown.
380 changes: 380 additions & 0 deletions pycls/models/anynet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,380 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""AnyNet models."""

import pycls.utils.logging as lu
import pycls.utils.net as nu
import torch.nn as nn
from pycls.core.config import cfg


logger = lu.get_logger(__name__)


def get_stem_fun(stem_type):
"""Retrives the stem function by name."""
stem_funs = {
"res_stem_cifar": ResStemCifar,
"res_stem_in": ResStemIN,
"simple_stem_in": SimpleStemIN,
}
assert stem_type in stem_funs.keys(), "Stem type '{}' not supported".format(
stem_type
)
return stem_funs[stem_type]


def get_block_fun(block_type):
"""Retrieves the block function by name."""
block_funs = {
"vanilla_block": VanillaBlock,
"res_basic_block": ResBasicBlock,
"res_bottleneck_block": ResBottleneckBlock,
}
assert block_type in block_funs.keys(), "Block type '{}' not supported".format(
block_type
)
return block_funs[block_type]


class AnyHead(nn.Module):
"""AnyNet head."""

def __init__(self, w_in, nc):
super(AnyHead, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(w_in, nc, bias=True)

def forward(self, x):
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x


class VanillaBlock(nn.Module):
"""Vanilla block: [3x3 conv, BN, Relu] x2"""

def __init__(self, w_in, w_out, stride, bm=None, gw=None, se_r=None):
assert (
bm is None and gw is None and se_r is None
), "Vanilla block does not support bm, gw, and se_r options"
super(VanillaBlock, self).__init__()
self._construct(w_in, w_out, stride)

def _construct(self, w_in, w_out, stride):
# 3x3, BN, ReLU
self.a = nn.Conv2d(
w_in, w_out, kernel_size=3, stride=stride, padding=1, bias=False
)
self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
# 3x3, BN, ReLU
self.b = nn.Conv2d(w_out, w_out, kernel_size=3, stride=1, padding=1, bias=False)
self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)

def forward(self, x):
for layer in self.children():
x = layer(x)
return x


class BasicTransform(nn.Module):
"""Basic transformation: [3x3 conv, BN, Relu] x2"""

def __init__(self, w_in, w_out, stride):
super(BasicTransform, self).__init__()
self._construct(w_in, w_out, stride)

def _construct(self, w_in, w_out, stride):
# 3x3, BN, ReLU
self.a = nn.Conv2d(
w_in, w_out, kernel_size=3, stride=stride, padding=1, bias=False
)
self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
# 3x3, BN
self.b = nn.Conv2d(w_out, w_out, kernel_size=3, stride=1, padding=1, bias=False)
self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.b_bn.final_bn = True

def forward(self, x):
for layer in self.children():
x = layer(x)
return x


class ResBasicBlock(nn.Module):
"""Residual basic block: x + F(x), F = basic transform"""

def __init__(self, w_in, w_out, stride, bm=None, gw=None, se_r=None):
assert (
bm is None and gw is None and se_r is None
), "Basic transform does not support bm, gw, and se_r options"
super(ResBasicBlock, self).__init__()
self._construct(w_in, w_out, stride)

def _add_skip_proj(self, w_in, w_out, stride):
self.proj = nn.Conv2d(
w_in, w_out, kernel_size=1, stride=stride, padding=0, bias=False
)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)

def _construct(self, w_in, w_out, stride):
# Use skip connection with projection if shape changes
self.proj_block = (w_in != w_out) or (stride != 1)
if self.proj_block:
self._add_skip_proj(w_in, w_out, stride)
self.f = BasicTransform(w_in, w_out, stride)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)

def forward(self, x):
if self.proj_block:
x = self.bn(self.proj(x)) + self.f(x)
else:
x = x + self.f(x)
x = self.relu(x)
return x


class SE(nn.Module):
"""Squeeze-and-Excitation (SE) block"""

def __init__(self, w_in, w_se):
super(SE, self).__init__()
self._construct(w_in, w_se)

def _construct(self, w_in, w_se):
# AvgPool
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
# FC, Activation, FC, Sigmoid
self.f_ex = nn.Sequential(
nn.Conv2d(w_in, w_se, kernel_size=1, bias=True),
nn.ReLU(inplace=cfg.MEM.RELU_INPLACE),
nn.Conv2d(w_se, w_in, kernel_size=1, bias=True),
nn.Sigmoid(),
)

def forward(self, x):
return x * self.f_ex(self.avg_pool(x))


class BottleneckTransform(nn.Module):
"""Bottlenect transformation: 1x1, 3x3, 1x1"""

def __init__(self, w_in, w_out, stride, bm, gw, se_r):
super(BottleneckTransform, self).__init__()
self._construct(w_in, w_out, stride, bm, gw, se_r)

def _construct(self, w_in, w_out, stride, bm, gw, se_r):
# Compute the bottleneck width
w_b = int(round(w_out * bm))
# Compute the number of groups
num_gs = w_b // gw
# 1x1, BN, ReLU
self.a = nn.Conv2d(w_in, w_b, kernel_size=1, stride=1, padding=0, bias=False)
self.a_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
# 3x3, BN, ReLU
self.b = nn.Conv2d(
w_b, w_b, kernel_size=3, stride=stride, padding=1, groups=num_gs, bias=False
)
self.b_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
# Squeeze-and-Excitation (SE)
if se_r:
w_se = int(round(w_in * se_r))
self.se = SE(w_b, w_se)
# 1x1, BN
self.c = nn.Conv2d(w_b, w_out, kernel_size=1, stride=1, padding=0, bias=False)
self.c_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.c_bn.final_bn = True

def forward(self, x):
for layer in self.children():
x = layer(x)
return x


class ResBottleneckBlock(nn.Module):
"""Residual bottleneck block: x + F(x), F = bottleneck transform"""

def __init__(self, w_in, w_out, stride, bm=1.0, gw=1, se_r=None):
super(ResBottleneckBlock, self).__init__()
self._construct(w_in, w_out, stride, bm, gw, se_r)

def _add_skip_proj(self, w_in, w_out, stride):
self.proj = nn.Conv2d(
w_in, w_out, kernel_size=1, stride=stride, padding=0, bias=False
)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)

def _construct(self, w_in, w_out, stride, bm, gw, se_r):
# Use skip connection with projection if shape changes
self.proj_block = (w_in != w_out) or (stride != 1)
if self.proj_block:
self._add_skip_proj(w_in, w_out, stride)
self.f = BottleneckTransform(w_in, w_out, stride, bm, gw, se_r)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)

def forward(self, x):
if self.proj_block:
x = self.bn(self.proj(x)) + self.f(x)
else:
x = x + self.f(x)
x = self.relu(x)
return x


class ResStemCifar(nn.Module):
"""ResNet stem for CIFAR."""

def __init__(self, w_in, w_out):
super(ResStemCifar, self).__init__()
self._construct(w_in, w_out)

def _construct(self, w_in, w_out):
# 3x3, BN, ReLU
self.conv = nn.Conv2d(
w_in, w_out, kernel_size=3, stride=1, padding=1, bias=False
)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)

def forward(self, x):
for layer in self.children():
x = layer(x)
return x


class ResStemIN(nn.Module):
"""ResNet stem for ImageNet."""

def __init__(self, w_in, w_out):
super(ResStemIN, self).__init__()
self._construct(w_in, w_out)

def _construct(self, w_in, w_out):
# 7x7, BN, ReLU, maxpool
self.conv = nn.Conv2d(
w_in, w_out, kernel_size=7, stride=2, padding=3, bias=False
)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

def forward(self, x):
for layer in self.children():
x = layer(x)
return x


class SimpleStemIN(nn.Module):
"""Simple stem for ImageNet."""

def __init__(self, in_w, out_w):
super(SimpleStemIN, self).__init__()
self._construct(in_w, out_w)

def _construct(self, in_w, out_w):
# 3x3, BN, ReLU
self.conv = nn.Conv2d(
in_w, out_w, kernel_size=3, stride=2, padding=1, bias=False
)
self.bn = nn.BatchNorm2d(out_w, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)

def forward(self, x):
for layer in self.children():
x = layer(x)
return x


class AnyStage(nn.Module):
"""AnyNet stage (sequence of blocks w/ the same output shape)."""

def __init__(self, w_in, w_out, stride, d, block_fun, bm, gw, se_r):
super(AnyStage, self).__init__()
self._construct(w_in, w_out, stride, d, block_fun, bm, gw, se_r)

def _construct(self, w_in, w_out, stride, d, block_fun, bm, gw, se_r):
# Construct the blocks
for i in range(d):
# Stride and w_in apply to the first block of the stage
b_stride = stride if i == 0 else 1
b_w_in = w_in if i == 0 else w_out
# Construct the block
self.add_module(
"b{}".format(i + 1), block_fun(b_w_in, w_out, b_stride, bm, gw, se_r)
)

def forward(self, x):
for block in self.children():
x = block(x)
return x


class AnyNet(nn.Module):
"""AnyNet model."""

def __init__(self, **kwargs):
super(AnyNet, self).__init__()
if kwargs:
self._construct(
stem_type=kwargs["stem_type"],
stem_w=kwargs["stem_w"],
block_type=kwargs["block_type"],
ds=kwargs["ds"],
ws=kwargs["ws"],
ss=kwargs["ss"],
bms=kwargs["bms"],
gws=kwargs["gws"],
se_r=kwargs["se_r"],
nc=kwargs["nc"],
)
else:
self._construct(
stem_type=cfg.ANYNET.STEM_TYPE,
stem_w=cfg.ANYNET.STEM_W,
block_type=cfg.ANYNET.BLOCK_TYPE,
ds=cfg.ANYNET.DEPTHS,
ws=cfg.ANYNET.WIDTHS,
ss=cfg.ANYNET.STRIDES,
bms=cfg.ANYNET.BOT_MULS,
gws=cfg.ANYNET.GROUP_WS,
se_r=cfg.ANYNET.SE_R if cfg.ANYNET.SE_ON else None,
nc=cfg.MODEL.NUM_CLASSES,
)
self.apply(nu.init_weights)

def _construct(self, stem_type, stem_w, block_type, ds, ws, ss, bms, gws, se_r, nc):
logger.info("Constructing AnyNet: ds={}, ws={}".format(ds, ws))
# Generate dummy bot muls and gs for models that do not use them
bms = bms if bms else [1.0 for _d in ds]
gws = gws if gws else [1 for _d in ds]
# Group params by stage
stage_params = list(zip(ds, ws, ss, bms, gws))
# Construct the stem
stem_fun = get_stem_fun(stem_type)
self.stem = stem_fun(3, stem_w)
# Construct the stages
block_fun = get_block_fun(block_type)
prev_w = stem_w
for i, (d, w, s, bm, gw) in enumerate(stage_params):
self.add_module(
"s{}".format(i + 1), AnyStage(prev_w, w, s, d, block_fun, bm, gw, se_r)
)
prev_w = w
# Construct the head
self.head = AnyHead(w_in=prev_w, nc=nc)

def forward(self, x):
for module in self.children():
x = module(x)
return x
235 changes: 235 additions & 0 deletions pycls/models/effnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""EfficientNet models."""

import pycls.utils.logging as logging
import pycls.utils.net as nu
import torch
import torch.nn as nn
from pycls.core.config import cfg


logger = logging.get_logger(__name__)


class EffHead(nn.Module):
"""EfficientNet head."""

def __init__(self, w_in, w_out, nc):
super(EffHead, self).__init__()
self._construct(w_in, w_out, nc)

def _construct(self, w_in, w_out, nc):
# 1x1, BN, Swish
self.conv = nn.Conv2d(
w_in, w_out, kernel_size=1, stride=1, padding=0, bias=False
)
self.conv_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.conv_swish = Swish()
# AvgPool
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
# Dropout
if cfg.EN.DROPOUT_RATIO > 0.0:
self.dropout = nn.Dropout(p=cfg.EN.DROPOUT_RATIO)
# FC
self.fc = nn.Linear(w_out, nc, bias=True)

def forward(self, x):
x = self.conv_swish(self.conv_bn(self.conv(x)))
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
x = self.dropout(x) if hasattr(self, "dropout") else x
x = self.fc(x)
return x


class Swish(nn.Module):
"""Swish activation function: x * sigmoid(x)"""

def __init__(self):
super(Swish, self).__init__()

def forward(self, x):
return x * torch.sigmoid(x)


class SE(nn.Module):
"""Squeeze-and-Excitation (SE) block w/ Swish."""

def __init__(self, w_in, w_se):
super(SE, self).__init__()
self._construct(w_in, w_se)

def _construct(self, w_in, w_se):
# AvgPool
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
# FC, Swish, FC, Sigmoid
self.f_ex = nn.Sequential(
nn.Conv2d(w_in, w_se, kernel_size=1, bias=True),
Swish(),
nn.Conv2d(w_se, w_in, kernel_size=1, bias=True),
nn.Sigmoid(),
)

def forward(self, x):
return x * self.f_ex(self.avg_pool(x))


class MBConv(nn.Module):
"""Mobile inverted bottleneck block w/ SE (MBConv)."""

def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out):
super(MBConv, self).__init__()
self._construct(w_in, exp_r, kernel, stride, se_r, w_out)

def _construct(self, w_in, exp_r, kernel, stride, se_r, w_out):
# Expansion ratio is wrt the input width
self.exp = None
w_exp = int(w_in * exp_r)
# Include exp ops only if the exp ratio is different from 1
if w_exp != w_in:
# 1x1, BN, Swish
self.exp = nn.Conv2d(
w_in, w_exp, kernel_size=1, stride=1, padding=0, bias=False
)
self.exp_bn = nn.BatchNorm2d(w_exp, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.exp_swish = Swish()
# 3x3 dwise, BN, Swish
self.dwise = nn.Conv2d(
w_exp,
w_exp,
kernel_size=kernel,
stride=stride,
groups=w_exp,
bias=False,
# Hacky padding to preserve res (supports only 3x3 and 5x5)
padding=(1 if kernel == 3 else 2),
)
self.dwise_bn = nn.BatchNorm2d(w_exp, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.dwise_swish = Swish()
# Squeeze-and-Excitation (SE)
w_se = int(w_in * se_r)
self.se = SE(w_exp, w_se)
# 1x1, BN
self.lin_proj = nn.Conv2d(
w_exp, w_out, kernel_size=1, stride=1, padding=0, bias=False
)
self.lin_proj_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
# Skip connection if in and out shapes are the same (MN-V2 style)
self.has_skip = (stride == 1) and (w_in == w_out)

def forward(self, x):
f_x = x
# Expansion
if self.exp:
f_x = self.exp_swish(self.exp_bn(self.exp(f_x)))
# Depthwise
f_x = self.dwise_swish(self.dwise_bn(self.dwise(f_x)))
# SE
f_x = self.se(f_x)
# Linear projection
f_x = self.lin_proj_bn(self.lin_proj(f_x))
# Skip connection
if self.has_skip:
# Drop connect
if self.training and cfg.EN.DC_RATIO > 0.0:
f_x = nu.drop_connect(f_x, cfg.EN.DC_RATIO)
f_x = x + f_x
return f_x


class EffStage(nn.Module):
"""EfficientNet stage."""

def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out, d):
super(EffStage, self).__init__()
self._construct(w_in, exp_r, kernel, stride, se_r, w_out, d)

def _construct(self, w_in, exp_r, kernel, stride, se_r, w_out, d):
# Construct the blocks
for i in range(d):
# Stride and input width apply to the first block of the stage
b_stride = stride if i == 0 else 1
b_w_in = w_in if i == 0 else w_out
# Construct the block
self.add_module(
"b{}".format(i + 1),
MBConv(b_w_in, exp_r, kernel, b_stride, se_r, w_out),
)

def forward(self, x):
for block in self.children():
x = block(x)
return x


class StemIN(nn.Module):
"""EfficientNet stem for ImageNet."""

def __init__(self, w_in, w_out):
super(StemIN, self).__init__()
self._construct(w_in, w_out)

def _construct(self, w_in, w_out):
# 3x3, BN, Swish
self.conv = nn.Conv2d(
w_in, w_out, kernel_size=3, stride=2, padding=1, bias=False
)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.swish = Swish()

def forward(self, x):
for layer in self.children():
x = layer(x)
return x


class EffNet(nn.Module):
"""EfficientNet model."""

def __init__(self):
assert cfg.TRAIN.DATASET in [
"imagenet"
], "Training on {} is not supported".format(cfg.TRAIN.DATASET)
assert cfg.TEST.DATASET in [
"imagenet"
], "Testing on {} is not supported".format(cfg.TEST.DATASET)
super(EffNet, self).__init__()
self._construct(
stem_w=cfg.EN.STEM_W,
ds=cfg.EN.DEPTHS,
ws=cfg.EN.WIDTHS,
exp_rs=cfg.EN.EXP_RATIOS,
se_r=cfg.EN.SE_R,
ss=cfg.EN.STRIDES,
ks=cfg.EN.KERNELS,
head_w=cfg.EN.HEAD_W,
nc=cfg.MODEL.NUM_CLASSES,
)
self.apply(nu.init_weights)

def _construct(self, stem_w, ds, ws, exp_rs, se_r, ss, ks, head_w, nc):
# Group params by stage
stage_params = list(zip(ds, ws, exp_rs, ss, ks))
logger.info("Constructing: EfficientNet-{}".format(stage_params))
# Construct the stem
self.stem = StemIN(3, stem_w)
prev_w = stem_w
# Construct the stages
for i, (d, w, exp_r, stride, kernel) in enumerate(stage_params):
self.add_module(
"s{}".format(i + 1), EffStage(prev_w, exp_r, kernel, stride, se_r, w, d)
)
prev_w = w
# Construct the head
self.head = EffHead(prev_w, head_w, nc)

def forward(self, x):
for module in self.children():
x = module(x)
return x
86 changes: 86 additions & 0 deletions pycls/models/regnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""RegNet models."""

import numpy as np
import pycls.utils.logging as lu
from pycls.core.config import cfg
from pycls.models.anynet import AnyNet


logger = lu.get_logger(__name__)


def quantize_float(f, q):
"""Converts a float to closest non-zero int divisible by q."""
return int(round(f / q) * q)


def adjust_ws_gs_comp(ws, bms, gs):
"""Adjusts the compatibility of widths and groups."""
ws_bot = [int(w * b) for w, b in zip(ws, bms)]
gs = [min(g, w_bot) for g, w_bot in zip(gs, ws_bot)]
ws_bot = [quantize_float(w_bot, g) for w_bot, g in zip(ws_bot, gs)]
ws = [int(w_bot / b) for w_bot, b in zip(ws_bot, bms)]
return ws, gs


def get_stages_from_blocks(ws, rs):
"""Gets ws/ds of network at each stage from per block values."""
ts_temp = zip(ws + [0], [0] + ws, rs + [0], [0] + rs)
ts = [w != wp or r != rp for w, wp, r, rp in ts_temp]
s_ws = [w for w, t in zip(ws, ts[:-1]) if t]
s_ds = np.diff([d for d, t in zip(range(len(ts)), ts) if t]).tolist()
return s_ws, s_ds


def generate_regnet(w_a, w_0, w_m, d, q=8):
"""Generates per block ws from RegNet parameters."""
assert w_a >= 0 and w_0 > 0 and w_m > 1 and w_0 % q == 0
ws_cont = np.arange(d) * w_a + w_0
ks = np.round(np.log(ws_cont / w_0) / np.log(w_m))
ws = w_0 * np.power(w_m, ks)
ws = np.round(np.divide(ws, q)) * q
num_stages, max_stage = len(np.unique(ws)), ks.max() + 1
ws, ws_cont = ws.astype(int).tolist(), ws_cont.tolist()
return ws, num_stages, max_stage, ws_cont


class RegNet(AnyNet):
"""RegNet model."""

def __init__(self):
# Generate RegNet ws per block
b_ws, num_s, _, _ = generate_regnet(
cfg.REGNET.WA, cfg.REGNET.W0, cfg.REGNET.WM, cfg.REGNET.DEPTH
)
# Convert to per stage format
ws, ds = get_stages_from_blocks(b_ws, b_ws)
# Generate group widths and bot muls
gws = [cfg.REGNET.GROUP_W for _ in range(num_s)]
bms = [cfg.REGNET.BOT_MUL for _ in range(num_s)]
# Adjust the compatibility of ws and gws
ws, gws = adjust_ws_gs_comp(ws, bms, gws)
# Use the same stride for each stage
ss = [cfg.REGNET.STRIDE for _ in range(num_s)]
# Use SE for RegNetY
se_r = cfg.REGNET.SE_R if cfg.REGNET.SE_ON else None
# Construct the model
kwargs = {
"stem_type": cfg.REGNET.STEM_TYPE,
"stem_w": cfg.REGNET.STEM_W,
"block_type": cfg.REGNET.BLOCK_TYPE,
"ss": ss,
"ds": ds,
"ws": ws,
"bms": bms,
"gws": gws,
"se_r": se_r,
"nc": cfg.MODEL.NUM_CLASSES,
}
super(RegNet, self).__init__(**kwargs)
275 changes: 275 additions & 0 deletions pycls/models/resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""ResNe(X)t models."""

import pycls.utils.logging as lu
import pycls.utils.net as nu
import torch.nn as nn
from pycls.core.config import cfg


logger = lu.get_logger(__name__)


# Stage depths for ImageNet models
_IN_STAGE_DS = {50: (3, 4, 6, 3), 101: (3, 4, 23, 3), 152: (3, 8, 36, 3)}


def get_trans_fun(name):
"""Retrieves the transformation function by name."""
trans_funs = {
"basic_transform": BasicTransform,
"bottleneck_transform": BottleneckTransform,
}
assert (
name in trans_funs.keys()
), "Transformation function '{}' not supported".format(name)
return trans_funs[name]


class ResHead(nn.Module):
"""ResNet head."""

def __init__(self, w_in, nc):
super(ResHead, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(w_in, nc, bias=True)

def forward(self, x):
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x


class BasicTransform(nn.Module):
"""Basic transformation: 3x3, 3x3"""

def __init__(self, w_in, w_out, stride, w_b=None, num_gs=1):
assert (
w_b is None and num_gs == 1
), "Basic transform does not support w_b and num_gs options"
super(BasicTransform, self).__init__()
self._construct(w_in, w_out, stride)

def _construct(self, w_in, w_out, stride):
# 3x3, BN, ReLU
self.a = nn.Conv2d(
w_in, w_out, kernel_size=3, stride=stride, padding=1, bias=False
)
self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
# 3x3, BN
self.b = nn.Conv2d(w_out, w_out, kernel_size=3, stride=1, padding=1, bias=False)
self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.b_bn.final_bn = True

def forward(self, x):
for layer in self.children():
x = layer(x)
return x


class BottleneckTransform(nn.Module):
"""Bottleneck transformation: 1x1, 3x3, 1x1"""

def __init__(self, w_in, w_out, stride, w_b, num_gs):
super(BottleneckTransform, self).__init__()
self._construct(w_in, w_out, stride, w_b, num_gs)

def _construct(self, w_in, w_out, stride, w_b, num_gs):
# MSRA -> stride=2 is on 1x1; TH/C2 -> stride=2 is on 3x3
(str1x1, str3x3) = (stride, 1) if cfg.RESNET.STRIDE_1X1 else (1, stride)
# 1x1, BN, ReLU
self.a = nn.Conv2d(
w_in, w_b, kernel_size=1, stride=str1x1, padding=0, bias=False
)
self.a_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
# 3x3, BN, ReLU
self.b = nn.Conv2d(
w_b, w_b, kernel_size=3, stride=str3x3, padding=1, groups=num_gs, bias=False
)
self.b_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
# 1x1, BN
self.c = nn.Conv2d(w_b, w_out, kernel_size=1, stride=1, padding=0, bias=False)
self.c_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.c_bn.final_bn = True

def forward(self, x):
for layer in self.children():
x = layer(x)
return x


class ResBlock(nn.Module):
"""Residual block: x + F(x)"""

def __init__(self, w_in, w_out, stride, trans_fun, w_b=None, num_gs=1):
super(ResBlock, self).__init__()
self._construct(w_in, w_out, stride, trans_fun, w_b, num_gs)

def _add_skip_proj(self, w_in, w_out, stride):
self.proj = nn.Conv2d(
w_in, w_out, kernel_size=1, stride=stride, padding=0, bias=False
)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)

def _construct(self, w_in, w_out, stride, trans_fun, w_b, num_gs):
# Use skip connection with projection if shape changes
self.proj_block = (w_in != w_out) or (stride != 1)
if self.proj_block:
self._add_skip_proj(w_in, w_out, stride)
self.f = trans_fun(w_in, w_out, stride, w_b, num_gs)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)

def forward(self, x):
if self.proj_block:
x = self.bn(self.proj(x)) + self.f(x)
else:
x = x + self.f(x)
x = self.relu(x)
return x


class ResStage(nn.Module):
"""Stage of ResNet."""

def __init__(self, w_in, w_out, stride, d, w_b=None, num_gs=1):
super(ResStage, self).__init__()
self._construct(w_in, w_out, stride, d, w_b, num_gs)

def _construct(self, w_in, w_out, stride, d, w_b, num_gs):
# Construct the blocks
for i in range(d):
# Stride and w_in apply to the first block of the stage
b_stride = stride if i == 0 else 1
b_w_in = w_in if i == 0 else w_out
# Retrieve the transformation function
trans_fun = get_trans_fun(cfg.RESNET.TRANS_FUN)
# Construct the block
res_block = ResBlock(b_w_in, w_out, b_stride, trans_fun, w_b, num_gs)
self.add_module("b{}".format(i + 1), res_block)

def forward(self, x):
for block in self.children():
x = block(x)
return x


class ResStem(nn.Module):
"""Stem of ResNet."""

def __init__(self, w_in, w_out):
assert (
cfg.TRAIN.DATASET == cfg.TEST.DATASET
), "Train and test dataset must be the same for now"
super(ResStem, self).__init__()
if "cifar" in cfg.TRAIN.DATASET:
self._construct_cifar(w_in, w_out)
else:
self._construct_imagenet(w_in, w_out)

def _construct_cifar(self, w_in, w_out):
# 3x3, BN, ReLU
self.conv = nn.Conv2d(
w_in, w_out, kernel_size=3, stride=1, padding=1, bias=False
)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)

def _construct_imagenet(self, w_in, w_out):
# 7x7, BN, ReLU, maxpool
self.conv = nn.Conv2d(
w_in, w_out, kernel_size=7, stride=2, padding=3, bias=False
)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

def forward(self, x):
for layer in self.children():
x = layer(x)
return x


class ResNet(nn.Module):
"""ResNet model."""

def __init__(self):
assert cfg.TRAIN.DATASET in [
"cifar10",
"imagenet",
], "Training ResNet on {} is not supported".format(cfg.TRAIN.DATASET)
assert cfg.TEST.DATASET in [
"cifar10",
"imagenet",
], "Testing ResNet on {} is not supported".format(cfg.TEST.DATASET)
super(ResNet, self).__init__()
if "cifar" in cfg.TRAIN.DATASET:
self._construct_cifar()
else:
self._construct_imagenet()
self.apply(nu.init_weights)

def _construct_cifar(self):
assert (
cfg.MODEL.DEPTH - 2
) % 6 == 0, "Model depth should be of the format 6n + 2 for cifar"
logger.info("Constructing: ResNet-{}".format(cfg.MODEL.DEPTH))
# Each stage has the same number of blocks for cifar
d = int((cfg.MODEL.DEPTH - 2) / 6)
# Stem: (N, 3, 32, 32) -> (N, 16, 32, 32)
self.stem = ResStem(w_in=3, w_out=16)
# Stage 1: (N, 16, 32, 32) -> (N, 16, 32, 32)
self.s1 = ResStage(w_in=16, w_out=16, stride=1, d=d)
# Stage 2: (N, 16, 32, 32) -> (N, 32, 16, 16)
self.s2 = ResStage(w_in=16, w_out=32, stride=2, d=d)
# Stage 3: (N, 32, 16, 16) -> (N, 64, 8, 8)
self.s3 = ResStage(w_in=32, w_out=64, stride=2, d=d)
# Head: (N, 64, 8, 8) -> (N, num_classes)
self.head = ResHead(w_in=64, nc=cfg.MODEL.NUM_CLASSES)

def _construct_imagenet(self):
logger.info(
"Constructing: ResNe(X)t-{}-{}x{}, {}".format(
cfg.MODEL.DEPTH,
cfg.RESNET.NUM_GROUPS,
cfg.RESNET.WIDTH_PER_GROUP,
cfg.RESNET.TRANS_FUN,
)
)
# Retrieve the number of blocks per stage
(d1, d2, d3, d4) = _IN_STAGE_DS[cfg.MODEL.DEPTH]
# Compute the initial bottleneck width
num_gs = cfg.RESNET.NUM_GROUPS
w_b = cfg.RESNET.WIDTH_PER_GROUP * num_gs
# Stem: (N, 3, 224, 224) -> (N, 64, 56, 56)
self.stem = ResStem(w_in=3, w_out=64)
# Stage 1: (N, 64, 56, 56) -> (N, 256, 56, 56)
self.s1 = ResStage(w_in=64, w_out=256, stride=1, d=d1, w_b=w_b, num_gs=num_gs)
# Stage 2: (N, 256, 56, 56) -> (N, 512, 28, 28)
self.s2 = ResStage(
w_in=256, w_out=512, stride=2, d=d2, w_b=w_b * 2, num_gs=num_gs
)
# Stage 3: (N, 512, 56, 56) -> (N, 1024, 14, 14)
self.s3 = ResStage(
w_in=512, w_out=1024, stride=2, d=d3, w_b=w_b * 4, num_gs=num_gs
)
# Stage 4: (N, 1024, 14, 14) -> (N, 2048, 7, 7)
self.s4 = ResStage(
w_in=1024, w_out=2048, stride=2, d=d4, w_b=w_b * 8, num_gs=num_gs
)
# Head: (N, 2048, 7, 7) -> (N, num_classes)
self.head = ResHead(w_in=2048, nc=cfg.MODEL.NUM_CLASSES)

def forward(self, x):
for module in self.children():
x = module(x)
return x
Empty file added pycls/utils/__init__.py
Empty file.
Binary file added pycls/utils/__pycache__/__init__.cpython-37.pyc
Binary file not shown.
Binary file added pycls/utils/__pycache__/benchmark.cpython-37.pyc
Binary file not shown.
Binary file added pycls/utils/__pycache__/checkpoint.cpython-37.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added pycls/utils/__pycache__/io.cpython-37.pyc
Binary file not shown.
Binary file added pycls/utils/__pycache__/logging.cpython-37.pyc
Binary file not shown.
Binary file added pycls/utils/__pycache__/meters.cpython-37.pyc
Binary file not shown.
Binary file added pycls/utils/__pycache__/metrics.cpython-37.pyc
Binary file not shown.
Binary file not shown.
Binary file added pycls/utils/__pycache__/net.cpython-37.pyc
Binary file not shown.
Binary file added pycls/utils/__pycache__/timer.cpython-37.pyc
Binary file not shown.
89 changes: 89 additions & 0 deletions pycls/utils/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Functions for benchmarking networks."""

import pycls.utils.logging as lu
import torch
from pycls.core.config import cfg
from pycls.utils.timer import Timer


@torch.no_grad()
def compute_fw_test_time(model, inputs):
"""Computes forward test time (no grad, eval mode)."""
# Use eval mode
model.eval()
# Warm up the caches
for _cur_iter in range(cfg.PREC_TIME.WARMUP_ITER):
model(inputs)
# Make sure warmup kernels completed
torch.cuda.synchronize()
# Compute precise forward pass time
timer = Timer()
for _cur_iter in range(cfg.PREC_TIME.NUM_ITER):
timer.tic()
model(inputs)
torch.cuda.synchronize()
timer.toc()
# Make sure forward kernels completed
torch.cuda.synchronize()
return timer.average_time


def compute_fw_bw_time(model, loss_fun, inputs, labels):
"""Computes forward backward time."""
# Use train mode
model.train()
# Warm up the caches
for _cur_iter in range(cfg.PREC_TIME.WARMUP_ITER):
preds = model(inputs)
loss = loss_fun(preds, labels)
loss.backward()
# Make sure warmup kernels completed
torch.cuda.synchronize()
# Compute precise forward backward pass time
fw_timer = Timer()
bw_timer = Timer()
for _cur_iter in range(cfg.PREC_TIME.NUM_ITER):
# Forward
fw_timer.tic()
preds = model(inputs)
loss = loss_fun(preds, labels)
torch.cuda.synchronize()
fw_timer.toc()
# Backward
bw_timer.tic()
loss.backward()
torch.cuda.synchronize()
bw_timer.toc()
# Make sure forward backward kernels completed
torch.cuda.synchronize()
return fw_timer.average_time, bw_timer.average_time


def compute_precise_time(model, loss_fun):
"""Computes precise time."""
# Generate a dummy mini-batch
im_size = cfg.TRAIN.IM_SIZE
inputs = torch.rand(cfg.PREC_TIME.BATCH_SIZE, 3, im_size, im_size)
labels = torch.zeros(cfg.PREC_TIME.BATCH_SIZE, dtype=torch.int64)
# Copy the data to the GPU
inputs = inputs.cuda(non_blocking=False)
labels = labels.cuda(non_blocking=False)
# Compute precise time
fw_test_time = compute_fw_test_time(model, inputs)
fw_time, bw_time = compute_fw_bw_time(model, loss_fun, inputs, labels)
# Log precise time
lu.log_json_stats(
{
"prec_test_fw_time": fw_test_time,
"prec_train_fw_time": fw_time,
"prec_train_bw_time": bw_time,
"prec_train_fw_bw_time": fw_time + bw_time,
}
)
91 changes: 91 additions & 0 deletions pycls/utils/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Functions that handle saving and loading of checkpoints."""

import os

import pycls.utils.distributed as du
import torch
from pycls.core.config import cfg


# Common prefix for checkpoint file names
_NAME_PREFIX = "model_epoch_"
# Checkpoints directory name
_DIR_NAME = "checkpoints"


def get_checkpoint_dir():
"""Retrieves the location for storing checkpoints."""
return os.path.join(cfg.OUT_DIR, _DIR_NAME)


def get_checkpoint(epoch):
"""Retrieves the path to a checkpoint file."""
name = "{}{:04d}.pyth".format(_NAME_PREFIX, epoch)
return os.path.join(get_checkpoint_dir(), name)


def get_last_checkpoint():
"""Retrieves the most recent checkpoint (highest epoch number)."""
checkpoint_dir = get_checkpoint_dir()
# Checkpoint file names are in lexicographic order
checkpoints = [f for f in os.listdir(checkpoint_dir) if _NAME_PREFIX in f]
last_checkpoint_name = sorted(checkpoints)[-1]
return os.path.join(checkpoint_dir, last_checkpoint_name)


def has_checkpoint():
"""Determines if there are checkpoints available."""
checkpoint_dir = get_checkpoint_dir()
if not os.path.exists(checkpoint_dir):
return False
return any(_NAME_PREFIX in f for f in os.listdir(checkpoint_dir))


def is_checkpoint_epoch(cur_epoch):
"""Determines if a checkpoint should be saved on current epoch."""
return (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0


def save_checkpoint(model, optimizer, epoch):
"""Saves a checkpoint."""
# Save checkpoints only from the master process
if not du.is_master_proc():
return
# Ensure that the checkpoint dir exists
os.makedirs(get_checkpoint_dir(), exist_ok=True)
# Omit the DDP wrapper in the multi-gpu setting
sd = model.module.state_dict() if cfg.NUM_GPUS > 1 else model.state_dict()
# Record the state
checkpoint = {
"epoch": epoch,
"model_state": sd,
"optimizer_state": optimizer.state_dict(),
"cfg": cfg.dump(),
}
# Write the checkpoint
checkpoint_file = get_checkpoint(epoch + 1)
torch.save(checkpoint, checkpoint_file)
return checkpoint_file


def load_checkpoint(checkpoint_file, model, optimizer=None):
"""Loads the checkpoint from the given file."""
assert os.path.exists(checkpoint_file), "Checkpoint '{}' not found".format(
checkpoint_file
)
# Load the checkpoint on CPU to avoid GPU mem spike
checkpoint = torch.load(checkpoint_file, map_location="cpu")
# Account for the DDP wrapper in the multi-gpu setting
ms = model.module if cfg.NUM_GPUS > 1 else model
ms.load_state_dict(checkpoint["model_state"])
# Load the optimizer state (commonly not done when fine-tuning)
if optimizer:
optimizer.load_state_dict(checkpoint["optimizer_state"])
return checkpoint["epoch"]
61 changes: 61 additions & 0 deletions pycls/utils/distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Distributed helpers."""

import torch
from pycls.core.config import cfg


def is_master_proc():
"""Determines if the current process is the master process.
Master process is responsible for logging, writing and loading checkpoints.
In the multi GPU setting, we assign the master role to the rank 0 process.
When training using a single GPU, there is only one training processes
which is considered the master processes.
"""
return cfg.NUM_GPUS == 1 or torch.distributed.get_rank() == 0


def init_process_group(proc_rank, world_size):
"""Initializes the default process group."""
# Set the GPU to use
torch.cuda.set_device(proc_rank)
# Initialize the process group
torch.distributed.init_process_group(
backend=cfg.DIST_BACKEND,
init_method="tcp://{}:{}".format(cfg.HOST, cfg.PORT),
world_size=world_size,
rank=proc_rank,
)


def destroy_process_group():
"""Destroys the default process group."""
torch.distributed.destroy_process_group()


def scaled_all_reduce(tensors):
"""Performs the scaled all_reduce operation on the provided tensors.
The input tensors are modified in-place. Currently supports only the sum
reduction operator. The reduced values are scaled by the inverse size of
the process group (equivalent to cfg.NUM_GPUS).
"""
# Queue the reductions
reductions = []
for tensor in tensors:
reduction = torch.distributed.all_reduce(tensor, async_op=True)
reductions.append(reduction)
# Wait for reductions to finish
for reduction in reductions:
reduction.wait()
# Scale the results
for tensor in tensors:
tensor.mul_(1.0 / cfg.NUM_GPUS)
return tensors
59 changes: 59 additions & 0 deletions pycls/utils/error_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Multiprocessing error handler."""

import os
import signal
import threading


class ChildException(Exception):
"""Wraps an exception from a child process."""

def __init__(self, child_trace):
super(ChildException, self).__init__(child_trace)


class ErrorHandler(object):
"""Multiprocessing error handler (based on fairseq's).
Listens for errors in child processes and
propagates the tracebacks to the parent process.
"""

def __init__(self, error_queue):
# Shared error queue
self.error_queue = error_queue
# Children processes sharing the error queue
self.children_pids = []
# Start a thread listening to errors
self.error_listener = threading.Thread(target=self.listen, daemon=True)
self.error_listener.start()
# Register the signal handler
signal.signal(signal.SIGUSR1, self.signal_handler)

def add_child(self, pid):
"""Registers a child process."""
self.children_pids.append(pid)

def listen(self):
"""Listens for errors in the error queue."""
# Wait until there is an error in the queue
child_trace = self.error_queue.get()
# Put the error back for the signal handler
self.error_queue.put(child_trace)
# Invoke the signal handler
os.kill(os.getpid(), signal.SIGUSR1)

def signal_handler(self, _sig_num, _stack_frame):
"""Signal handler."""
# Kill children processes
for pid in self.children_pids:
os.kill(pid, signal.SIGINT)
# Propagate the error from the child process
raise ChildException(self.error_queue.get())
90 changes: 90 additions & 0 deletions pycls/utils/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""IO utilities (adapted from Detectron)"""

import logging
import os
import re
import sys
from urllib import request as urlrequest


logger = logging.getLogger(__name__)

_PYCLS_BASE_URL = "https://dl.fbaipublicfiles.com/pycls"


def cache_url(url_or_file, cache_dir):
"""Download the file specified by the URL to the cache_dir and return the
path to the cached file. If the argument is not a URL, simply return it as
is.
"""
is_url = re.match(r"^(?:http)s?://", url_or_file, re.IGNORECASE) is not None

if not is_url:
return url_or_file

url = url_or_file
assert url.startswith(_PYCLS_BASE_URL), (
"pycls only automatically caches URLs in the pycls S3 bucket: {}"
).format(_PYCLS_BASE_URL)

cache_file_path = url.replace(_PYCLS_BASE_URL, cache_dir)
if os.path.exists(cache_file_path):
return cache_file_path

cache_file_dir = os.path.dirname(cache_file_path)
if not os.path.exists(cache_file_dir):
os.makedirs(cache_file_dir)

logger.info("Downloading remote file {} to {}".format(url, cache_file_path))
download_url(url, cache_file_path)
return cache_file_path


def _progress_bar(count, total):
"""Report download progress.
Credit:
https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console/27871113
"""
bar_len = 60
filled_len = int(round(bar_len * count / float(total)))

percents = round(100.0 * count / float(total), 1)
bar = "=" * filled_len + "-" * (bar_len - filled_len)

sys.stdout.write(
" [{}] {}% of {:.1f}MB file \r".format(bar, percents, total / 1024 / 1024)
)
sys.stdout.flush()
if count >= total:
sys.stdout.write("\n")


def download_url(url, dst_file_path, chunk_size=8192, progress_hook=_progress_bar):
"""Download url and write it to dst_file_path.
Credit:
https://stackoverflow.com/questions/2028517/python-urllib2-progress-hook
"""
req = urlrequest.Request(url)
response = urlrequest.urlopen(req)
total_size = response.info().get("Content-Length").strip()
total_size = int(total_size)
bytes_so_far = 0

with open(dst_file_path, "wb") as f:
while 1:
chunk = response.read(chunk_size)
bytes_so_far += len(chunk)
if not chunk:
break
if progress_hook:
progress_hook(bytes_so_far, total_size)
f.write(chunk)

return bytes_so_far
100 changes: 100 additions & 0 deletions pycls/utils/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Logging."""

import builtins
import decimal
import logging
import os
import sys

import pycls.utils.distributed as du
import simplejson
from pycls.core.config import cfg


# Show filename and line number in logs
_FORMAT = "[%(filename)s: %(lineno)3d]: %(message)s"

# Log file name (for cfg.LOG_DEST = 'file')
_LOG_FILE = "stdout.log"

# Printed json stats lines will be tagged w/ this
_TAG = "json_stats: "


def _suppress_print():
"""Suppresses printing from the current process."""

def ignore(*_objects, _sep=" ", _end="\n", _file=sys.stdout, _flush=False):
pass

builtins.print = ignore


def setup_logging():
"""Sets up the logging."""
# Enable logging only for the master process
if du.is_master_proc():
# Clear the root logger to prevent any existing logging config
# (e.g. set by another module) from messing with our setup
logging.root.handlers = []
# Construct logging configuration
logging_config = {"level": logging.INFO, "format": _FORMAT}
# Log either to stdout or to a file
if cfg.LOG_DEST == "stdout":
logging_config["stream"] = sys.stdout
else:
logging_config["filename"] = os.path.join(cfg.OUT_DIR, _LOG_FILE)
# Configure logging
logging.basicConfig(**logging_config)
else:
_suppress_print()


def get_logger(name):
"""Retrieves the logger."""
return logging.getLogger(name)


def log_json_stats(stats):
"""Logs json stats."""
# Decimal + string workaround for having fixed len float vals in logs
stats = {
k: decimal.Decimal("{:.6f}".format(v)) if isinstance(v, float) else v
for k, v in stats.items()
}
json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True)
logger = get_logger(__name__)
logger.info("{:s}{:s}".format(_TAG, json_stats))


def load_json_stats(log_file):
"""Loads json_stats from a single log file."""
with open(log_file, "r") as f:
lines = f.readlines()
json_lines = [l[l.find(_TAG) + len(_TAG) :] for l in lines if _TAG in l]
json_stats = [simplejson.loads(l) for l in json_lines]
return json_stats


def parse_json_stats(log, row_type, key):
"""Extract values corresponding to row_type/key out of log."""
vals = [row[key] for row in log if row["_type"] == row_type and key in row]
if key == "iter" or key == "epoch":
vals = [int(val.split("/")[0]) for val in vals]
return vals


def get_log_files(log_dir, name_filter=""):
"""Get all log files in directory containing subdirs of trained models."""
names = [n for n in sorted(os.listdir(log_dir)) if name_filter in n]
files = [os.path.join(log_dir, n, _LOG_FILE) for n in names]
f_n_ps = [(f, n) for (f, n) in zip(files, names) if os.path.exists(f)]
files, names = zip(*f_n_ps)
return files, names
47 changes: 47 additions & 0 deletions pycls/utils/lr_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Learning rate policies."""

import numpy as np
from pycls.core.config import cfg


def lr_fun_steps(cur_epoch):
"""Steps schedule (cfg.OPTIM.LR_POLICY = 'steps')."""
ind = [i for i, s in enumerate(cfg.OPTIM.STEPS) if cur_epoch >= s][-1]
return cfg.OPTIM.BASE_LR * (cfg.OPTIM.LR_MULT ** ind)


def lr_fun_exp(cur_epoch):
"""Exponential schedule (cfg.OPTIM.LR_POLICY = 'exp')."""
return cfg.OPTIM.BASE_LR * (cfg.OPTIM.GAMMA ** cur_epoch)


def lr_fun_cos(cur_epoch):
"""Cosine schedule (cfg.OPTIM.LR_POLICY = 'cos')."""
base_lr, max_epoch = cfg.OPTIM.BASE_LR, cfg.OPTIM.MAX_EPOCH
return 0.5 * base_lr * (1.0 + np.cos(np.pi * cur_epoch / max_epoch))


def get_lr_fun():
"""Retrieves the specified lr policy function"""
lr_fun = "lr_fun_" + cfg.OPTIM.LR_POLICY
if lr_fun not in globals():
raise NotImplementedError("Unknown LR policy:" + cfg.OPTIM.LR_POLICY)
return globals()[lr_fun]


def get_epoch_lr(cur_epoch):
"""Retrieves the lr for the given epoch according to the policy."""
lr = get_lr_fun()(cur_epoch)
# Linear warmup
if cur_epoch < cfg.OPTIM.WARMUP_EPOCHS:
alpha = cur_epoch / cfg.OPTIM.WARMUP_EPOCHS
warmup_factor = cfg.OPTIM.WARMUP_FACTOR * (1.0 - alpha) + alpha
lr *= warmup_factor
return lr
239 changes: 239 additions & 0 deletions pycls/utils/meters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Meters."""

import datetime
from collections import deque

import numpy as np
import pycls.utils.logging as lu
import pycls.utils.metrics as metrics
from pycls.core.config import cfg
from pycls.utils.timer import Timer


def eta_str(eta_td):
"""Converts an eta timedelta to a fixed-width string format."""
days = eta_td.days
hrs, rem = divmod(eta_td.seconds, 3600)
mins, secs = divmod(rem, 60)
return "{0:02},{1:02}:{2:02}:{3:02}".format(days, hrs, mins, secs)


class ScalarMeter(object):
"""Measures a scalar value (adapted from Detectron)."""

def __init__(self, window_size):
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0

def reset(self):
self.deque.clear()
self.total = 0.0
self.count = 0

def add_value(self, value):
self.deque.append(value)
self.count += 1
self.total += value

def get_win_median(self):
return np.median(self.deque)

def get_win_avg(self):
return np.mean(self.deque)

def get_global_avg(self):
return self.total / self.count


class TrainMeter(object):
"""Measures training stats."""

def __init__(self, epoch_iters):
self.epoch_iters = epoch_iters
self.max_iter = cfg.OPTIM.MAX_EPOCH * epoch_iters
self.iter_timer = Timer()
self.loss = ScalarMeter(cfg.LOG_PERIOD)
self.loss_total = 0.0
self.lr = None
# Current minibatch errors (smoothed over a window)
self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD)
self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD)
# Number of misclassified examples
self.num_top1_mis = 0
self.num_top5_mis = 0
self.num_samples = 0

def reset(self, timer=False):
if timer:
self.iter_timer.reset()
self.loss.reset()
self.loss_total = 0.0
self.lr = None
self.mb_top1_err.reset()
self.mb_top5_err.reset()
self.num_top1_mis = 0
self.num_top5_mis = 0
self.num_samples = 0

def iter_tic(self):
self.iter_timer.tic()

def iter_toc(self):
self.iter_timer.toc()

def update_stats(self, top1_err, top5_err, loss, lr, mb_size):
# Current minibatch stats
self.mb_top1_err.add_value(top1_err)
self.mb_top5_err.add_value(top5_err)
self.loss.add_value(loss)
self.lr = lr
# Aggregate stats
self.num_top1_mis += top1_err * mb_size
self.num_top5_mis += top5_err * mb_size
self.loss_total += loss * mb_size
self.num_samples += mb_size

def get_iter_stats(self, cur_epoch, cur_iter):
eta_sec = self.iter_timer.average_time * (
self.max_iter - (cur_epoch * self.epoch_iters + cur_iter + 1)
)
eta_td = datetime.timedelta(seconds=int(eta_sec))
mem_usage = metrics.gpu_mem_usage()
stats = {
"_type": "train_iter",
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
"iter": "{}/{}".format(cur_iter + 1, self.epoch_iters),
"time_avg": self.iter_timer.average_time,
"time_diff": self.iter_timer.diff,
"eta": eta_str(eta_td),
"top1_err": self.mb_top1_err.get_win_median(),
"top5_err": self.mb_top5_err.get_win_median(),
"loss": self.loss.get_win_median(),
"lr": self.lr,
"mem": int(np.ceil(mem_usage)),
}
return stats

def log_iter_stats(self, cur_epoch, cur_iter):
if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
return
stats = self.get_iter_stats(cur_epoch, cur_iter)
lu.log_json_stats(stats)

def get_epoch_stats(self, cur_epoch):
eta_sec = self.iter_timer.average_time * (
self.max_iter - (cur_epoch + 1) * self.epoch_iters
)
eta_td = datetime.timedelta(seconds=int(eta_sec))
mem_usage = metrics.gpu_mem_usage()
top1_err = self.num_top1_mis / self.num_samples
top5_err = self.num_top5_mis / self.num_samples
avg_loss = self.loss_total / self.num_samples
stats = {
"_type": "train_epoch",
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
"time_avg": self.iter_timer.average_time,
"eta": eta_str(eta_td),
"top1_err": top1_err,
"top5_err": top5_err,
"loss": avg_loss,
"lr": self.lr,
"mem": int(np.ceil(mem_usage)),
}
return stats

def log_epoch_stats(self, cur_epoch):
stats = self.get_epoch_stats(cur_epoch)
lu.log_json_stats(stats)


class TestMeter(object):
"""Measures testing stats."""

def __init__(self, max_iter):
self.max_iter = max_iter
self.iter_timer = Timer()
# Current minibatch errors (smoothed over a window)
self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD)
self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD)
# Min errors (over the full test set)
self.min_top1_err = 100.0
self.min_top5_err = 100.0
# Number of misclassified examples
self.num_top1_mis = 0
self.num_top5_mis = 0
self.num_samples = 0

def reset(self, min_errs=False):
if min_errs:
self.min_top1_err = 100.0
self.min_top5_err = 100.0
self.iter_timer.reset()
self.mb_top1_err.reset()
self.mb_top5_err.reset()
self.num_top1_mis = 0
self.num_top5_mis = 0
self.num_samples = 0

def iter_tic(self):
self.iter_timer.tic()

def iter_toc(self):
self.iter_timer.toc()

def update_stats(self, top1_err, top5_err, mb_size):
self.mb_top1_err.add_value(top1_err)
self.mb_top5_err.add_value(top5_err)
self.num_top1_mis += top1_err * mb_size
self.num_top5_mis += top5_err * mb_size
self.num_samples += mb_size

def get_iter_stats(self, cur_epoch, cur_iter):
mem_usage = metrics.gpu_mem_usage()
iter_stats = {
"_type": "test_iter",
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
"iter": "{}/{}".format(cur_iter + 1, self.max_iter),
"time_avg": self.iter_timer.average_time,
"time_diff": self.iter_timer.diff,
"top1_err": self.mb_top1_err.get_win_median(),
"top5_err": self.mb_top5_err.get_win_median(),
"mem": int(np.ceil(mem_usage)),
}
return iter_stats

def log_iter_stats(self, cur_epoch, cur_iter):
if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
return
stats = self.get_iter_stats(cur_epoch, cur_iter)
lu.log_json_stats(stats)

def get_epoch_stats(self, cur_epoch):
top1_err = self.num_top1_mis / self.num_samples
top5_err = self.num_top5_mis / self.num_samples
self.min_top1_err = min(self.min_top1_err, top1_err)
self.min_top5_err = min(self.min_top5_err, top5_err)
mem_usage = metrics.gpu_mem_usage()
stats = {
"_type": "test_epoch",
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
"time_avg": self.iter_timer.average_time,
"top1_err": top1_err,
"top5_err": top5_err,
"min_top1_err": self.min_top1_err,
"min_top5_err": self.min_top5_err,
"mem": int(np.ceil(mem_usage)),
}
return stats

def log_epoch_stats(self, cur_epoch):
stats = self.get_epoch_stats(cur_epoch)
lu.log_json_stats(stats)
104 changes: 104 additions & 0 deletions pycls/utils/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Functions for computing metrics."""

import numpy as np
import torch
import torch.nn as nn
from pycls.core.config import cfg


# Number of bytes in a megabyte
_B_IN_MB = 1024 * 1024


def topks_correct(preds, labels, ks):
"""Computes the number of top-k correct predictions for each k."""
assert preds.size(0) == labels.size(
0
), "Batch dim of predictions and labels must match"
# Find the top max_k predictions for each sample
_top_max_k_vals, top_max_k_inds = torch.topk(
preds, max(ks), dim=1, largest=True, sorted=True
)
# (batch_size, max_k) -> (max_k, batch_size)
top_max_k_inds = top_max_k_inds.t()
# (batch_size, ) -> (max_k, batch_size)
rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds)
# (i, j) = 1 if top i-th prediction for the j-th sample is correct
top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels)
# Compute the number of topk correct predictions for each k
topks_correct = [top_max_k_correct[:k, :].view(-1).float().sum() for k in ks]
return topks_correct


def topk_errors(preds, labels, ks):
"""Computes the top-k error for each k."""
num_topks_correct = topks_correct(preds, labels, ks)
return [(1.0 - x / preds.size(0)) * 100.0 for x in num_topks_correct]


def topk_accuracies(preds, labels, ks):
"""Computes the top-k accuracy for each k."""
num_topks_correct = topks_correct(preds, labels, ks)
return [(x / preds.size(0)) * 100.0 for x in num_topks_correct]


def params_count(model):
"""Computes the number of parameters."""
return np.sum([p.numel() for p in model.parameters()]).item()


def flops_count(model):
"""Computes the number of flops statically."""
h, w = cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE
count = 0
for n, m in model.named_modules():
if isinstance(m, nn.Conv2d):
if "se." in n:
count += m.in_channels * m.out_channels + m.bias.numel()
continue
h_out = (h + 2 * m.padding[0] - m.kernel_size[0]) // m.stride[0] + 1
w_out = (w + 2 * m.padding[1] - m.kernel_size[1]) // m.stride[1] + 1
count += np.prod([m.weight.numel(), h_out, w_out])
if ".proj" not in n:
h, w = h_out, w_out
elif isinstance(m, nn.MaxPool2d):
h = (h + 2 * m.padding - m.kernel_size) // m.stride + 1
w = (w + 2 * m.padding - m.kernel_size) // m.stride + 1
elif isinstance(m, nn.Linear):
count += m.in_features * m.out_features + m.bias.numel()
return count.item()


def acts_count(model):
"""Computes the number of activations statically."""
h, w = cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE
count = 0
for n, m in model.named_modules():
if isinstance(m, nn.Conv2d):
if "se." in n:
count += m.out_channels
continue
h_out = (h + 2 * m.padding[0] - m.kernel_size[0]) // m.stride[0] + 1
w_out = (w + 2 * m.padding[1] - m.kernel_size[1]) // m.stride[1] + 1
count += np.prod([m.out_channels, h_out, w_out])
if ".proj" not in n:
h, w = h_out, w_out
elif isinstance(m, nn.MaxPool2d):
h = (h + 2 * m.padding - m.kernel_size) // m.stride + 1
w = (w + 2 * m.padding - m.kernel_size) // m.stride + 1
elif isinstance(m, nn.Linear):
count += m.out_features
return count.item()


def gpu_mem_usage():
"""Computes the GPU memory usage for the current device (MB)."""
mem_usage_bytes = torch.cuda.max_memory_allocated()
return mem_usage_bytes / _B_IN_MB
57 changes: 57 additions & 0 deletions pycls/utils/multiprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Multiprocessing helpers."""

import multiprocessing as mp
import traceback

import pycls.utils.distributed as du
from pycls.utils.error_handler import ErrorHandler


def run(proc_rank, world_size, error_queue, fun, fun_args, fun_kwargs):
"""Runs a function from a child process."""
try:
# Initialize the process group
du.init_process_group(proc_rank, world_size)
# Run the function
fun(*fun_args, **fun_kwargs)
except KeyboardInterrupt:
# Killed by the parent process
pass
except Exception:
# Propagate exception to the parent process
error_queue.put(traceback.format_exc())
finally:
# Destroy the process group
du.destroy_process_group()


def multi_proc_run(num_proc, fun, fun_args=(), fun_kwargs=None):
"""Runs a function in a multi-proc setting."""

if fun_kwargs is None:
fun_kwargs = {}

# Handle errors from training subprocesses
error_queue = mp.SimpleQueue()
error_handler = ErrorHandler(error_queue)

# Run each training subprocess
ps = []
for i in range(num_proc):
p_i = mp.Process(
target=run, args=(i, num_proc, error_queue, fun, fun_args, fun_kwargs)
)
ps.append(p_i)
p_i.start()
error_handler.add_child(p_i.pid)

# Wait for each subprocess to finish
for p in ps:
p.join()
94 changes: 94 additions & 0 deletions pycls/utils/net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Functions for manipulating networks."""

import itertools
import math

import torch
import torch.nn as nn
from pycls.core.config import cfg


def init_weights(m):
"""Performs ResNet-style weight initialization."""
if isinstance(m, nn.Conv2d):
# Note that there is no bias due to BN
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out))
elif isinstance(m, nn.BatchNorm2d):
zero_init_gamma = (
hasattr(m, "final_bn") and m.final_bn and cfg.BN.ZERO_INIT_FINAL_GAMMA
)
m.weight.data.fill_(0.0 if zero_init_gamma else 1.0)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(mean=0.0, std=0.01)
m.bias.data.zero_()


@torch.no_grad()
def compute_precise_bn_stats(model, loader):
"""Computes precise BN stats on training data."""
# Compute the number of minibatches to use
num_iter = min(cfg.BN.NUM_SAMPLES_PRECISE // loader.batch_size, len(loader))
# Retrieve the BN layers
bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
# Initialize stats storage
mus = [torch.zeros_like(bn.running_mean) for bn in bns]
sqs = [torch.zeros_like(bn.running_var) for bn in bns]
# Remember momentum values
moms = [bn.momentum for bn in bns]
# Disable momentum
for bn in bns:
bn.momentum = 1.0
# Accumulate the stats across the data samples
for inputs, _labels in itertools.islice(loader, num_iter):
model(inputs.cuda())
# Accumulate the stats for each BN layer
for i, bn in enumerate(bns):
m, v = bn.running_mean, bn.running_var
sqs[i] += (v + m * m) / num_iter
mus[i] += m / num_iter
# Set the stats and restore momentum values
for i, bn in enumerate(bns):
bn.running_var = sqs[i] - mus[i] * mus[i]
bn.running_mean = mus[i]
bn.momentum = moms[i]


def reset_bn_stats(model):
"""Resets running BN stats."""
for m in model.modules():
if isinstance(m, torch.nn.BatchNorm2d):
m.reset_running_stats()


def drop_connect(x, drop_ratio):
"""Drop connect (adapted from DARTS)."""
keep_ratio = 1.0 - drop_ratio
mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device)
mask.bernoulli_(keep_ratio)
x.div_(keep_ratio)
x.mul_(mask)
return x


def get_flat_weights(model):
"""Gets all model weights as a single flat vector."""
return torch.cat([p.data.view(-1, 1) for p in model.parameters()], 0)


def set_flat_weights(model, flat_weights):
"""Sets all model weights from a single flat vector."""
k = 0
for p in model.parameters():
n = p.data.numel()
p.data.copy_(flat_weights[k : (k + n)].view_as(p.data))
k += n
assert k == flat_weights.numel()
132 changes: 132 additions & 0 deletions pycls/utils/plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Plotting functions."""

import colorlover as cl
import matplotlib.pyplot as plt
import plotly.graph_objs as go
import plotly.offline as offline
import pycls.utils.logging as lu


def get_plot_colors(max_colors, color_format="pyplot"):
"""Generate colors for plotting."""
colors = cl.scales["11"]["qual"]["Paired"]
if max_colors > len(colors):
colors = cl.to_rgb(cl.interp(colors, max_colors))
if color_format == "pyplot":
return [[j / 255.0 for j in c] for c in cl.to_numeric(colors)]
return colors


def prepare_plot_data(log_files, names, key="top1_err"):
"""Load logs and extract data for plotting error curves."""
plot_data = []
for file, name in zip(log_files, names):
d, log = {}, lu.load_json_stats(file)
for phase in ["train", "test"]:
x = lu.parse_json_stats(log, phase + "_epoch", "epoch")
y = lu.parse_json_stats(log, phase + "_epoch", key)
d["x_" + phase], d["y_" + phase] = x, y
d[phase + "_label"] = "[{:5.2f}] ".format(min(y) if y else 0) + name
plot_data.append(d)
assert len(plot_data) > 0, "No data to plot"
return plot_data


def plot_error_curves_plotly(log_files, names, filename, key="top1_err"):
"""Plot error curves using plotly and save to file."""
plot_data = prepare_plot_data(log_files, names, key)
colors = get_plot_colors(len(plot_data), "plotly")
# Prepare data for plots (3 sets, train duplicated w and w/o legend)
data = []
for i, d in enumerate(plot_data):
s = str(i)
line_train = {"color": colors[i], "dash": "dashdot", "width": 1.5}
line_test = {"color": colors[i], "dash": "solid", "width": 1.5}
data.append(
go.Scatter(
x=d["x_train"],
y=d["y_train"],
mode="lines",
name=d["train_label"],
line=line_train,
legendgroup=s,
visible=True,
showlegend=False,
)
)
data.append(
go.Scatter(
x=d["x_test"],
y=d["y_test"],
mode="lines",
name=d["test_label"],
line=line_test,
legendgroup=s,
visible=True,
showlegend=True,
)
)
data.append(
go.Scatter(
x=d["x_train"],
y=d["y_train"],
mode="lines",
name=d["train_label"],
line=line_train,
legendgroup=s,
visible=False,
showlegend=True,
)
)
# Prepare layout w ability to toggle 'all', 'train', 'test'
titlefont = {"size": 18, "color": "#7f7f7f"}
vis = [[True, True, False], [False, False, True], [False, True, False]]
buttons = zip(["all", "train", "test"], [[{"visible": v}] for v in vis])
buttons = [{"label": l, "args": v, "method": "update"} for l, v in buttons]
layout = go.Layout(
title=key + " vs. epoch<br>[dash=train, solid=test]",
xaxis={"title": "epoch", "titlefont": titlefont},
yaxis={"title": key, "titlefont": titlefont},
showlegend=True,
hoverlabel={"namelength": -1},
updatemenus=[
{
"buttons": buttons,
"direction": "down",
"showactive": True,
"x": 1.02,
"xanchor": "left",
"y": 1.08,
"yanchor": "top",
}
],
)
# Create plotly plot
offline.plot({"data": data, "layout": layout}, filename=filename)


def plot_error_curves_pyplot(log_files, names, filename=None, key="top1_err"):
"""Plot error curves using matplotlib.pyplot and save to file."""
plot_data = prepare_plot_data(log_files, names, key)
colors = get_plot_colors(len(names))
for ind, d in enumerate(plot_data):
c, lbl = colors[ind], d["test_label"]
plt.plot(d["x_train"], d["y_train"], "--", c=c, alpha=0.8)
plt.plot(d["x_test"], d["y_test"], "-", c=c, alpha=0.8, label=lbl)
plt.title(key + " vs. epoch\n[dash=train, solid=test]", fontsize=14)
plt.xlabel("epoch", fontsize=14)
plt.ylabel(key, fontsize=14)
plt.grid(alpha=0.4)
plt.legend()
if filename:
plt.savefig(filename)
plt.clf()
else:
plt.show()
35 changes: 35 additions & 0 deletions pycls/utils/timer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Timer."""

import time


class Timer(object):
"""A simple timer (adapted from Detectron)."""

def __init__(self):
self.reset()

def tic(self):
# using time.time instead of time.clock because time time.clock
# does not normalize for multithreading
self.start_time = time.time()

def toc(self):
self.diff = time.time() - self.start_time
self.total_time += self.diff
self.calls += 1
self.average_time = self.total_time / self.calls

def reset(self):
self.total_time = 0.0
self.calls = 0
self.start_time = 0.0
self.diff = 0.0
self.average_time = 0.0
78 changes: 78 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@

import torch
import argparse
import numpy as np

from PIL import Image

from RegDanbooru2019_8G import RegDanbooru2019

parser = argparse.ArgumentParser(description='Test RegDeepDanbooru')
parser.add_argument('--model', default='', type=str, help='trained model')
parser.add_argument('--image', default='', type=str, help='image to test')
parser.add_argument('--size', default=768, type=int, help='canvas size')
parser.add_argument('--threshold', default=0.5, type=float, help='threshold')
args = parser.parse_args()

DANBOORU_LABEL_MAP = {}

def load_danbooru_label_map() :
print(' -- Loading danbooru2019 labels')
global DANBOORU_LABEL_MAP
with open('danbooru_labels.txt', 'r') as fp :
for l in fp :
l = l.strip()
if l :
idx, tag = l.split(' ')
DANBOORU_LABEL_MAP[int(idx)] = tag

def test(model, image_resized) :
print(' -- Running model on GPU')
image_resized_torch = torch.from_numpy(image_resized).float() / 127.5 - 1.0
if len(image_resized_torch.shape) == 3 :
image_resized_torch = image_resized_torch.unsqueeze(0).permute(0, 3, 1, 2)
elif len(image_resized_torch.shape) == 4 :
image_resized_torch = image_resized_torch.permute(0, 3, 1, 2)
image_resized_torch = image_resized_torch.cuda()
with torch.no_grad() :
danbooru_logits = model(image_resized_torch)
danbooru = danbooru_logits.sigmoid().cpu()
return danbooru

def load_and_resize_image(img_path, canvas_size = 512) :
img = Image.open(img_path).convert('RGB')
old_size = img.size
w, h = old_size
w, h = float(w), float(h)
ratio = float(canvas_size) / max(old_size)
new_size = tuple([int(round(x * ratio)) for x in old_size])
print(f'Test image size: {new_size}')
return np.array(img.resize(new_size, Image.ANTIALIAS))

def translate_danbooru_labels(probs, threshold = 0.8) :
global DANBOORU_LABEL_MAP
choosen_indices = (probs > threshold).nonzero()
result = []
for i in range(probs.size(0)) :
prob_single = probs[0].numpy()
indices_single = choosen_indices[choosen_indices[:, 0] == i][:, 1].numpy()
tag_prob_map = {DANBOORU_LABEL_MAP[idx]: prob_single[idx] for idx in indices_single}
result.append(tag_prob_map)
return result

def main() :
model = RegDanbooru2019().cuda()
model.load_state_dict(torch.load(args.model)['model'])
model.eval()
torch.save(model, 'RegNetY-8G.pth',)

test_img = load_and_resize_image(args.image, args.size)

danbooru = test(model, test_img)

tags = translate_danbooru_labels(danbooru, args.threshold)
print(tags)

if __name__ == "__main__":
load_danbooru_label_map()
main()

0 comments on commit bf20f72

Please sign in to comment.