-
Hi, I would appreciate it if anybody has a thought on this. I am working on bi-level optimization where a validation set is used in the outer optimization. I am using torchopt-distributed, on a single node GPU with 4 workers. My network is a modified vision transformer. My goal, for now, is to get the gradients of the validation loss wrt y. However, when I do distributed autograd backward with the validation loss, it is either stuck or extremely slow. The only error I get is when it has passed the RPC timeout. I have tried the ff. to investigate:
I have also considered doing the validation loop inside inner_loop, but seems inefficient as validation will be done on a premature model per worker. I have followed the documentation on distributed training, and the sample with MAML. Here is a rough pseudo code:
|
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 8 replies
-
Hi @mmargalo。 There is some communication cost during the distributed training procedure. I also suggest you fuse the two functions def inner_loop(net, x_train, y_train):
# ... clone net by reference
loss = net(x) # built-in criterion
meta_opt.step(loss)
return loss, net
def val_loss(net, x_valid, y_valid):
# ... clone net by reference
loss = net(x) # built-in criterion
return loss
@parallelize(...)
def compute_outer_loss(net, x_train, y_train, x_valid, y_valid):
train_loss, trained_net = inner_loop(net, x_train, y_train)
valid_loss = val_loss(trained_net, x_valid, y_valid)
return train_loss, valid_loss In the backward loop, I suggest you collect the losses and backward only once rather than backward multiple times: val_losses = []
for vx, vy in val_loader:
val_loss = val_loop(model_clone, vx, vy)
val_losses.append(val_loss)
val_loss = torch.stack(val_losses).sum() # or mean()
todist.autograd.backward(context_id, val_loss) Also, seems that your computation graph is growing too big in the for-loop. Have you ever tried to have fewer loop runs (e.g., increase the batch size)? |
Beta Was this translation helpful? Give feedback.
-
I see, so the model is not updated even if it's referenced on the workers. Thanks for the advice, I ended up merging the inner and val loops instead, with the val loop going over a small batch instead of the entire val set.
|
Beta Was this translation helpful? Give feedback.
I see, so the model is not updated even if it's referenced on the workers. Thanks for the advice, I ended up merging the inner and val loops instead, with the val loop going over a small batch instead of the entire val set.