You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm learning to use the library by writing a DEQ layer inside GPT-2. After loss.backward() some model parameters do not get gradients. I tracked the issue and it turns out that, the returned fixed point solution z_out[-1] by the DEQ layer (which is a tensor of shape (batch size, sequence length, hidden dimension)) has requires_grad set to False. Strangely, if I use a freshly-initialized GPT-2 model instead (without the pretrained weights), the issue is gone.
if I print the requires_grad of the solution in each iteration before and after the implicit function, such as:
deflayer_iter(hidden_states, input_hidden_states):
print("before:", hidden_states.requires_grad)
# forward pass through each transformer block, where input_hidden_states is the input embedding# ...print("after:", hidden_states.requires_grad)
print("----")
returnhidden_statesfunc=lambdavar: layer_iter(var, hidden_states_) # hidden_states_ is the input after the embedding layerzeros_=torch.zeros(*hidden_states_.shape, requires_grad=True).to(hidden_states_.device)
z_out, info=self.deq(func, zeros_)
print("z_out[-1].requires_grad:", z_out[-1].requires_grad)
print(info)
hidden_states=z_out[-1]
where it seems that the model does converge in 9 iterations but the returned hidden states have requires_grad=False, so the model's parameters (besides those after the DEQ layer) do not get gradients. I tried manually setting z_out[-1].requires_grad=True but this doesn't help; after loss.backward() the .grad is still None for those parameters.
Intriguingly, if I use a freshly-initialized GPT-2 then the issue seems to go away:
where it can be seen that requires_grad becomes True now. Also, apart from requires_grad, it seems the sradius in info is -1 now instead of 0, which I'm not sure if it's related here.
I wonder if you have ideas about why this happens. Would appreciate it!
The text was updated successfully, but these errors were encountered:
Thanks for the great library!
I'm learning to use the library by writing a DEQ layer inside GPT-2. After
loss.backward()
some model parameters do not get gradients. I tracked the issue and it turns out that, the returned fixed point solutionz_out[-1]
by the DEQ layer (which is a tensor of shape(batch size, sequence length, hidden dimension)
) hasrequires_grad
set toFalse
. Strangely, if I use a freshly-initialized GPT-2 model instead (without the pretrained weights), the issue is gone.Specifically, this is my deq setup:
if I print the
requires_grad
of the solution in each iteration before and after the implicit function, such as:and with the main scripts as
Then I will get
where it seems that the model does converge in 9 iterations but the returned hidden states have requires_grad=False, so the model's parameters (besides those after the DEQ layer) do not get gradients. I tried manually setting
z_out[-1].requires_grad=True
but this doesn't help; after loss.backward() the.grad
is stillNone
for those parameters.Intriguingly, if I use a freshly-initialized GPT-2 then the issue seems to go away:
and I get
where it can be seen that requires_grad becomes True now. Also, apart from requires_grad, it seems the
sradius
ininfo
is -1 now instead of 0, which I'm not sure if it's related here.I wonder if you have ideas about why this happens. Would appreciate it!
The text was updated successfully, but these errors were encountered: