Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient checkpointing #78

Merged
merged 7 commits into from
Dec 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion rudalle/dalle/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def forward(
return_loss=False,
has_cache=False,
use_cache=False,
gradient_checkpointing=None,
):
text = input_ids[:, :self.text_seq_length]
text_range = torch.arange(self.text_seq_length)
Expand Down Expand Up @@ -123,7 +124,9 @@ def forward(

attention_mask = attention_mask[:, :, :embeddings.shape[1], :embeddings.shape[1]]
transformer_output, present_has_cache = self.transformer(
embeddings, attention_mask, has_cache=has_cache, use_cache=use_cache)
embeddings, attention_mask,
has_cache=has_cache, use_cache=use_cache,
gradient_checkpointing=gradient_checkpointing)

logits = self.to_logits(transformer_output)
if return_loss is False:
Expand Down
38 changes: 36 additions & 2 deletions rudalle/dalle/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,25 @@ def gelu_jit(x):
return gelu(x)


class Layer(torch.nn.Module):
"""
Helper class for gradient checkpointing.
"""

def __init__(self, x, f, *args, **kwargs):
super(Layer, self).__init__()
# module to checkpoint
self.x = x
# post-processing function
self.f = f
# arguments to the module
self.args = args
self.kwargs = kwargs

def forward(self, x):
return self.f(self.x(x, *self.args, **self.kwargs))


def rescale_max(h, scale=False):
if scale:
# This transformation does not affect following layernorm output.
Expand Down Expand Up @@ -110,12 +129,27 @@ def _get_layer_mask(self, layer_id):
layer_mask = self.conv_mask
return layer_mask

def forward(self, hidden_states, attention_mask, has_cache, use_cache):
def forward(self, hidden_states, attention_mask, has_cache, use_cache, gradient_checkpointing=None):
if gradient_checkpointing:
assert not use_cache
layers = []
for i, layer in enumerate(self.layers):
mask = attention_mask
layer_mask = self._get_layer_mask(i)[:mask.size(2), :mask.size(3)]
mask = torch.mul(attention_mask, layer_mask)
hidden_states, present_has_cache = layer(hidden_states, mask, has_cache=has_cache, use_cache=use_cache)
if gradient_checkpointing:
layers.append(Layer(layer,
# only get the embeddings, not present_has_cache
lambda x: x[0],
mask,
use_cache=False, has_cache=False))
else:
hidden_states, present_has_cache = layer(
hidden_states, mask, has_cache=has_cache, use_cache=use_cache)
if gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint_sequential(
layers, gradient_checkpointing, hidden_states)
present_has_cache = False
hidden_states = rescale_max(hidden_states, self.custom_relax)
output = self.final_layernorm(hidden_states)
return output, present_has_cache
Expand Down