From a6e01de413a81987c22386645ea5a313c386cac7 Mon Sep 17 00:00:00 2001 From: neverix Date: Fri, 17 Dec 2021 14:36:33 +0300 Subject: [PATCH] Gradient checkpointing (#78) * Gradient checkpointing: part 1 * Gradient checkpointing: part 2 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Patch (wrong name) * Appease pre-commit.ci * Fix bug Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- rudalle/dalle/model.py | 5 ++++- rudalle/dalle/transformer.py | 38 ++++++++++++++++++++++++++++++++++-- 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/rudalle/dalle/model.py b/rudalle/dalle/model.py index 28379cb..1b2b0ec 100644 --- a/rudalle/dalle/model.py +++ b/rudalle/dalle/model.py @@ -96,6 +96,7 @@ def forward( return_loss=False, has_cache=False, use_cache=False, + gradient_checkpointing=None, ): text = input_ids[:, :self.text_seq_length] text_range = torch.arange(self.text_seq_length) @@ -123,7 +124,9 @@ def forward( attention_mask = attention_mask[:, :, :embeddings.shape[1], :embeddings.shape[1]] transformer_output, present_has_cache = self.transformer( - embeddings, attention_mask, has_cache=has_cache, use_cache=use_cache) + embeddings, attention_mask, + has_cache=has_cache, use_cache=use_cache, + gradient_checkpointing=gradient_checkpointing) logits = self.to_logits(transformer_output) if return_loss is False: diff --git a/rudalle/dalle/transformer.py b/rudalle/dalle/transformer.py index 317662a..83bb356 100755 --- a/rudalle/dalle/transformer.py +++ b/rudalle/dalle/transformer.py @@ -18,6 +18,25 @@ def gelu_jit(x): return gelu(x) +class Layer(torch.nn.Module): + """ + Helper class for gradient checkpointing. + """ + + def __init__(self, x, f, *args, **kwargs): + super(Layer, self).__init__() + # module to checkpoint + self.x = x + # post-processing function + self.f = f + # arguments to the module + self.args = args + self.kwargs = kwargs + + def forward(self, x): + return self.f(self.x(x, *self.args, **self.kwargs)) + + def rescale_max(h, scale=False): if scale: # This transformation does not affect following layernorm output. @@ -110,12 +129,27 @@ def _get_layer_mask(self, layer_id): layer_mask = self.conv_mask return layer_mask - def forward(self, hidden_states, attention_mask, has_cache, use_cache): + def forward(self, hidden_states, attention_mask, has_cache, use_cache, gradient_checkpointing=None): + if gradient_checkpointing: + assert not use_cache + layers = [] for i, layer in enumerate(self.layers): mask = attention_mask layer_mask = self._get_layer_mask(i)[:mask.size(2), :mask.size(3)] mask = torch.mul(attention_mask, layer_mask) - hidden_states, present_has_cache = layer(hidden_states, mask, has_cache=has_cache, use_cache=use_cache) + if gradient_checkpointing: + layers.append(Layer(layer, + # only get the embeddings, not present_has_cache + lambda x: x[0], + mask, + use_cache=False, has_cache=False)) + else: + hidden_states, present_has_cache = layer( + hidden_states, mask, has_cache=has_cache, use_cache=use_cache) + if gradient_checkpointing: + hidden_states = torch.utils.checkpoint.checkpoint_sequential( + layers, gradient_checkpointing, hidden_states) + present_has_cache = False hidden_states = rescale_max(hidden_states, self.custom_relax) output = self.final_layernorm(hidden_states) return output, present_has_cache