Skip to content

nimz/pytorch-checkpoint

 
 

Repository files navigation

pytorch-checkpoint

Gradient checkpointing is a technique to reduce GPU memory cost.

Official implementation

There exists a PyTorch implementaion in the official repo. However, it is extremely slow with multiple GPUs.

This implementation

This repo contains a PyTorch implemention that can work on multiple GPUs.

Main results

Method # GPU Batch Memory Time
Naive 2 256 5.25G 0.27s
Official 2 256 2.98G 1.41s
This repo 2 256 2.97G 0.31s

Documentation

The main functionality is in checkpoint.py

import checkpoint
checkpoint.CheckpointFunction.apply(function, n, *args)

Parameters:

  • function – describes what to run in the forward pass of the model or part of the model. It should also know how to handle the inputs passed as the tuple. For example, in LSTM, if user passes (activation, hidden), function should correctly use the first input as activation and the second input as hidden.
  • n – number of inputs to the function
  • args – tuple containing inputs to the function AND parameters to optimize in the function. Note that the first n elements in this tuple should be ordered inputs to the function. Other elements are considered as parameters.

Returns:

  • Output of running function on inputs to the function

Note: We recommend using checkpointing with cp_BatchNorm2d instead of torch.nn.BatchNorm2d, to avoid accumulating the same batch norm statistics more than once.

DenseNet example

We provide an example of applying our checkpointing on memory efficient densenet. It only involves changing a few lines in the original implementation. (The original implementation uses PyTorch official checkpointing.)

# bn_function is a function containing conv1, norm1, relu1.
# naive no checkpointing: bottleneck_output = bn_function(*prev_features)
# official implementation: bottleneck_output = cp.checkpoint(bn_function, *prev_features)
args = prev_features + tuple(self.norm1.parameters()) + tuple(self.conv1.parameters())
# The parameters to optimize in the bn_function are tuple(self.norm1.parameters()) + tuple(self.conv1.parameters())
bottleneck_output = cp.CheckpointFunction.apply(bn_function, len(prev_features), *args)

Demo

python-fire is not required for checkpointing, but is required for the efficient densenet demo.

pip install fire
  • our checkpointing demo:
CUDA_VISIBLE_DEVICES=0,1 python cp_demo.py --efficient True --data cifar --save model --batch_size 256
  • the official implementation demo:
CUDA_VISIBLE_DEVICES=0,1 python original_demo.py --efficient True --data cifar --save model --batch_size 256

Environment

This code is tested with PyTorch 1.0.0.dev20181102

Speed tested on TITAN X (Pascal)

Full results

Method # GPU Batch Memory Time
Naive 1 256 9.93G 0.42s
Naive 2 4 0.65G 0.10s
Naive 2 256 5.25G 0.27s
Naive 2 512 9.93G 0.50s
Official 1 256 5.38G 0.52s
Official 1 512 10.1G 1.00s
Official 2 4 0.62G 1.40s
Official 2 256 2.98G 1.41s
Official 2 512 5.39G 1.53s
This repo 1 256 5.37G 0.50s
This repo 1 512 10.1G 0.97s
This repo 2 4 0.62G 0.13s
This repo 2 256 2.97G 0.31s
This repo 2 512 5.37G 0.58s

Credits

Part of our code in checkpoint.py and cp_BatchNorm2d.py is from https://github.com/pytorch/pytorch

The efficient densenet demo is taken from https://github.com/gpleiss/efficient_densenet_pytorch

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%