Skip to content

Commit

Permalink
Gradient checkpointing (#78)
Browse files Browse the repository at this point in the history
* Gradient checkpointing: part 1

* Gradient checkpointing: part 2

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Patch (wrong name)

* Appease pre-commit.ci

* Fix bug

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
neverix and pre-commit-ci[bot] authored Dec 17, 2021
1 parent 8fd2147 commit a6e01de
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 3 deletions.
5 changes: 4 additions & 1 deletion rudalle/dalle/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def forward(
return_loss=False,
has_cache=False,
use_cache=False,
gradient_checkpointing=None,
):
text = input_ids[:, :self.text_seq_length]
text_range = torch.arange(self.text_seq_length)
Expand Down Expand Up @@ -123,7 +124,9 @@ def forward(

attention_mask = attention_mask[:, :, :embeddings.shape[1], :embeddings.shape[1]]
transformer_output, present_has_cache = self.transformer(
embeddings, attention_mask, has_cache=has_cache, use_cache=use_cache)
embeddings, attention_mask,
has_cache=has_cache, use_cache=use_cache,
gradient_checkpointing=gradient_checkpointing)

logits = self.to_logits(transformer_output)
if return_loss is False:
Expand Down
38 changes: 36 additions & 2 deletions rudalle/dalle/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,25 @@ def gelu_jit(x):
return gelu(x)


class Layer(torch.nn.Module):
"""
Helper class for gradient checkpointing.
"""

def __init__(self, x, f, *args, **kwargs):
super(Layer, self).__init__()
# module to checkpoint
self.x = x
# post-processing function
self.f = f
# arguments to the module
self.args = args
self.kwargs = kwargs

def forward(self, x):
return self.f(self.x(x, *self.args, **self.kwargs))


def rescale_max(h, scale=False):
if scale:
# This transformation does not affect following layernorm output.
Expand Down Expand Up @@ -110,12 +129,27 @@ def _get_layer_mask(self, layer_id):
layer_mask = self.conv_mask
return layer_mask

def forward(self, hidden_states, attention_mask, has_cache, use_cache):
def forward(self, hidden_states, attention_mask, has_cache, use_cache, gradient_checkpointing=None):
if gradient_checkpointing:
assert not use_cache
layers = []
for i, layer in enumerate(self.layers):
mask = attention_mask
layer_mask = self._get_layer_mask(i)[:mask.size(2), :mask.size(3)]
mask = torch.mul(attention_mask, layer_mask)
hidden_states, present_has_cache = layer(hidden_states, mask, has_cache=has_cache, use_cache=use_cache)
if gradient_checkpointing:
layers.append(Layer(layer,
# only get the embeddings, not present_has_cache
lambda x: x[0],
mask,
use_cache=False, has_cache=False))
else:
hidden_states, present_has_cache = layer(
hidden_states, mask, has_cache=has_cache, use_cache=use_cache)
if gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint_sequential(
layers, gradient_checkpointing, hidden_states)
present_has_cache = False
hidden_states = rescale_max(hidden_states, self.custom_relax)
output = self.final_layernorm(hidden_states)
return output, present_has_cache
Expand Down

0 comments on commit a6e01de

Please sign in to comment.