MemFlow: Optical Flow Estimation and Prediction with Memory
Qiaole Dong, Yanwei Fu
CVPR 2024
conda create --name memflow python=3.8
conda activate memflow
conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia
pip install yacs loguru einops timm==0.4.12 imageio matplotlib tensorboard scipy opencv-python h5py tqdm
For faster training or inference, you should further install the FlashAttention.
This FlashAttention wheel is compatible with our CUDA version. Refer to issue.
We provide pretrained models. The default path of the models for evaluation is:
├── ckpts
├── MemFlowNet_things.pth
├── MemFlowNet_sintel.pth
├── MemFlowNet_kitti.pth
├── MemFlowNet_spring.pth
├── MemFlowNet_T_things.pth
├── MemFlowNet_T_things_kitti.pth
├── MemFlowNet_T_sintel.pth
├── MemFlowNet_T_kitti.pth
├── MemFlowNet_P_things.pth
├── MemFlowNet_P_sintel.pth
Download models and put them in the ckpts
folder. Run the following command:
python -u inference.py --name MemFlowNet --stage sintel --restore_ckpt ckpts/MemFlowNet_sintel.pth --seq_dir demo_input_images --vis_dir demo_flow_vis
Note: you can change the _CN.val_decoder_depth
of configs/sintel_memflownet.py
from 15
to smaller numbers for better speed and performance trade-off as in Fig. 1.
To evaluate/train MatchFlow, you will need to download the required datasets.
- FlyingThings3D
- Sintel
- KITTI
- HD1K (optional)
- Spring
By default our codes will search for the datasets in these locations. You can create symbolic links to wherever
the datasets were downloaded in the datasets
folder
├── datasets
├── Sintel
├── test
├── training
├── KITTI
├── testing
├── training
├── devkit
├── FlyingThings3D
├── frames_cleanpass
├── frames_finalpass
├── optical_flow
├── spring
├── test
├── training
├── flow_subsampling
Please download the models to ckpts
folder. Then you can evaluate the provided model using following script:
bash evaluate.sh
We used the following training schedule in our paper (2 A100/A6000 GPUs). Training logs will be written to the logs
which can be
visualized using tensorboard.
bash train.sh
If you found our paper helpful, please consider citing:
@inproceedings{dong2024memflow,
title={MemFlow: Optical Flow Estimation and Prediction with Memory},
author={Dong, Qiaole and Fu, Yanwei},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2024}
}
Thanks to previous open-sourced repo: