-
Notifications
You must be signed in to change notification settings - Fork 105
Open
Description
from functorch import make_functional_with_buffers, vmap, grad
fmodel, params, buffers = make_functional_with_buffers(net,disable_autograd_tracking=True)
def compute_loss_stateless_model (params, buffers, sample, target):
batch = sample.unsqueeze(0)
targets = target.unsqueeze(0)
predictions = fmodel(params, buffers, batch)
loss = criterion(predictions, targets)
return loss
ft_compute_grad = grad(compute_loss_stateless_model)
gradinet = ft_compute_grad(params, buffers, train_poi_set[0][0].cuda(), torch.tensor(train_poi_set[0][1]).cuda())
This will return the gradient of the whole model. However, I only want the second last layers' gradient, like:
gradinet = ft_compute_grad(params, buffers, train_poi_set[0][0].cuda(), torch.tensor(train_poi_set[0][1]).cuda())[-2]
Although this method can also obtain the required gradient, it will cause a lot of unnecessary overhead. Is there any way to close the 'require_grad' of all previous layers? Thanks for your answer!
Metadata
Metadata
Assignees
Labels
No labels