Description
I have a use-case for functorch
. I would like to check possible iterations of model parameters in a very efficient way (I want to eliminate the loop). Here's an example code for a simplified case I got it working:
linear = torch.nn.Linear(10,2)
default_weight = linear.weight.data
sample_input = torch.rand(3,10)
sample_add = torch.rand_like(default_weight)
def interpolate_weights(alpha):
with torch.no_grad():
res_weight = torch.nn.Parameter(default_weight + alpha*sample_add)
linear.weight = res_weight
return linear(sample_input)
now I could do for alpha in np.np.linspace(0.0, 1.0, 100)
but I want to vectorise this loop since my code is prohibitively slow. Is functorch here applicable? Executing:
alphas = torch.linspace(0.0, 1.0, 100)
vmap(interpolate_weights)(alphas)
works, but how to do something similar for a simple resnet does not work. I've tried using load_state_dict
but that's not working:
from torchvision import models
model_resnet = models.resnet18(pretrained=True)
named_params = list(model_resnet.named_parameters())
named_params_data = [(n,p.data.clone()) for (n,p) in named_params]
sample_data = torch.rand(10,3,224,244)
def test_resnet(new_params):
def interpolate(alpha):
with torch.no_grad():
p_dict = {name:(old + alpha*new_params[i]) for i,(name, old) in enumerate(named_params_data)}
model_resnet.load_state_dict(p_dict, strict=False)
out = model_resnet(sample_data)
return out
return interpolate
rand_tensor = [torch.rand_like(p) for n,p in named_params_data]
to_vamp_resnet = test_thing(rand_tensor)
vmap(to_vamp_resnet)(alphas)
results in:
While copying the parameter named "fc.bias", whose dimensions in the model are torch.Size([1000]) and whose dimensions in the checkpoint are torch.Size([1000]), an exception occurred : ('vmap: inplace arithmetic(self, *extra_args) is not possible because there exists a Tensor
otherin extra_args that has more elements than
self. This happened due to
otherbeing vmapped over but
selfnot being vmapped over in a vmap. Please try to use out-of-place operators instead of inplace arithmetic. If said operator is being called inside the PyTorch framework, please file a bug report instead.',).