-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
49ff1fe
commit 24e66cc
Showing
6 changed files
with
604 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
We use ResNet and Swin as the backbone in our model. | ||
|
||
* `convert-torchvision-to-d2.py` | ||
|
||
Tool to convert torchvision pre-trained weights for D2. | ||
|
||
``` | ||
wget https://download.pytorch.org/models/resnet101-63fe2227.pth | ||
python tools/convert-torchvision-to-d2.py resnet101-63fe2227.pth R-101.pkl | ||
``` | ||
|
||
* `convert-pretrained-swin-model-to-d2.py` | ||
|
||
Tool to convert Swin Transformer pre-trained weights for D2. | ||
|
||
``` | ||
pip install timm | ||
wget https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth | ||
python tools/convert-pretrained-swin-model-to-d2.py swin_tiny_patch4_window7_224.pth swin_tiny_patch4_window7_224.pkl | ||
wget https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth | ||
python tools/convert-pretrained-swin-model-to-d2.py swin_small_patch4_window7_224.pth swin_small_patch4_window7_224.pkl | ||
wget https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth | ||
python tools/convert-pretrained-swin-model-to-d2.py swin_base_patch4_window12_384_22k.pth swin_base_patch4_window12_384_22k.pkl | ||
wget https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth | ||
python tools/convert-pretrained-swin-model-to-d2.py swin_large_patch4_window12_384_22k.pth swin_large_patch4_window12_384_22k.pkl | ||
``` | ||
|
||
|
||
* `analyze_model.py` | ||
|
||
Tool to analyze model parameters and flops. | ||
|
||
Usage for semantic segmentation (ADE20K only, use with caution!): | ||
|
||
``` | ||
python tools/analyze_model.py --num-inputs 1 --tasks flop --use-fixed-input-size --config-file CONFIG_FILE | ||
``` | ||
|
||
Note that, for semantic segmentation (ADE20K only), we use a dummy image with fixed size that equals to `cfg.INPUT.CROP.SIZE[0] x cfg.INPUT.CROP.SIZE[0]`. | ||
Please do not use `--use-fixed-input-size` for calculating FLOPs on other datasets like Cityscapes! | ||
|
||
Usage for panoptic and instance segmentation: | ||
|
||
``` | ||
python tools/analyze_model.py --num-inputs 100 --tasks flop --config-file CONFIG_FILE | ||
``` | ||
|
||
Note that, for panoptic and instance segmentation, we compute the average flops over 100 real validation images. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
# -*- coding: utf-8 -*- | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# Modified by Bowen Cheng from https://github.com/facebookresearch/detectron2/blob/main/tools/analyze_model.py | ||
|
||
import logging | ||
import numpy as np | ||
from collections import Counter | ||
import tqdm | ||
from fvcore.nn import flop_count_table # can also try flop_count_str | ||
|
||
from detectron2.checkpoint import DetectionCheckpointer | ||
from detectron2.config import CfgNode, LazyConfig, get_cfg, instantiate | ||
from detectron2.data import build_detection_test_loader | ||
from detectron2.engine import default_argument_parser | ||
from detectron2.modeling import build_model | ||
from detectron2.projects.deeplab import add_deeplab_config | ||
from detectron2.utils.analysis import ( | ||
FlopCountAnalysis, | ||
activation_count_operators, | ||
parameter_count_table, | ||
) | ||
from detectron2.utils.logger import setup_logger | ||
|
||
# fmt: off | ||
import os | ||
import sys | ||
sys.path.insert(1, os.path.join(sys.path[0], '..')) | ||
# fmt: on | ||
|
||
from mask2former import add_maskformer2_config | ||
|
||
logger = logging.getLogger("detectron2") | ||
|
||
|
||
def setup(args): | ||
if args.config_file.endswith(".yaml"): | ||
cfg = get_cfg() | ||
add_deeplab_config(cfg) | ||
add_maskformer2_config(cfg) | ||
cfg.merge_from_file(args.config_file) | ||
cfg.DATALOADER.NUM_WORKERS = 0 | ||
cfg.merge_from_list(args.opts) | ||
cfg.freeze() | ||
else: | ||
cfg = LazyConfig.load(args.config_file) | ||
cfg = LazyConfig.apply_overrides(cfg, args.opts) | ||
setup_logger(name="fvcore") | ||
setup_logger() | ||
return cfg | ||
|
||
|
||
def do_flop(cfg): | ||
if isinstance(cfg, CfgNode): | ||
data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0]) | ||
model = build_model(cfg) | ||
DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS) | ||
else: | ||
data_loader = instantiate(cfg.dataloader.test) | ||
model = instantiate(cfg.model) | ||
model.to(cfg.train.device) | ||
DetectionCheckpointer(model).load(cfg.train.init_checkpoint) | ||
model.eval() | ||
|
||
counts = Counter() | ||
total_flops = [] | ||
for idx, data in zip(tqdm.trange(args.num_inputs), data_loader): # noqa | ||
if args.use_fixed_input_size and isinstance(cfg, CfgNode): | ||
import torch | ||
crop_size = cfg.INPUT.CROP.SIZE[0] | ||
data[0]["image"] = torch.zeros((3, crop_size, crop_size)) | ||
flops = FlopCountAnalysis(model, data) | ||
if idx > 0: | ||
flops.unsupported_ops_warnings(False).uncalled_modules_warnings(False) | ||
counts += flops.by_operator() | ||
total_flops.append(flops.total()) | ||
|
||
logger.info("Flops table computed from only one input sample:\n" + flop_count_table(flops)) | ||
logger.info( | ||
"Average GFlops for each type of operators:\n" | ||
+ str([(k, v / (idx + 1) / 1e9) for k, v in counts.items()]) | ||
) | ||
logger.info( | ||
"Total GFlops: {:.1f}±{:.1f}".format(np.mean(total_flops) / 1e9, np.std(total_flops) / 1e9) | ||
) | ||
|
||
|
||
def do_activation(cfg): | ||
if isinstance(cfg, CfgNode): | ||
data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0]) | ||
model = build_model(cfg) | ||
DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS) | ||
else: | ||
data_loader = instantiate(cfg.dataloader.test) | ||
model = instantiate(cfg.model) | ||
model.to(cfg.train.device) | ||
DetectionCheckpointer(model).load(cfg.train.init_checkpoint) | ||
model.eval() | ||
|
||
counts = Counter() | ||
total_activations = [] | ||
for idx, data in zip(tqdm.trange(args.num_inputs), data_loader): # noqa | ||
count = activation_count_operators(model, data) | ||
counts += count | ||
total_activations.append(sum(count.values())) | ||
logger.info( | ||
"(Million) Activations for Each Type of Operators:\n" | ||
+ str([(k, v / idx) for k, v in counts.items()]) | ||
) | ||
logger.info( | ||
"Total (Million) Activations: {}±{}".format( | ||
np.mean(total_activations), np.std(total_activations) | ||
) | ||
) | ||
|
||
|
||
def do_parameter(cfg): | ||
if isinstance(cfg, CfgNode): | ||
model = build_model(cfg) | ||
else: | ||
model = instantiate(cfg.model) | ||
logger.info("Parameter Count:\n" + parameter_count_table(model, max_depth=5)) | ||
|
||
|
||
def do_structure(cfg): | ||
if isinstance(cfg, CfgNode): | ||
model = build_model(cfg) | ||
else: | ||
model = instantiate(cfg.model) | ||
logger.info("Model Structure:\n" + str(model)) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = default_argument_parser( | ||
epilog=""" | ||
Examples: | ||
To show parameters of a model: | ||
$ ./analyze_model.py --tasks parameter \\ | ||
--config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml | ||
Flops and activations are data-dependent, therefore inputs and model weights | ||
are needed to count them: | ||
$ ./analyze_model.py --num-inputs 100 --tasks flop \\ | ||
--config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml \\ | ||
MODEL.WEIGHTS /path/to/model.pkl | ||
""" | ||
) | ||
parser.add_argument( | ||
"--tasks", | ||
choices=["flop", "activation", "parameter", "structure"], | ||
required=True, | ||
nargs="+", | ||
) | ||
parser.add_argument( | ||
"-n", | ||
"--num-inputs", | ||
default=100, | ||
type=int, | ||
help="number of inputs used to compute statistics for flops/activations, " | ||
"both are data dependent.", | ||
) | ||
parser.add_argument( | ||
"--use-fixed-input-size", | ||
action="store_true", | ||
help="use fixed input size when calculating flops", | ||
) | ||
args = parser.parse_args() | ||
assert not args.eval_only | ||
assert args.num_gpus == 1 | ||
|
||
cfg = setup(args) | ||
|
||
for task in args.tasks: | ||
{ | ||
"flop": do_flop, | ||
"activation": do_activation, | ||
"parameter": do_parameter, | ||
"structure": do_structure, | ||
}[task](cfg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
#!/usr/bin/env python | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
|
||
import pickle as pkl | ||
import sys | ||
|
||
import torch | ||
|
||
""" | ||
Usage: | ||
# download pretrained swin model: | ||
wget https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth | ||
# run the conversion | ||
./convert-pretrained-model-to-d2.py swin_tiny_patch4_window7_224.pth swin_tiny_patch4_window7_224.pkl | ||
# Then, use swin_tiny_patch4_window7_224.pkl with the following changes in config: | ||
MODEL: | ||
WEIGHTS: "/path/to/swin_tiny_patch4_window7_224.pkl" | ||
INPUT: | ||
FORMAT: "RGB" | ||
""" | ||
|
||
if __name__ == "__main__": | ||
input = sys.argv[1] | ||
|
||
obj = torch.load(input, map_location="cpu")["model"] | ||
|
||
res = {"model": obj, "__author__": "third_party", "matching_heuristics": True} | ||
|
||
with open(sys.argv[2], "wb") as f: | ||
pkl.dump(res, f) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
#!/usr/bin/env python | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
|
||
import pickle as pkl | ||
import sys | ||
|
||
import torch | ||
|
||
""" | ||
Usage: | ||
# download one of the ResNet{18,34,50,101,152} models from torchvision: | ||
wget https://download.pytorch.org/models/resnet50-19c8e357.pth -O r50.pth | ||
# run the conversion | ||
./convert-torchvision-to-d2.py r50.pth r50.pkl | ||
# Then, use r50.pkl with the following changes in config: | ||
MODEL: | ||
WEIGHTS: "/path/to/r50.pkl" | ||
PIXEL_MEAN: [123.675, 116.280, 103.530] | ||
PIXEL_STD: [58.395, 57.120, 57.375] | ||
RESNETS: | ||
DEPTH: 50 | ||
STRIDE_IN_1X1: False | ||
INPUT: | ||
FORMAT: "RGB" | ||
""" | ||
|
||
if __name__ == "__main__": | ||
input = sys.argv[1] | ||
|
||
obj = torch.load(input, map_location="cpu") | ||
|
||
newmodel = {} | ||
for k in list(obj.keys()): | ||
old_k = k | ||
if "layer" not in k: | ||
k = "stem." + k | ||
for t in [1, 2, 3, 4]: | ||
k = k.replace("layer{}".format(t), "res{}".format(t + 1)) | ||
for t in [1, 2, 3]: | ||
k = k.replace("bn{}".format(t), "conv{}.norm".format(t)) | ||
k = k.replace("downsample.0", "shortcut") | ||
k = k.replace("downsample.1", "shortcut.norm") | ||
print(old_k, "->", k) | ||
newmodel[k] = obj.pop(old_k).detach().numpy() | ||
|
||
res = {"model": newmodel, "__author__": "torchvision", "matching_heuristics": True} | ||
|
||
with open(sys.argv[2], "wb") as f: | ||
pkl.dump(res, f) | ||
if obj: | ||
print("Unconverted keys:", obj.keys()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
#!/usr/bin/env python | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
# Modified by Bowen Cheng from: https://github.com/bowenc0221/boundary-iou-api/blob/master/tools/coco_instance_evaluation.py | ||
|
||
""" | ||
Evaluation for COCO val2017: | ||
python ./tools/coco_instance_evaluation.py \ | ||
--gt-json-file COCO_GT_JSON \ | ||
--dt-json-file COCO_DT_JSON | ||
""" | ||
import argparse | ||
import json | ||
|
||
from boundary_iou.coco_instance_api.coco import COCO | ||
from boundary_iou.coco_instance_api.cocoeval import COCOeval | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--gt-json-file", default="") | ||
parser.add_argument("--dt-json-file", default="") | ||
parser.add_argument("--iou-type", default="boundary") | ||
parser.add_argument("--dilation-ratio", default="0.020", type=float) | ||
args = parser.parse_args() | ||
print(args) | ||
|
||
annFile = args.gt_json_file | ||
resFile = args.dt_json_file | ||
dilation_ratio = args.dilation_ratio | ||
if args.iou_type == "boundary": | ||
get_boundary = True | ||
else: | ||
get_boundary = False | ||
cocoGt = COCO(annFile, get_boundary=get_boundary, dilation_ratio=dilation_ratio) | ||
|
||
# remove box predictions | ||
resFile = json.load(open(resFile)) | ||
for c in resFile: | ||
c.pop("bbox", None) | ||
|
||
cocoDt = cocoGt.loadRes(resFile) | ||
cocoEval = COCOeval(cocoGt, cocoDt, iouType=args.iou_type, dilation_ratio=dilation_ratio) | ||
cocoEval.evaluate() | ||
cocoEval.accumulate() | ||
cocoEval.summarize() | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.