Skip to content

Commit a4cd0b9

Browse files
authored
Merge pull request #81 from codefuse-ai/support_coba_loss
Support coba loss
2 parents a68a716 + 65f4511 commit a4cd0b9

12 files changed

+231
-148
lines changed

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646

4747

4848
## News
49+
🔥🔥🔥 [2024/10/29] Our paper [CoBa: Convergence Balancer for Multitask Finetuning of Large Language Models](https://arxiv.org/abs/2410.06741) has been accepted by EMNLP-2024, which achieves balanced convergence across various tasks.
50+
4951
🔥🔥🔥 [2024/05/20] We released **MFTCoder v0.4**, mainly for MFTCoder-accelerate. It supports **QLoRA + DeepSpeed Zero3** and **QLoRA + FSDP** as options allowing you training very large models. It now supports new models like Qwen2, Qwen2-MoE, Starcoder2, Gemma, etc.
5052

5153
🔥🔥🔥 [2024/05/20] Our paper [MFTCoder: Boosting Code LLMs with Multitask Fine-Tuning](https://arxiv.org/abs/2311.02303) has been accepted by KDD2024.

README_cn.md

+2
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545

4646

4747
## 新闻
48+
🔥🔥🔥 [2024/10/29] 我们的论文 [CoBa: Convergence Balancer for Multitask Finetuning of Large Language Models](https://arxiv.org/abs/2410.06741) 已被 EMNLP 2024 接收,可以实现多任务收敛均衡。
49+
4850
🔥🔥🔥 [2024/05/20] **MFTCoder-v0.4**发布。新增支持**QLoRA+ DeepSpeed Zero3**, **QLoRA + FSDP**训练模式,可以更好的支持微调更大的模型,比如Qwen1.5-70B等。新增对Qwen2, Qwen2-MoE, Starcoder2, Gemma等模型的支持。
4951

5052
🔥🔥🔥 [2024/05/20] 我们的论文 [MFTCoder: Boosting Code LLMs with Multitask Fine-Tuning](https://arxiv.org/abs/2311.02303) 已被 KDD 2024 接收.

mftcoder_accelerate/README.md

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
[[中文]](README_cn.md) [**English**]
88

99
## 1. Updates
10+
🔥 MFTCoder-accelerate supports latest implementation of CoBa Loss (selfpaced Loss) for better Convergence Balance.
11+
1012
🔥 MFTCoder-accelerate now support these modes: QLoRA/LoRA + DeepSpeed ZeRO2, QLoRA + DeepSpeed ZeRO3, Full-parameter + DeepSpeed ZeRO3, QLoRA + FSDP, Full-parameter + FSDP.
1113

1214
🔥 MFTCoder-accelerate supports QLoRA + DeepSpeed ZeRO3 and QLoRA + FSDP, which both work for larger models;

mftcoder_accelerate/README_cn.md

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
[**中文**] [[English]](README.md)
88

99
## 1. 更新
10+
🔥 MFTCoder-accelerate 增加了CoBa Loss的最新实现(原selfpaced Loss), 让收敛均衡更进一步。
11+
1012
🔥 MFTCoder-accelerate 最新支持的训练模式包括: QLoRA/LoRA + DeepSpeed ZeRO2, QLoRA + DeepSpeed ZeRO3, 全量 + DeepSpeed ZeRO3, QLoRA + FSDP, 全量 + FSDP。
1113

1214
🔥 MFTCoder-accelerate 新增支持QLoRA + DeepSpeed ZeRO3, 支持QLoRA + FSDP, 可以训练更大的模型;

mftcoder_accelerate/src/configs/selfpaced_train_config.json renamed to mftcoder_accelerate/src/configs/coba_train_config.json

+10-9
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,17 @@
55
"pretrained_model_path": "$MODEL_NAME_OR_PATH",
66
"model_type": "$MODEL_TYPE",
77
"load_raw_dataset": true,
8-
"data_split": "98,2,0",
8+
"data_split": "95,5,0",
99
"padding_mode": "padding",
1010
"use_dynamic_padding": true,
1111
"tokenize_mode": "sft",
1212
"tokenizer_type": "AutoTokenizer",
13-
"weighted_loss_mode": "selfpaced",
14-
"selfpaced_interval": 1,
15-
"selfpaced_history_length": 100,
16-
"selfpaced_sample_valid_num": 1,
17-
"selfpaced_scale_factor": 50,
13+
"weighted_loss_mode": "coba",
14+
"coba_warmup_steps": 100,
15+
"coba_history_length": 200,
16+
"coba_tau": 5,
17+
"coba_update_interval": 1,
18+
"coba_sample_valid_num": 1,
1819
"attn_implementation": "flash_attention_2",
1920
"seq_length": 4096,
2021
"seed": 1234,
@@ -23,8 +24,8 @@
2324
"lora_rank": 96,
2425
"lora_alpha": 32,
2526
"lora_dropout": 0.05,
26-
"per_device_train_batch_size": 2,
27-
"per_device_eval_batch_size": 2,
27+
"per_device_train_batch_size": 8,
28+
"per_device_eval_batch_size": 8,
2829
"learning_rate": 5e-5,
2930
"min_lr": 5e-6,
3031
"weight_decay": 0.1,
@@ -42,4 +43,4 @@
4243
"early_stopping": true,
4344
"early_stopping_stall_num": 5,
4445
"saving_limit": null
45-
}
46+
}

mftcoder_accelerate/src/mpt/mpt_accelerate.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def prepare_args():
209209
# generate TASK2ID, ID2TASK
210210
generate_task_id(args.data_paths)
211211

212-
if args.weighted_loss_mode == "selfpaced":
212+
if args.weighted_loss_mode == "coba":
213213
args.task_weights = [1.0] * len(ID2TASK)
214214
elif args.task_weights is not None:
215215
args.task_weights = [float(wt) for wt in args.task_weights[1:-1].split(",")]

mftcoder_accelerate/src/mpt/mpt_arguments.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -126,14 +126,16 @@ class MptTrainArgs:
126126
# if dynamic padding
127127
use_dynamic_padding: bool = True
128128

129-
# interval of update per task train weight in selfpaced
130-
selfpaced_interval: int = 1
131-
# history length of sample valid loss used to fit the slope curve in selfpaced
132-
selfpaced_history_length: int = 100
133-
# the number of mini valid batches sampled at each interval
134-
selfpaced_sample_valid_num: int = 1
135-
# scale factor before softmax
136-
selfpaced_scale_factor: int = 50
129+
# warm-up steps for CoBa, recommand the number of valid batches
130+
coba_warmup_steps: int = 100
131+
# history length of sample valid loss used to fit the slope curve in CoBa
132+
coba_history_length: int = 200
133+
# temperature for divergence factor in CoBa
134+
coba_tau: int = 5
135+
# iteration interval of update per task train weight in CoBa
136+
coba_update_interval: int = 1
137+
# the number of mini valid batches sampled at each updated iteration interval
138+
coba_sample_valid_num: int = 1
137139

138140
# ATTENTION_CLASSES = { "eager": Normal Attention, "flash_attention_2": FlashAttention2}
139141
attn_implementation: str = "flash_attention_2"

mftcoder_accelerate/src/mpt/mpt_trainer.py

+45-31
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
# sys.path.append("..")
3232
from utils.common_utils import generate_task_id, TASK2ID, ID2TASK
33-
from utils.loss_utils import loss_func_mft, SelfpacedStatus, load_balancing_loss_func
33+
from utils.loss_utils import loss_func_mft, CoBaStatus, load_balancing_loss_func
3434

3535
logger = get_logger(__name__)
3636

@@ -239,7 +239,7 @@ def accelerate_monitor(
239239
reduce_task_loss,
240240
reduce_task_exist,
241241
completed_steps,
242-
selfpaced_status=None,
242+
coba_status=None,
243243
):
244244
"""
245245
gather reduce_loss and reduce_task_loss from all N devices.
@@ -263,27 +263,27 @@ def accelerate_monitor(
263263
f"[lr={self.lr_scheduler.get_lr()[0]:.4e}, {self.optimizer.param_groups[0]['lr']:.4e}]",
264264
main_process_only=True,
265265
)
266-
if selfpaced_status is not None:
267-
if completed_steps > selfpaced_status.selfpaced_history_length:
268-
selfpaced_status.log_per_task_weight = selfpaced_status.log_per_task_weight / torch.sum(
269-
selfpaced_status.log_per_task_weight
266+
if coba_status is not None:
267+
if completed_steps > coba_status.coba_warmup_steps:
268+
coba_status.log_per_task_weight = coba_status.log_per_task_weight / torch.sum(
269+
coba_status.log_per_task_weight
270270
)
271271
else:
272-
selfpaced_status.log_per_task_weight = torch.ones(len(ID2TASK)) / len(ID2TASK)
272+
coba_status.log_per_task_weight = torch.ones(len(ID2TASK)) / len(ID2TASK)
273273
logger.info(
274-
f"[TRAIN][per_task_train_weight={selfpaced_status.log_per_task_weight}]", main_process_only=True
274+
f"[TRAIN][per_task_train_weight={coba_status.log_per_task_weight}]", main_process_only=True
275275
)
276276
train_log_dict = {"Loss/train": train_loss}
277277
for i in range(len(ID2TASK)):
278278
train_log_dict[f"{ID2TASK[i]}_loss/train"] = train_task_loss[i]
279-
if selfpaced_status is not None:
280-
train_log_dict[f"{ID2TASK[i]}_selfpaced_weight/train"] = selfpaced_status.log_per_task_weight[i].item()
279+
if coba_status is not None:
280+
train_log_dict[f"{ID2TASK[i]}_coba_weight/train"] = coba_status.log_per_task_weight[i].item()
281281

282282
if self.accelerator.is_main_process:
283283
write_tensorboard(self.summary_writer, train_log_dict, completed_steps)
284284

285-
if selfpaced_status is not None:
286-
selfpaced_status.log_per_task_weight = torch.zeros(len(ID2TASK))
285+
if coba_status is not None:
286+
coba_status.log_per_task_weight = torch.zeros(len(ID2TASK))
287287

288288
def accelerate_evaluate(
289289
self,
@@ -416,18 +416,29 @@ def accelerate_train(self):
416416
reduce_task_exist = torch.zeros(len(ID2TASK)).to(self.model.device)
417417
per_task_weight = self.args.task_weights
418418

419-
if self.args.weighted_loss_mode == "selfpaced":
420-
selfpaced_status = SelfpacedStatus(
421-
self.args.selfpaced_scale_factor,
422-
self.args.selfpaced_interval,
423-
self.args.selfpaced_history_length,
424-
self.args.selfpaced_sample_valid_num,
419+
if self.args.weighted_loss_mode == "coba":
420+
self.model.eval()
421+
eval_loss, eval_task_loss, _, _, _ = self.accelerate_evaluate(
422+
completed_steps,
423+
0,
424+
min_eval_loss,
425+
stall_num,
426+
best_step,
427+
)
428+
self.model.train()
429+
coba_status = CoBaStatus(
430+
self.args.coba_warmup_steps,
431+
self.args.coba_history_length,
432+
self.args.coba_tau,
433+
self.args.coba_update_interval,
434+
self.args.coba_sample_valid_num,
425435
self.valid_dataloader,
426436
)
427-
selfpaced_status.sample_valid_batch(self.model, completed_steps)
428-
selfpaced_status.valid_iterator = iter(selfpaced_status.valid_dataloader)
437+
coba_status.valid_task_loss_begining = eval_task_loss.clone().to(self.model.device)
438+
coba_status.sample_valid_batch(self.model, completed_steps)
439+
logger.info(f"valid_task_loss: {coba_status.valid_task_loss_accumulated}", main_process_only=True)
429440
else:
430-
selfpaced_status = None
441+
coba_status = None
431442

432443
# Training Loop!
433444
for epoch in range(starting_epoch, self.args.num_train_epochs):
@@ -463,13 +474,15 @@ def accelerate_train(self):
463474
)
464475

465476
if (
466-
self.args.weighted_loss_mode == "selfpaced"
467-
and step % self.args.gradient_accumulation_steps == 0
468-
and completed_steps % self.args.selfpaced_interval == 0
469-
and completed_steps >= self.args.selfpaced_history_length
477+
self.args.weighted_loss_mode == "coba"
478+
and self.accelerator.sync_gradients
479+
and completed_steps % self.args.coba_update_interval == 0
480+
and completed_steps >= self.args.coba_warmup_steps
470481
):
471-
per_task_weight = selfpaced_status.compute_per_task_weight(completed_steps=completed_steps)
472-
selfpaced_status.log_per_task_weight += per_task_weight
482+
with torch.no_grad():
483+
per_task_weight = coba_status.compute_per_task_weight(completed_steps=completed_steps)
484+
coba_status.log_per_task_weight += per_task_weight
485+
# logger.info(f'per_task_weight: {per_task_weight}', main_process_only=True)
473486

474487
# loss
475488
loss, task_loss, _ = loss_func_mft(
@@ -524,11 +537,12 @@ def accelerate_train(self):
524537
# If the accelerator has performed an optimization step behind the scenes, thus a completed_step done.
525538
if self.accelerator.sync_gradients:
526539
if (
527-
self.args.weighted_loss_mode == "selfpaced"
528-
and completed_steps % self.args.selfpaced_interval == 0
540+
self.args.weighted_loss_mode == "coba"
541+
and completed_steps % self.args.coba_update_interval == 0
529542
and completed_steps >= 1
530543
):
531-
selfpaced_status.sample_valid_batch(self.model, completed_steps)
544+
coba_status.sample_valid_batch(self.model, completed_steps)
545+
# logger.info(f"valid_task_loss: {coba_status.valid_task_loss_accumulated}", main_process_only=True)
532546

533547
# progress_bar.update(1)
534548
completed_steps += 1
@@ -542,7 +556,7 @@ def accelerate_train(self):
542556
reduce_task_loss,
543557
reduce_task_exist,
544558
completed_steps,
545-
selfpaced_status,
559+
coba_status,
546560
)
547561
# reset reduce_loss
548562
reduce_loss = torch.tensor(0.0).to(self.model.device)

mftcoder_accelerate/src/pefts/mft_accelerate.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def prepare_args():
217217
# generate TASK2ID, ID2TASK
218218
generate_task_id(args.data_paths)
219219

220-
if args.weighted_loss_mode == "selfpaced":
220+
if args.weighted_loss_mode == "coba":
221221
args.task_weights = [1.0] * len(ID2TASK)
222222
elif args.task_weights is not None:
223223
args.task_weights = [float(wt) for wt in args.task_weights[1:-1].split(",")]

mftcoder_accelerate/src/pefts/mft_arguments.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -141,14 +141,16 @@ class MftTrainArgs:
141141
# if dynamic padding
142142
use_dynamic_padding: bool = True
143143

144-
# interval of update per task train weight in selfpaced
145-
selfpaced_interval: int = 1
146-
# history length of sample valid loss used to fit the slope curve in selfpaced
147-
selfpaced_history_length: int = 100
148-
# the number of mini valid batches sampled at each interval
149-
selfpaced_sample_valid_num: int = 1
150-
# scale factor before softmax
151-
selfpaced_scale_factor: int = 50
144+
# warm-up steps for CoBa, recommand the number of valid batches
145+
coba_warmup_steps: int = 100
146+
# history length of sample valid loss used to fit the slope curve in CoBa
147+
coba_history_length: int = 200
148+
# temperature for divergence factor in CoBa
149+
coba_tau: int = 5
150+
# iteration interval of update per task train weight in CoBa
151+
coba_update_interval: int = 1
152+
# the number of mini valid batches sampled at each updated iteration interval
153+
coba_sample_valid_num: int = 1
152154

153155
# ATTENTION_CLASSES = { "eager": Normal Attention, "flash_attention_2": FlashAttention2}
154156
attn_implementation: str = "flash_attention_2"

0 commit comments

Comments
 (0)