-
Notifications
You must be signed in to change notification settings - Fork 1
Gradient Accumulation in PyTorch, GANs and other Nightmares
You are using PyTorch and you would like to utilize gradient accumulation ? Preferably within a GAN ? Well, then you might have already encountered what is summarized within this document. Please feel free to ignore this page, in case you are a PyTorch expert.
The standard procedure to run one training step for a single model is as follows:
import torch
# And import all other libraries
# Define a model and optimizer:
model = ...
optimizer = torch.optim.Adam(parameters=model.parameters(),lr=1e-4)
# Load data, define training loop, etc.
#...
# Run single training step:
# 1. Resets all the model gradients:
optimizer.zero_grad(set_to_none=True)
# 2. Get the model predictions:
y_pred = model(x) #--> Get the model
# 3. Formulate the loss:
loss = torch.nn.losses.MSELoss()(y_pred,y_true)
# 4. Run backpropagation:
loss.backward()
# 5. Update model weights / biases with a single optimization step:
optimizer.step()
# Monitor everything and evaluate performance...
So far so good. The first item in the above code snippet is important, because it sets all existing gradients (that the model may have) to zero.
Item 4 (backpropagation) uses PyTorch's autograd under the hood. autograd is great, but it sees everything! (Like Sauron from Lord of the Rings, but without the scary eye vibe...) To quote the autograd manual : "[...] This is where autograd comes in: It tracks the history of every computation. Every computed tensor in your PyTorch model carries a history of its input tensors and the function used to create it [...]" This says it all...
Once torch.autograd.backward()
is run, each trainable model parameter has a gradient attached to it. To be more clear:
# Model freshly initialized:
for p in model.parameters():
print(p.grad) #--> Will return either None or 0
# Run backpropagation:
loss.backward()
for p in model.parameters():
print(p.grad) #--> Returns the actual gradient value
Now if one would repeat the above code snippet without using optimizer.zero_grad()
, the gradients of each parameter would become larger, because they are accumulated. Thus, if you want to be sure that your model receives the proper gradients, you have to use optimizer.zero_grad()
We learned that the canonical code to run one optimization step is:
# Reset
optimizer.zero_grad()
# Loss computation...
loss.backward()
# Update parameters:
optimizer.step()
But, following the logic we explored so far, the following code yields the same results:
# Loss computation...
loss.backward()
# Update parameters:
optimizer.step()
# Reset
optimizer.zero_grad()
Because the gradients are reset after the optimization step and the model starts the next training iteration without any gradient history. This is why one can run gradient accumulation efficiently in PyTorch.
The following code snippets highlights how gradient accumulation is done in PyTorch:
# Run the gradient accumulation:
n_accumulation_steps = 5
for _ in range(n_accumulation_steps):
# Draw random samples x from your data:
x = ...
# Get model response:
y_pred = model(x)
# Compute the loss...
loss = torch.nn.losses.MSELoss()(y_pred,y_true)
loss = loss / n_accumulation_steps #--> This ensures that the gradients are properly normalized
# And now run the optimizer step to use all accumulated gradients:
optimizer.step()
optimizer.zero_grad()
Now why is the loss in the above code snippet normalized ? The reason for this is simply the fact that one wishes to average over all accumulated gradients. There are situations where this normalization is not required / necessary. In these cases, one may leave out this normalization step. An alternative implementation to the above code snipped could be:
# Run the gradient accumulation:
n_accumulation_steps = 5
for _ in range(n_accumulation_steps):
# Draw random samples x from your data:
x = ...
# Get model response:
y_pred = model(x)
# Compute the loss...
loss = torch.nn.losses.MSELoss()(y_pred,y_true)
# The loss was has not been normalized, but one prefers to average the gradients nevertheless:
for p in model.parameters():
current_grad = p.grad
p.grad = current_grad / n_accumulation_steps #--> Overwrite each gradient with its normalized version
# And now run the optimizer step to use all accumulated (and normalized) gradients:
optimizer.step()
optimizer.zero_grad()
From a conceptional point of view, both approaches will yield the same results. However, the first implementation is a bit more elegant and efficient as it does not requires an additional for-loop that iterates through all model parameters. There are many resources online that nicely explain gradient accumulation in PyTorch. I will just list the three that helped me personally the most:
The PyTorch implementation for a GAN with a generator G and a discriminator D may look like this:
# Define both models G and D somewhere, including the corresponding optimizer:
G = ...
D = ...
g_opt = torch.optim.Adam(parameters=G.parameters,lr=1e-5)
d_opt = torch.optim.Adam(parameters=D.parameters,lr=1e-4)
# A) Train the generator first:
# Create a noise tensor:
noise = torch.normal(mean=0.0,std=1.0,size=(1000,20))
# Produce fake data:
x_fake = G(noise)
# And get the discriminator response:
y_fake = D(x_fake)
# Reset optimizer:
g_opt.zero_grad(set_to_none=True)
# Compute loss:
g_loss = torch.nn.losses.BCELoss()(y_fake,torch.ones_like(y_fake))
# Run backpropagation:
g_loss.backward(retain_graph=True)
# And update the weights:
g_opt.step()
# B) Train the Discriminator:
# Get discriminator response on real data:
y_real = D(x_real)
# Already computed the discriminator response on the fake / generated data, so proceed to the loss computation:
# Do not forget to reset the optimizer:
d_opt.zero_grad(set_to_none=True)
d_loss = torch.nn.losses.BCELoss()(y_fake,torch.zeros_like(y_fake)) + torch.nn.BCELoss()(y_real,torch.ones_like(y_real))
d_loss.backward()
# And update the discriminator weights:
d_opt.step()
# rinse and repeat....
The above code also works if we swap the order and train the discriminator first and then the generator.
Now the following code will break the GAN training:
# A) Train the generator first:
# Create a noise tensor:
noise = torch.normal(mean=0.0,std=1.0,size=(1000,20))
# Produce fake data:
x_fake = G(noise)
# And get the discriminator response:
y_fake = D(x_fake)
# Compute loss:
g_loss = torch.nn.losses.BCELoss()(y_fake,torch.ones_like(y_fake))
# Run backpropagation:
g_loss.backward(retain_graph=True)
# And update the weights:
g_opt.step()
# Reset optimizer:
g_opt.zero_grad(set_to_none=True)
# B) Train the Discriminator:
# Get discriminator response on real data:
y_real = D(x_real)
# Already computed the discriminator response on the fake / generated data, so proceed to the loss computation:
d_loss = torch.nn.losses.BCELoss()(y_fake,torch.zeros_like(y_fake)) + torch.nn.BCELoss()(y_real,torch.ones_like(y_real))
d_loss.backward()
# And update the discriminator weights:
d_opt.step()
# Do not forget to reset the optimizer:
d_opt.zero_grad(set_to_none=True)
# rinse and repeat....
One will notice that the training will not converge properly. Now why is that ? The reason for this is autograd, because it sees all gradients. The line g_loss.backward(retain_graph=True)
runs the loss backwards through the generator AND the discriminator, because the latter is part of the GAN. This means that autograd attaches a gradient to every trainable parameter within the discriminator network. So the discriminator already holds gradient information for each parameter. If we now perform the actual discriminator training via: d_loss.backward()
we add the gradients that we are actually interested in, on top of those from the generator backpropagation. What a disaster! The above code snippet can be used, if we make the following changes to it:
# Tell autograd that we do not want to keep track of the discriminator gradients
for param in D.parameters():
param.requires_grad = False
# Compute generator loss and update generator:
g_opt.step()
g_opt.zero_grad(set_to_none=True)
# Now re-register the discriminator gradients again:
for param in D.parameters():
param.requires_grad = True
# Compute discriminator loss update discriminator:
d_opt.step()
d_opt.zero_grad(set_to_none=True)
Now there is no apparent reason for coding the GAN training in this way, because we have to add more lines and handle two additional for-loops. But, the code above will work as expected and allow the GAN to converge.
The previously discussed code snippet offers no clear advantage, except for adding the gradient accumulation functionality.