Skip to content

Commit 8cfc8ee

Browse files
Support TSM-R50 Python API (wang-xinyu#488)
* add tensorrt temporal shift module and related pytorch implementations * add .gitignore and getn weights script. * rename get_wts.py script * Add tsm-r50 demo. * update readme * remove useless codes * update readme * update readme * remote video and .gitignore, update tutorial * update readme and tutorial * fix a few bugs and test on tensorrt 5.1 * update readme
1 parent d9bdd7e commit 8cfc8ee

8 files changed

+869
-4
lines changed

lenet/lenet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def load_weights(file):
2727

2828
weight_map = {}
2929
with open(file, "r") as f:
30-
lines = f.readlines()
30+
lines = [line.strip() for line in f]
3131
count = int(lines[0])
3232
assert count == len(lines) - 1
3333
for i in range(1, count + 1):

resnet/resnet50.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def load_weights(file):
2929

3030
weight_map = {}
3131
with open(file, "r") as f:
32-
lines = f.readlines()
32+
lines = [line.strip() for line in f]
3333
count = int(lines[0])
3434
assert count == len(lines) - 1
3535
for i in range(1, count + 1):
@@ -138,7 +138,7 @@ def bottleneck(network, weight_map, input, in_channels, out_channels, stride,
138138
return relu3
139139

140140

141-
def createLenetEngine(maxBatchSize, builder, config, dt):
141+
def create_engine(maxBatchSize, builder, config, dt):
142142
weight_map = load_weights(WEIGHT_PATH)
143143
network = builder.create_network()
144144

@@ -233,7 +233,7 @@ def createLenetEngine(maxBatchSize, builder, config, dt):
233233
def APIToModel(maxBatchSize):
234234
builder = trt.Builder(TRT_LOGGER)
235235
config = builder.create_builder_config()
236-
engine = createLenetEngine(maxBatchSize, builder, config, trt.float32)
236+
engine = create_engine(maxBatchSize, builder, config, trt.float32)
237237
assert engine
238238
with open(ENGINE_PATH, "wb") as f:
239239
f.write(engine.serialize())

tsm/README.md

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Temporal Shift Module
2+
3+
TSM-R50 from "TSM: Temporal Shift Module for Efficient Video Understanding" <https://arxiv.org/abs/1811.08383>
4+
5+
TSM is a widely used Action Recognition model. This TensorRT implementation is tested with TensorRT 5.1 and TensorRT 7.2.
6+
7+
For the PyTorch implementation, you can refer to [open-mmlab/mmaction2](https://github.com/open-mmlab/mmaction2) or [mit-han-lab/temporal-shift-module](https://github.com/mit-han-lab/temporal-shift-module).
8+
9+
More details about the shift module(which is the core of TSM) could to [test_shift.py](./test_shift.py).
10+
11+
## Tutorial
12+
13+
+ An example could refer to [demo.sh](./demo.sh)
14+
+ Requirements: Successfully installed `torch>=1.3.0, torchvision`
15+
16+
+ Step 1: Train/Download TSM-R50 checkpoints from [offical Github repo](https://github.com/mit-han-lab/temporal-shift-module) or [MMAction2](https://github.com/open-mmlab/mmaction2)
17+
+ Supported settings: `num_segments`, `shift_div`, `num_classes`.
18+
+ Fixed settings: `backbone`(ResNet50), `shift_place`(blockres), `temporal_pool`(False).
19+
20+
+ Step 2: Convert PyTorch checkpoints to TensorRT weights.
21+
22+
```shell
23+
python gen_wts.py /path/to/pytorch.pth --out-filename /path/to/tensorrt.wts
24+
```
25+
26+
+ Step 3: Modify configs in `tsm_r50.py`.
27+
28+
```python
29+
BATCH_SIZE = 1
30+
NUM_SEGMENTS = 8
31+
INPUT_H = 224
32+
INPUT_W = 224
33+
OUTPUT_SIZE = 400
34+
SHIFT_DIV = 8
35+
```
36+
37+
+ Step 4: Inference with `tsm_r50.py`.
38+
39+
```shell
40+
usage: tsm_r50.py [-h] [--tensorrt-weights TENSORRT_WEIGHTS] [--input-video INPUT_VIDEO] [--save-engine-path SAVE_ENGINE_PATH] [--load-engine-path LOAD_ENGINE_PATH] [--test-mmaction2] [--mmaction2-config MMACTION2_CONFIG] [--mmaction2-checkpoint MMACTION2_CHECKPOINT]
41+
42+
optional arguments:
43+
-h, --help show this help message and exit
44+
--tensorrt-weights TENSORRT_WEIGHTS
45+
Path to TensorRT weights, which is generated by gen_weights.py
46+
--input-video INPUT_VIDEO
47+
Path to local video file
48+
--save-engine-path SAVE_ENGINE_PATH
49+
Save engine to local file
50+
--load-engine-path LOAD_ENGINE_PATH
51+
Saved engine file path
52+
--test-mmaction2 Compare TensorRT results with MMAction2 Results
53+
--mmaction2-config MMACTION2_CONFIG
54+
Path to MMAction2 config file
55+
--mmaction2-checkpoint MMACTION2_CHECKPOINT
56+
Path to MMAction2 checkpoint url or file path
57+
```
58+
59+
## TODO
60+
61+
+ [x] Python Shift module.
62+
+ [x] Generate wts of official tsm and mmaction2 tsm.
63+
+ [x] Python API Definition
64+
+ [x] Test with mmaction2 demo
65+
+ [x] Tutorial
66+
+ [ ] C++ API Definition

tsm/demo.sh

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Step 1: Get checkpoints from mmaction2
2+
# https://github.com/open-mmlab/mmaction2/tree/master/configs/recognition/tsm
3+
wget https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_1x1x8_50e_kinetics400_rgb/tsm_r50_1x1x8_50e_kinetics400_rgb_20200607-af7fb746.pth
4+
5+
# Step 2: Convert pytorch checkpoints to TensorRT weights
6+
python gen_wts.py tsm_r50_1x1x8_50e_kinetics400_rgb_20200607-af7fb746.pth --out-filename ./tsm_r50_kinetics400_mmaction2.wts
7+
8+
# Step 3: Skip this step since we use default settings.
9+
10+
# Step 4: Inference
11+
# 1) Save local engine file to `./tsm_r50_kinetics400_mmaction2.trt`.
12+
python tsm_r50.py \
13+
--tensorrt-weights ./tsm_r50_kinetics400_mmaction2.wts \
14+
--save-engine-path ./tsm_r50_kinetics400_mmaction2.trt
15+
16+
# 2) Predict the recognition result using a single video `demo.mp4`.
17+
# Should print `Result class id 6`, aka `arm wrestling`
18+
# Download demo video
19+
wget https://raw.githubusercontent.com/open-mmlab/mmaction2/master/demo/demo.mp4
20+
# # use *.wts as input
21+
# python tsm_r50.py --tensorrt-weights ./tsm_r50_kinetics400_mmaction2.wts \
22+
# --input-video ./demo.mp4
23+
# use engine file as input
24+
python tsm_r50.py --load-engine-path ./tsm_r50_kinetics400_mmaction2.trt \
25+
--input-video ./demo.mp4
26+
27+
# 3) Optional: Compare inference result with MMAction2 TSM-R50 model
28+
# Have to install MMAction2 First, please refer to https://github.com/open-mmlab/mmaction2/blob/master/docs/install.md
29+
# pip3 install pytest-runner
30+
# pip3 install mmcv
31+
# pip3 install mmaction2
32+
# # use *.wts as input
33+
# python tsm_r50.py \
34+
# --tensorrt-weights ./tsm_r50_kinetics400_mmaction2.wts \
35+
# --test-mmaction2 \
36+
# --mmaction2-config mmaction2_tsm_r50_config.py \
37+
# --mmaction2-checkpoint tsm_r50_1x1x8_50e_kinetics400_rgb_20200607-af7fb746.pth
38+
# # use TensorRT engine as input
39+
# python tsm_r50.py \
40+
# --load-engine-path ./tsm_r50_kinetics400_mmaction2.trt \
41+
# --test-mmaction2 \
42+
# --mmaction2-config mmaction2_tsm_r50_config.py \
43+
# --mmaction2-checkpoint tsm_r50_1x1x8_50e_kinetics400_rgb_20200607-af7fb746.pth

tsm/gen_wts.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import argparse
2+
import struct
3+
4+
import torch
5+
import numpy as np
6+
7+
8+
def write_one_weight(writer, name, weight):
9+
assert isinstance(weight, np.ndarray)
10+
values = weight.reshape(-1)
11+
writer.write('{} {}'.format(name, len(values)))
12+
for value in values:
13+
writer.write(' ')
14+
# float to bytes to hex_string
15+
writer.write(struct.pack('>f', float(value)).hex())
16+
writer.write('\n')
17+
18+
19+
def convert_name(name):
20+
return name.replace("module.", "").replace("base_model.", "").\
21+
replace("net.", "").replace("new_fc", "fc").replace("backbone.", "").\
22+
replace("cls_head.fc_cls", "fc").replace(".conv.", ".").\
23+
replace("conv1.bn", "bn1").replace("conv2.bn", "bn2").\
24+
replace("conv3.bn", "bn3").replace("downsample.bn", "downsample.1").\
25+
replace("downsample.weight", "downsample.0.weight")
26+
27+
28+
def main(args):
29+
ckpt = torch.load(args.checkpoint)['state_dict']
30+
ckpt = {k: v for k, v in ckpt.items() if 'num_batches_tracked' not in k}
31+
with open(args.out_filename, "w") as f:
32+
f.write(f"{len(ckpt)}\n")
33+
for k, v in ckpt.items():
34+
key = convert_name(k)
35+
write_one_weight(f, key, v.cpu().numpy())
36+
37+
38+
if __name__ == '__main__':
39+
parser = argparse.ArgumentParser()
40+
parser.add_argument("checkpoint", type=str, help="Path to checkpoint file")
41+
parser.add_argument("--out-filename",
42+
type=str,
43+
default="tsm_r50.wts",
44+
help="Path to converted wegiths file")
45+
args = parser.parse_args()
46+
main(args)

tsm/mmaction2_tsm_r50_config.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# model settings
2+
model = dict(
3+
type='Recognizer2D',
4+
backbone=dict(
5+
type='ResNetTSM',
6+
pretrained='torchvision://resnet50',
7+
depth=50,
8+
norm_eval=False,
9+
shift_div=8),
10+
cls_head=dict(
11+
type='TSMHead',
12+
num_classes=400,
13+
in_channels=2048,
14+
spatial_type='avg',
15+
consensus=dict(type='AvgConsensus', dim=1),
16+
dropout_ratio=0.5,
17+
init_std=0.001,
18+
is_shift=True),
19+
# model training and testing settings
20+
train_cfg=None,
21+
test_cfg=dict(average_clips='prob'))

0 commit comments

Comments
 (0)