Skip to content

Commit

Permalink
refactor code;
Browse files Browse the repository at this point in the history
remove stts and i3d
  • Loading branch information
leftthomas committed Jun 9, 2019
1 parent f679d75 commit 4e91752
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 701 deletions.
21 changes: 11 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Two-Stream STTS
A PyTorch implementation of Two-Stream Spatio-Temporal and Temporal-Spatio Convolutional Network based on the paper
[Two-Stream Spatio-Temporal and Temporal-Spatio Convolutional Network for Activity Recognition]().
# C3D
A PyTorch implementation of C3D and R2Plus1D based on the paper
[Learning Spatiotemporal Features with 3D Convolutional Networks](https://arxiv.org/abs/1412.0767) and
[A Closer Look at Spatiotemporal Convolutions for Action Recognition](https://arxiv.org/abs/1711.11248).

## Requirements
- [Anaconda](https://www.anaconda.com/download/)
Expand Down Expand Up @@ -56,18 +57,18 @@ and [KINETICS600](https://deepmind.com/research/open-source/open-source-datasets
Download `UCF101` and `HMDB51` datasets with `train/val/test` split files into `data` directory.
We use the `split1` to split files. Run `misc.py` to preprocess these datasets.

For `KINETICS600` dataset, first download `train/val/test` split files into `data` directory, and
then run `download.py` to download and preprocess this dataset.
For `KINETICS600` dataset, first download `train/val/test` split files into `data` directory, then
run `download.py` to download and preprocess this dataset.

## Usage
### Train Model
```
visdom -logging_level WARNING & python train.py --num_epochs 20 --pre_train kinetics600_stts-a.pth
visdom -logging_level WARNING & python train.py --num_epochs 20 --pre_train kinetics600_r2plus1d.pth
optional arguments:
--data_type dataset type [default value is 'ucf101'](choices=['ucf101', 'hmdb51', 'kinetics600'])
--gpu_ids selected gpu [default value is '0,1,2,3']
--model_type model type [default value is 'stts-a'](choices=['stts-a', 'stts', 'i3d', 'r2plus1d', 'c3d'])
--batch_size training batch size [default value is 16]
--model_type model type [default value is 'r2plus1d'](choices=['r2plus1d', 'c3d'])
--batch_size training batch size [default value is 64]
--num_epochs training epochs number [default value is 100]
--pre_train used pre-trained model epoch name [default value is None]
```
Expand All @@ -78,9 +79,9 @@ Visdom now can be accessed by going to `127.0.0.1:8097` in your browser.
python inference.py --video_name data/ucf101/ApplyLipstick/v_ApplyLipstick_g04_c02.avi
optional arguments:
--data_type dataset type [default value is 'ucf101'](choices=['ucf101', 'hmdb51', 'kinetics600'])
--model_type model type [default value is 'stts-a'](choices=['stts-a', 'stts', 'i3d', 'r2plus1d', 'c3d'])
--model_type model type [default value is 'r2plus1d'](choices=['r2plus1d', 'c3d'])
--video_name test video name
--model_name model epoch name [default value is 'ucf101_st-ts-a.pth']
--model_name model epoch name [default value is 'ucf101_r2plus1d.pth']
```
The inferences will show in a pop up window.

Expand Down
13 changes: 3 additions & 10 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@

import utils
from models.C3D import C3D
from models.I3D import I3D
from models.R2Plus1D import R2Plus1D
from models.STTS import STTS


def center_crop(image):
Expand All @@ -24,10 +22,9 @@ def center_crop(image):
parser = argparse.ArgumentParser(description='Test Activity Recognition')
parser.add_argument('--data_type', default='ucf101', type=str, choices=['ucf101', 'hmdb51', 'kinetics600'],
help='dataset type')
parser.add_argument('--model_type', default='stts-a', type=str,
choices=['stts-a', 'stts', 'i3d', 'r2plus1d', 'c3d'], help='model type')
parser.add_argument('--model_type', default='r2plus1d', type=str, choices=['r2plus1d', 'c3d'], help='model type')
parser.add_argument('--video_name', type=str, help='test video name')
parser.add_argument('--model_name', default='ucf101_stts-a.pth', type=str, help='model epoch name')
parser.add_argument('--model_name', default='ucf101_r2plus1d.pth', type=str, help='model epoch name')
opt = parser.parse_args()

DATA_TYPE, MODEL_TYPE, VIDEO_NAME, MODEL_NAME = opt.data_type, opt.model_type, opt.video_name, opt.model_name
Expand All @@ -40,11 +37,7 @@ def center_crop(image):
if '{}_{}.pth'.format(DATA_TYPE, MODEL_TYPE) != MODEL_NAME:
raise NotImplementedError('the model name must be the same model type and same data type')

if MODEL_TYPE == 'stts-a' or MODEL_TYPE == 'stts':
model = STTS(len(class_names), (2, 2, 2, 2), MODEL_TYPE)
elif MODEL_TYPE == 'i3d':
model = I3D(len(class_names))
elif MODEL_TYPE == 'r2plus1d':
if MODEL_TYPE == 'r2plus1d':
model = R2Plus1D(len(class_names), (2, 2, 2, 2))
else:
model = C3D(len(class_names))
Expand Down
Loading

0 comments on commit 4e91752

Please sign in to comment.