-
Hello, I modified the tutorial code to get the following simple example of bi-level optimization. The results make sense to me. However, is this the right way of doing bi-level optimization with import torch
import torch.autograd
import torch.nn as nn
import torchopt
class Net(nn.Module):
def __init__(self):
super().__init__()
self.y = nn.Parameter(torch.tensor(1.))
def loss_lower(self, x):
return 0.5*(x - self.y)**2
def loss_upper(self, x):
return 0.5*(x**2 + self.y**2)
net = Net()
x = nn.Parameter(torch.tensor(2.0))
opt_upper = torch.optim.SGD([x], lr=1e-1)
optim = torchopt.MetaSGD(net, lr=1e-1)
net.train()
for i in range(100):
opt_upper.zero_grad()
inner_loss = net.loss_lower(x)
optim.step(inner_loss)
outer_loss = net.loss_upper(x)
# retain_graph seems necessary here
# outer_loss.backward()
outer_loss.backward(retain_graph=True)
opt_upper.step()
print("x:", x.detach().numpy(), "y:", net.y.detach().numpy()) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hi @xianghang, after each meta-optimization step (i.e., before the next meta-optimization loop), you need to use Here is an example: import torch
import torch.nn as nn
import torchopt
class Net(nn.Module):
def __init__(self):
super().__init__()
self.y = nn.Parameter(torch.tensor(1.0))
def loss_lower(self, x):
return 0.5 * (x - self.y) ** 2
def loss_upper(self, x):
return 0.5 * (x**2 + self.y**2) net = Net().train()
x = nn.Parameter(torch.tensor(2.0))
optim_upper = torchopt.SGD([x], lr=1e-1)
optim = torchopt.MetaSGD(net, lr=1e-1) for iter_upper in range(1, 51):
print(f'Iter {iter_upper}:')
for iter_lower in range(1, 11):
inner_loss = net.loss_lower(x)
optim.step(inner_loss)
print(
'',
f'Inner-iter({iter_upper:>2}.{iter_lower:<2}):',
f'x={x.item():.8f}',
f'y={net.y.item():.8f}',
sep=' ',
)
outer_loss = net.loss_upper(x)
optim_upper.zero_grad()
outer_loss.backward()
optim_upper.step()
torchopt.stop_gradient(net) # <-- detach graph
torchopt.stop_gradient(optim)
print('Final:', f'x={x.item():.8f}', f'y={net.y.item():.8f}', sep=' ') Outputs:
Visualization: Full output
|
Beta Was this translation helpful? Give feedback.
Hi @xianghang, after each meta-optimization step (i.e., before the next meta-optimization loop), you need to use
torchopt.stop_gradient
to detach the previous computational graph. See also our tutorialtutorials/4_Stop_Gradient.ipynb
.Here is an example: