Skip to content

Commit 7dbba98

Browse files
authored
[RL] Fix typo and add wandb log (#10641)
* add log * add INFERENCE_TRUNCATED_RETURN_EOS * add INFERENCE_TRUNCATED_RETURN_EOS
1 parent eee766b commit 7dbba98

File tree

6 files changed

+8
-7
lines changed

6 files changed

+8
-7
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../../../../../llm/devices/intel_hpu/tests/README.md

docs/zh/llm/docs/pretrain.md

Lines changed: 0 additions & 1 deletion
This file was deleted.

docs/zh/llm/docs/pretrain.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../../../llm/docs/pretrain.md

llm/alignment/rl/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ export FLAGS_cascade_attention_max_partition_size=2048
203203

204204
python -u -m paddle.distributed.launch --devices "0,1,2,3" run_rl.py ../../config/qwen/reinforce_plus_plus_argument.yaml
205205
```
206+
我们提供根据上述脚本可复现的[wandb 日志](https://api.wandb.ai/links/ainlp66-netflix/injcw3ra)
206207

207208
### 在线监控
208209
`grpo_argument.yaml``reinforce_plus_plus_argument.yaml`中设置的输出目录为`"logging_dir": "vdl_log"`, 可以通过以下命令查看训练过程

llm/config/qwen/grpo_32b_argument.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ disable_tqdm: true # Whether to disable tqdm progress bar
7474

7575
# RL args
7676
kl_coeff: 0.001 # KL coefficient for PPO and Reinforce++
77-
kl_loss_coeff: 0.001 # KL loss coefficient
77+
kl_loss_coeff: 0.000 # KL loss coefficient
7878
pg_loss_coeff: 1.0 # Policy gradient loss coefficient
7979
entropy_coeff: 0.0 # Entropy coefficient
8080
clip_range_ratio: 0.2 # The clipping range for ratio between the old and new policy. (PPO algorithm)

llm/config/qwen/grpo_argument.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ disable_tqdm: true # Whether to disable tqdm progress bar
7474

7575
# RL args
7676
kl_coeff: 0.001 # KL coefficient for PPO and Reinforce++
77-
kl_loss_coeff: 0.001 # KL loss coefficient
77+
kl_loss_coeff: 0.000 # KL loss coefficient
7878
pg_loss_coeff: 1.0 # Policy gradient loss coefficient
7979
entropy_coeff: 0.0 # Entropy coefficient
8080
clip_range_ratio: 0.2 # The clipping range for ratio between the old and new policy. (PPO algorithm)

paddlenlp/rl/trainer/ppo_trainer.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1526,7 +1526,7 @@ def train(
15261526
if self.args.rl_algorithm == "ppo":
15271527
batch["reward_values"] = self.critic_trainer.compute_value(**batch)
15281528

1529-
# danamic sampling: filter generated samples by rewards, keep generating until valid samples are enough
1529+
# dynamic sampling: filter generated samples by rewards, keep generating until valid samples are enough
15301530
if self.args.dynamic_sampling:
15311531
local_valid_prompt = 0
15321532
# combined_batch = combine_micro_batches_into_batch(micro_batches, pad_token_id=self.tokenizer.pad_token_id)
@@ -1601,7 +1601,7 @@ def train(
16011601
total_batch = defaultdict(list)
16021602
total_valid_prompt = 0
16031603
num_gen_batches = 0
1604-
logger.info("Danymic sampling completed. \n")
1604+
logger.info("Dynamic sampling completed. \n")
16051605

16061606
else:
16071607
if self.args.max_gen_batches > 0 and num_gen_batches > self.args.max_gen_batches:
@@ -1664,7 +1664,7 @@ def train(
16641664
paddle.device.cuda.empty_cache()
16651665

16661666
if self.args.rl_algorithm == "ppo":
1667-
rl_info["train_value_loss"] = self.critic_trainer.update_critc(micro_batch)
1667+
rl_info["train_value_loss"] = self.critic_trainer.update_critic(micro_batch)
16681668
if self.is_step_end():
16691669
self.state.global_step += 1
16701670
self.state.epoch = epoch + (step + 1) / steps_in_epoch
@@ -1701,7 +1701,6 @@ def train(
17011701

17021702
if self.control.should_training_stop:
17031703
break
1704-
# TODO(guosheng): add epilogue of training
17051704
logger.info("\nTraining completed. \n")
17061705
if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
17071706
if args.local_rank != -1:

0 commit comments

Comments
 (0)