Description
Thanks for the great library!
I'm learning to use the library by writing a DEQ layer inside GPT-2. After loss.backward()
some model parameters do not get gradients. I tracked the issue and it turns out that, the returned fixed point solution z_out[-1]
by the DEQ layer (which is a tensor of shape (batch size, sequence length, hidden dimension)
) has requires_grad
set to False
. Strangely, if I use a freshly-initialized GPT-2 model instead (without the pretrained weights), the issue is gone.
Specifically, this is my deq setup:
self.deq = get_deq(
ift=True,
f_solver='broyden', f_max_iter=15, f_tol=1e-3, f_stop_mode='rel',
b_solver='broyden', b_max_iter=15, b_tol=1e-3, b_stop_mode='rel',
)
if I print the requires_grad
of the solution in each iteration before and after the implicit function, such as:
def layer_iter(hidden_states, input_hidden_states):
print("before:", hidden_states.requires_grad)
# forward pass through each transformer block, where input_hidden_states is the input embedding
# ...
print("after:", hidden_states.requires_grad)
print("----")
return hidden_states
func = lambda var: layer_iter(var, hidden_states_) # hidden_states_ is the input after the embedding layer
zeros_ = torch.zeros(*hidden_states_.shape, requires_grad=True).to(hidden_states_.device)
z_out, info = self.deq(func, zeros_)
print("z_out[-1].requires_grad:", z_out[-1].requires_grad)
print(info)
hidden_states = z_out[-1]
and with the main scripts as
import torch
from torchdeq import get_deq
import matplotlib.pyplot as plt
from transformers import GPT2Tokenizer, GPT2Model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_name = 'gpt2'
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = GPT2Model.from_pretrained(model_name, attn_pdrop=0.0, embd_pdrop=0.0, resid_pdrop=0.0, summary_first_dropout=0.0)
model.to(device)
batch = ["we", "we"]
inputs = tokenizer(batch, return_tensors='pt', padding=True, truncation=True)
inputs = {key: value.to(device) for key, value in inputs.items()}
outputs = model(**inputs, use_cache=False)
Then I will get
before: True
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: False
----
z_out[-1].requires_grad: False
{'abs_lowest': tensor([914.4200, 914.4200], device='cuda:0'), 'rel_lowest': tensor([6.5291e-05, 6.5291e-05], device='cuda:0'), 'abs_trace': tensor([[1.0000e+08, 1.4693e+03, 1.1808e+03, 2.8000e+03, 2.1719e+03, 9.1442e+02,
1.0195e+03, 1.0512e+03, 1.0572e+03, 1.0586e+03, 9.1442e+02, 9.1442e+02,
9.1442e+02, 9.1442e+02, 9.1442e+02, 9.1442e+02],
[1.0000e+08, 1.4693e+03, 1.1808e+03, 2.8000e+03, 2.1719e+03, 9.1442e+02,
1.0195e+03, 1.0512e+03, 1.0572e+03, 1.0586e+03, 9.1442e+02, 9.1442e+02,
9.1442e+02, 9.1442e+02, 9.1442e+02, 9.1442e+02]], device='cuda:0'), 'rel_trace': tensor([[1.0000e+08, 1.3185e+00, 1.0195e+00, 6.6233e-01, 7.3519e-01, 2.3931e-01,
4.5128e-01, 1.2702e-02, 2.3903e-03, 6.5291e-05, 6.5291e-05, 6.5291e-05,
6.5291e-05, 6.5291e-05, 6.5291e-05, 6.5291e-05],
[1.0000e+08, 1.3185e+00, 1.0195e+00, 6.6233e-01, 7.3519e-01, 2.3931e-01,
4.5128e-01, 1.2702e-02, 2.3903e-03, 6.5291e-05, 6.5291e-05, 6.5291e-05,
6.5291e-05, 6.5291e-05, 6.5291e-05, 6.5291e-05]], device='cuda:0'), 'nstep': tensor([9., 9.], device='cuda:0'), 'sradius': tensor([0.])}
where it seems that the model does converge in 9 iterations but the returned hidden states have requires_grad=False, so the model's parameters (besides those after the DEQ layer) do not get gradients. I tried manually setting z_out[-1].requires_grad=True
but this doesn't help; after loss.backward() the .grad
is still None
for those parameters.
Intriguingly, if I use a freshly-initialized GPT-2 then the issue seems to go away:
model_name = 'gpt2'
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = GPT2Model.from_pretrained(model_name, attn_pdrop=0.0, embd_pdrop=0.0, resid_pdrop=0.0, summary_first_dropout=0.0)
# initialize the weights randomly
config = model.config
model = GPT2Model(config)
batch = ["we", "we"]
inputs = tokenizer(batch, return_tensors='pt', padding=True, truncation=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
inputs = {key: value.to(device) for key, value in inputs.items()}
model.to(device)
outputs = model(**inputs, use_cache=False)
and I get
before: True
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: True
----
z_out[-1].requires_grad: True
{'abs_lowest': tensor([9.0937, 9.0937], device='cuda:0'), 'rel_lowest': tensor([0.0002, 0.0002], device='cuda:0'), 'abs_trace': tensor([[1.0000e+08, 9.0937e+00, 9.2320e+00, 9.3630e+00, 9.4192e+00, 9.4344e+00,
9.4367e+00, 9.0937e+00, 9.0937e+00, 9.0937e+00, 9.0937e+00, 9.0937e+00,
9.0937e+00, 9.0937e+00, 9.0937e+00, 9.0937e+00],
[1.0000e+08, 9.0937e+00, 9.2320e+00, 9.3630e+00, 9.4192e+00, 9.4344e+00,
9.4367e+00, 9.0937e+00, 9.0937e+00, 9.0937e+00, 9.0937e+00, 9.0937e+00,
9.0937e+00, 9.0937e+00, 9.0937e+00, 9.0937e+00]], device='cuda:0'), 'rel_trace': tensor([[1.0000e+08, 5.5195e-01, 2.3347e-01, 9.2865e-02, 2.5290e-02, 3.8527e-03,
1.9339e-04, 1.9339e-04, 1.9339e-04, 1.9339e-04, 1.9339e-04, 1.9339e-04,
1.9339e-04, 1.9339e-04, 1.9339e-04, 1.9339e-04],
[1.0000e+08, 5.5195e-01, 2.3347e-01, 9.2865e-02, 2.5290e-02, 3.8527e-03,
1.9339e-04, 1.9339e-04, 1.9339e-04, 1.9339e-04, 1.9339e-04, 1.9339e-04,
1.9339e-04, 1.9339e-04, 1.9339e-04, 1.9339e-04]], device='cuda:0'), 'nstep': tensor([6., 6.], device='cuda:0'), 'sradius': tensor([-1.])}
max/min/mean steps: 6.0, 6.0, 6.0
where it can be seen that requires_grad becomes True now. Also, apart from requires_grad, it seems the sradius
in info
is -1 now instead of 0, which I'm not sure if it's related here.
I wonder if you have ideas about why this happens. Would appreciate it!