30
30
31
31
# sys.path.append("..")
32
32
from utils .common_utils import generate_task_id , TASK2ID , ID2TASK
33
- from utils .loss_utils import loss_func_mft , SelfpacedStatus , load_balancing_loss_func
33
+ from utils .loss_utils import loss_func_mft , CoBaStatus , load_balancing_loss_func
34
34
35
35
logger = get_logger (__name__ )
36
36
@@ -239,7 +239,7 @@ def accelerate_monitor(
239
239
reduce_task_loss ,
240
240
reduce_task_exist ,
241
241
completed_steps ,
242
- selfpaced_status = None ,
242
+ coba_status = None ,
243
243
):
244
244
"""
245
245
gather reduce_loss and reduce_task_loss from all N devices.
@@ -263,27 +263,27 @@ def accelerate_monitor(
263
263
f"[lr={ self .lr_scheduler .get_lr ()[0 ]:.4e} , { self .optimizer .param_groups [0 ]['lr' ]:.4e} ]" ,
264
264
main_process_only = True ,
265
265
)
266
- if selfpaced_status is not None :
267
- if completed_steps > selfpaced_status . selfpaced_history_length :
268
- selfpaced_status .log_per_task_weight = selfpaced_status .log_per_task_weight / torch .sum (
269
- selfpaced_status .log_per_task_weight
266
+ if coba_status is not None :
267
+ if completed_steps > coba_status . coba_warmup_steps :
268
+ coba_status .log_per_task_weight = coba_status .log_per_task_weight / torch .sum (
269
+ coba_status .log_per_task_weight
270
270
)
271
271
else :
272
- selfpaced_status .log_per_task_weight = torch .ones (len (ID2TASK )) / len (ID2TASK )
272
+ coba_status .log_per_task_weight = torch .ones (len (ID2TASK )) / len (ID2TASK )
273
273
logger .info (
274
- f"[TRAIN][per_task_train_weight={ selfpaced_status .log_per_task_weight } ]" , main_process_only = True
274
+ f"[TRAIN][per_task_train_weight={ coba_status .log_per_task_weight } ]" , main_process_only = True
275
275
)
276
276
train_log_dict = {"Loss/train" : train_loss }
277
277
for i in range (len (ID2TASK )):
278
278
train_log_dict [f"{ ID2TASK [i ]} _loss/train" ] = train_task_loss [i ]
279
- if selfpaced_status is not None :
280
- train_log_dict [f"{ ID2TASK [i ]} _selfpaced_weight /train" ] = selfpaced_status .log_per_task_weight [i ].item ()
279
+ if coba_status is not None :
280
+ train_log_dict [f"{ ID2TASK [i ]} _coba_weight /train" ] = coba_status .log_per_task_weight [i ].item ()
281
281
282
282
if self .accelerator .is_main_process :
283
283
write_tensorboard (self .summary_writer , train_log_dict , completed_steps )
284
284
285
- if selfpaced_status is not None :
286
- selfpaced_status .log_per_task_weight = torch .zeros (len (ID2TASK ))
285
+ if coba_status is not None :
286
+ coba_status .log_per_task_weight = torch .zeros (len (ID2TASK ))
287
287
288
288
def accelerate_evaluate (
289
289
self ,
@@ -416,18 +416,29 @@ def accelerate_train(self):
416
416
reduce_task_exist = torch .zeros (len (ID2TASK )).to (self .model .device )
417
417
per_task_weight = self .args .task_weights
418
418
419
- if self .args .weighted_loss_mode == "selfpaced" :
420
- selfpaced_status = SelfpacedStatus (
421
- self .args .selfpaced_scale_factor ,
422
- self .args .selfpaced_interval ,
423
- self .args .selfpaced_history_length ,
424
- self .args .selfpaced_sample_valid_num ,
419
+ if self .args .weighted_loss_mode == "coba" :
420
+ self .model .eval ()
421
+ eval_loss , eval_task_loss , _ , _ , _ = self .accelerate_evaluate (
422
+ completed_steps ,
423
+ 0 ,
424
+ min_eval_loss ,
425
+ stall_num ,
426
+ best_step ,
427
+ )
428
+ self .model .train ()
429
+ coba_status = CoBaStatus (
430
+ self .args .coba_warmup_steps ,
431
+ self .args .coba_history_length ,
432
+ self .args .coba_tau ,
433
+ self .args .coba_update_interval ,
434
+ self .args .coba_sample_valid_num ,
425
435
self .valid_dataloader ,
426
436
)
427
- selfpaced_status .sample_valid_batch (self .model , completed_steps )
428
- selfpaced_status .valid_iterator = iter (selfpaced_status .valid_dataloader )
437
+ coba_status .valid_task_loss_begining = eval_task_loss .clone ().to (self .model .device )
438
+ coba_status .sample_valid_batch (self .model , completed_steps )
439
+ logger .info (f"valid_task_loss: { coba_status .valid_task_loss_accumulated } " , main_process_only = True )
429
440
else :
430
- selfpaced_status = None
441
+ coba_status = None
431
442
432
443
# Training Loop!
433
444
for epoch in range (starting_epoch , self .args .num_train_epochs ):
@@ -463,13 +474,15 @@ def accelerate_train(self):
463
474
)
464
475
465
476
if (
466
- self .args .weighted_loss_mode == "selfpaced "
467
- and step % self .args . gradient_accumulation_steps == 0
468
- and completed_steps % self .args .selfpaced_interval == 0
469
- and completed_steps >= self .args .selfpaced_history_length
477
+ self .args .weighted_loss_mode == "coba "
478
+ and self .accelerator . sync_gradients
479
+ and completed_steps % self .args .coba_update_interval == 0
480
+ and completed_steps >= self .args .coba_warmup_steps
470
481
):
471
- per_task_weight = selfpaced_status .compute_per_task_weight (completed_steps = completed_steps )
472
- selfpaced_status .log_per_task_weight += per_task_weight
482
+ with torch .no_grad ():
483
+ per_task_weight = coba_status .compute_per_task_weight (completed_steps = completed_steps )
484
+ coba_status .log_per_task_weight += per_task_weight
485
+ # logger.info(f'per_task_weight: {per_task_weight}', main_process_only=True)
473
486
474
487
# loss
475
488
loss , task_loss , _ = loss_func_mft (
@@ -524,11 +537,12 @@ def accelerate_train(self):
524
537
# If the accelerator has performed an optimization step behind the scenes, thus a completed_step done.
525
538
if self .accelerator .sync_gradients :
526
539
if (
527
- self .args .weighted_loss_mode == "selfpaced "
528
- and completed_steps % self .args .selfpaced_interval == 0
540
+ self .args .weighted_loss_mode == "coba "
541
+ and completed_steps % self .args .coba_update_interval == 0
529
542
and completed_steps >= 1
530
543
):
531
- selfpaced_status .sample_valid_batch (self .model , completed_steps )
544
+ coba_status .sample_valid_batch (self .model , completed_steps )
545
+ # logger.info(f"valid_task_loss: {coba_status.valid_task_loss_accumulated}", main_process_only=True)
532
546
533
547
# progress_bar.update(1)
534
548
completed_steps += 1
@@ -542,7 +556,7 @@ def accelerate_train(self):
542
556
reduce_task_loss ,
543
557
reduce_task_exist ,
544
558
completed_steps ,
545
- selfpaced_status ,
559
+ coba_status ,
546
560
)
547
561
# reset reduce_loss
548
562
reduce_loss = torch .tensor (0.0 ).to (self .model .device )
0 commit comments