Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding grpo training #1233

Open
wants to merge 52 commits into
base: main
Choose a base branch
from

Conversation

Goekdeniz-Guelmez
Copy link
Contributor

No description provided.

@mark-lord
Copy link

mark-lord commented Feb 2, 2025

Absolute HERO! Been trying to figure this out myself the past week but made pretty much no progress whatsoever, other than to make a script that fills up all the RAM on my Mac 🤣

Is there any way to run this yet? I assume no since at the mo it's still marked as in draft + there isn't a lora_config.yaml like in the DPO example yet (not sure if it's needed)?

@Goekdeniz-Guelmez
Copy link
Contributor Author

No, not yet I still have to implement the Dataset Wrapper and some other stuff, I'll tell you when it's done.

Copy link

@Guo-astro Guo-astro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible need to use expanded_prompts, expanded_answers in both reward and loss

llms/mlx_lm/tuner/grpo_trainer.py Outdated Show resolved Hide resolved
@Goekdeniz-Guelmez
Copy link
Contributor Author

python -m mlx_lm.lora \
    --model Qwen/Qwen2.5-0.5B \
    --train \
    --data /Users/gokdenizgulmez/Desktop/test_grpo \
    --iters 5 \
    --batch-size 1 \
    --num-layers 4 \
    --val-batches 1 \
    --steps-per-report 1 \
    --adapter-path /Users/gokdenizgulmez/Desktop/test-grpo-full \
    --max-seq-length 128 \
    --grad-checkpoint \
    --training-mode grpo \
    --fine-tune-type lora \
    --beta 0.1 \
    --steps-per-eval 500 \
    --group-size 2

Output

Loading pretrained model
Fetching 7 files: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 124936.71it/s]
Loading datasets
Training
Trainable parameters: 0.109% (0.541M/494.033M)
Starting GRPO training with 5 reward functions..., iters: 5
[WARNING] Some prompts are longer than 128 tokens. Long prompts will be truncated.
Iter 1: Val loss 0.00000140, Val total_rewards_mean -0.359, Val total_rewards_std 0.010, Val grouped_rewards_mean -0.359, Val grouped_rewards_std 0.010, Val kl 0.000, Val reward_func_0_mean 0.000, Val reward_func_0_std 0.000, Val reward_func_1_mean 0.000, Val reward_func_1_std 0.000, Val reward_func_2_mean 0.000, Val reward_func_2_std 0.000, Val reward_func_3_mean 0.000, Val reward_func_3_std 0.000, Val reward_func_4_mean -1.794, Val reward_func_4_std 0.051, Val took 8.385s

But after that my 32 GB of ram get fully used. I tried to add some memory optimisations but the memory usage is still too much.

@Goekdeniz-Guelmez
Copy link
Contributor Author

Iter 1: Val loss -0.00000057, Val total_rewards_mean -0.387, Val total_rewards_std 0.026, Val grouped_rewards_mean -0.387, Val grouped_rewards_std 0.026, Val kl 0.000, Val r1_accuracy_reward_func_mean 0.000, Val r1_accuracy_reward_func_std 0.000, Val r1_int_reward_func_mean 0.000, Val r1_int_reward_func_std 0.000, Val r1_strict_format_reward_func_mean 0.000, Val r1_strict_format_reward_func_std 0.000, Val r1_soft_format_reward_func_mean 0.000, Val r1_soft_format_reward_func_std 0.000, Val r1_count_xml_mean -1.937, Val r1_count_xml_std 0.128, Val took 8.314s

Still uses too much memory.

@Goekdeniz-Guelmez
Copy link
Contributor Author

Goekdeniz-Guelmez commented Feb 3, 2025

So I tried using trl and the same amount of ram has been used, so no error on my side

@mark-lord
Copy link

🚀

Would you be able to share the datasets you used for the training? Will give it a go on my machine as soon as I can 🙌

@Goekdeniz-Guelmez
Copy link
Contributor Author

Will do that tomorrow 🤝

@Guo-astro
Copy link

🚀

Would you be able to share the datasets you used for the training? Will give it a go on my machine as soon as I can 🙌

I created a quick one only for testing the code

https://huggingface.co/datasets/Goastro/mlx-grpo-dataset

@Goekdeniz-Guelmez
Copy link
Contributor Author

python -m mlx_lm.lora \
    --model Qwen/Qwen2.5-0.5B \
    --train \
    --data /Users/gokdenizgulmez/Desktop/test_grpo \
    --iters 5 \
    --batch-size 1 \
    --num-layers 8 \
    --val-batches 1 \
    --steps-per-report 1 \
    --adapter-path /Users/gokdenizgulmez/Desktop/test-grpo-full \
    --max-seq-length 255 \
    --grad-checkpoint \
    --training-mode grpo \
    --fine-tune-type lora \
    --beta 0.1 \
    --steps-per-eval 500 \
    --group-size 2 \
    --max-completion-length 6

Output:

Loading pretrained model
Fetching 7 files: 100%|███████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 72853.92it/s]
Loading datasets
Training
Trainable parameters: 0.109% (0.541M/494.033M)
Fetching 7 files: 100%|███████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10955.27it/s]
Starting GRPO training with 5 reward functions..., iters: 5
Iter 1: Val loss 0.00000000, Val total_rewards_mean -0.354, Val total_rewards_std 0.012, Val grouped_rewards_mean -0.354, Val grouped_rewards_std 0.012, Val kl 0.000, Val r1_accuracy_reward_func_mean 0.000, Val r1_accuracy_reward_func_std 0.000, Val r1_int_reward_func_mean 0.000, Val r1_int_reward_func_std 0.000, Val r1_strict_format_reward_func_mean 0.000, Val r1_strict_format_reward_func_std 0.000, Val r1_soft_format_reward_func_mean 0.000, Val r1_soft_format_reward_func_std 0.000, Val r1_count_xml_mean -1.769, Val r1_count_xml_std 0.060, Val took 26.298s
Iter 1: Train loss -0.00001353, Total rewards mean -0.306, Total rewards std 0.001, Grouped rewards mean -0.306, Grouped rewards std 0.001, KL 0.000, r1_accuracy_reward_func mean 0.000, r1_accuracy_reward_func std 0.000, r1_int_reward_func mean 0.000, r1_int_reward_func std 0.000, r1_strict_format_reward_func mean 0.000, r1_strict_format_reward_func std 0.000, r1_soft_format_reward_func mean 0.000, r1_soft_format_reward_func std 0.000, r1_count_xml mean -1.532, r1_count_xml std 0.005, Learning Rate 1.000e-05, It/sec 0.079, Tokens/sec 25.072, Peak mem 7.254 GB
Iter 2: Train loss 0.00055540, Total rewards mean -0.572, Total rewards std 0.001, Grouped rewards mean -0.572, Grouped rewards std 0.001, KL 0.006, r1_accuracy_reward_func mean 0.000, r1_accuracy_reward_func std 0.000, r1_int_reward_func mean 0.000, r1_int_reward_func std 0.000, r1_strict_format_reward_func mean 0.000, r1_strict_format_reward_func std 0.000, r1_soft_format_reward_func mean 0.000, r1_soft_format_reward_func std 0.000, r1_count_xml mean -2.861, r1_count_xml std 0.005, Learning Rate 1.000e-05, It/sec 0.121, Tokens/sec 36.164, Peak mem 7.254 GB
Iter 3: Train loss 0.00070858, Total rewards mean -0.842, Total rewards std 0.003, Grouped rewards mean -0.842, Grouped rewards std 0.003, KL 0.013, r1_accuracy_reward_func mean 0.000, r1_accuracy_reward_func std 0.000, r1_int_reward_func mean 0.000, r1_int_reward_func std 0.000, r1_strict_format_reward_func mean 0.000, r1_strict_format_reward_func std 0.000, r1_soft_format_reward_func mean 0.000, r1_soft_format_reward_func std 0.000, r1_count_xml mean -4.210, r1_count_xml std 0.013, Learning Rate 1.000e-05, It/sec 0.110, Tokens/sec 31.790, Peak mem 7.254 GB
Iter 4: Train loss 0.00070563, Total rewards mean -1.161, Total rewards std 0.005, Grouped rewards mean -1.161, Grouped rewards std 0.005, KL 0.020, r1_accuracy_reward_func mean 0.000, r1_accuracy_reward_func std 0.000, r1_int_reward_func mean 0.000, r1_int_reward_func std 0.000, r1_strict_format_reward_func mean 0.000, r1_strict_format_reward_func std 0.000, r1_soft_format_reward_func mean 0.000, r1_soft_format_reward_func std 0.000, r1_count_xml mean -5.806, r1_count_xml std 0.024, Learning Rate 1.000e-05, It/sec 0.105, Tokens/sec 36.961, Peak mem 7.899 GB
Iter 5: Val loss 0.00057772, Val total_rewards_mean -0.345, Val total_rewards_std 0.005, Val grouped_rewards_mean -0.345, Val grouped_rewards_std 0.005, Val kl 0.006, Val r1_accuracy_reward_func_mean 0.000, Val r1_accuracy_reward_func_std 0.000, Val r1_int_reward_func_mean 0.000, Val r1_int_reward_func_std 0.000, Val r1_strict_format_reward_func_mean 0.000, Val r1_strict_format_reward_func_std 0.000, Val r1_soft_format_reward_func_mean 0.000, Val r1_soft_format_reward_func_std 0.000, Val r1_count_xml_mean -1.726, Val r1_count_xml_std 0.025, Val took 22.624s
Iter 5: Train loss 0.00059050, Total rewards mean -1.399, Total rewards std 0.006, Grouped rewards mean -1.399, Grouped rewards std 0.006, KL 0.026, r1_accuracy_reward_func mean 0.000, r1_accuracy_reward_func std 0.000, r1_int_reward_func mean 0.000, r1_int_reward_func std 0.000, r1_strict_format_reward_func mean 0.000, r1_strict_format_reward_func std 0.000, r1_soft_format_reward_func mean 0.000, r1_soft_format_reward_func std 0.000, r1_count_xml mean -6.994, r1_count_xml std 0.029, Learning Rate 1.000e-05, It/sec 0.156, Tokens/sec 39.539, Peak mem 7.899 GB
Saved final weights to /Users/gokdenizgulmez/Desktop/test-grpo-full/adapters.safetensors.

@mark-lord
Copy link

🥳🥳🥳

Working on my machine too! Not to mention it's plug-and-play with QLoRA as well, which I don't think TRL even has 😁 And already used it to get an 'aha' moment out of Phi-14b and do some knowledge injection 🚀
Screenshot 2025-02-04 at 02 10 40

@anson1788
Copy link

by the way, any chance you can share you dataset want to have a loook

Copy link

@Guo-astro Guo-astro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using gpu possible save some main memory

llms/mlx_lm/tuner/grpo_trainer.py Show resolved Hide resolved
@Goekdeniz-Guelmez
Copy link
Contributor Author

Goekdeniz-Guelmez commented Feb 11, 2025

Model: Qwen/Qwen2.5-1.5B
Dataset: Goastro/mlx-grpo-dataset
Prompt: Ann is cutting fabric to make curtains. She cuts a 4 foot by 6 foot rectangle for the living room, and a 2 foot by 4 foot rectangle for the bedroom. If the bolt of fabric is 16 feet by 12 feet, how much fabric is left in square feet?.

Before:

Assistant: <think>Let's break down the problem step by step.</think>
<answer>First, we need to calculate the total area of the fabric used for the curtains.</answer>
<think>Area of the living room curtain = 4 feet x 6 feet = 24 square feet</think>
<think>Area of the bedroom curtain = 2 feet x 4 feet = 8 square feet</think>
<answer>Total area used = 24 square feet + 8 square feet = 32 square feet</answer>
<think>Next, we need to calculate the total area of the bolt of fabric.</answer>
<think>Area of the bolt of fabric = 16 feet x 12 feet = 192 square feet</think>
<answer>Total area used = 32 square feet</answer>
<think>Finally, we can calculate the remaining fabric by subtracting the total area used from the total area of the bolt of fabric.</answer>
<answer>Remaining fabric = 192 square feet - 32 square feet = 160 square feet</answer>

After:

Assistant: <think>First, we need to calculate the total area of the fabric used for the curtains. The living room curtain is 4 feet by 6 feet, so its area is 4 * 6 = 24 square feet. The bedroom curtain is 2 feet by 4 feet, so its area is 2 * 4 = 8 square feet. The total area of the fabric used is 24 + 8 = 32 square feet. The bolt of fabric is 16 feet by 12 feet, so its total area is 16 * 12 = 192 square feet. Therefore, the amount of fabric left is 192 - 32 = 160 square feet.</think><answer>160</answer>

Copy link

@Guo-astro Guo-astro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Easier to read code 1/n

llms/mlx_lm/tuner/grpo_trainer.py Outdated Show resolved Hide resolved
Copy link

@Guo-astro Guo-astro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mainly adding 3 types :

RewardFunction GRPOExample GRPOBatch

llms/mlx_lm/tuner/datasets.py Show resolved Hide resolved
llms/mlx_lm/tuner/datasets.py Show resolved Hide resolved
llms/mlx_lm/tuner/grpo_trainer.py Show resolved Hide resolved
llms/mlx_lm/tuner/grpo_trainer.py Show resolved Hide resolved
llms/mlx_lm/tuner/grpo_trainer.py Show resolved Hide resolved
llms/mlx_lm/tuner/grpo_trainer.py Show resolved Hide resolved
llms/mlx_lm/tuner/grpo_trainer.py Show resolved Hide resolved
llms/mlx_lm/tuner/grpo_trainer.py Outdated Show resolved Hide resolved
llms/mlx_lm/tuner/grpo_trainer.py Outdated Show resolved Hide resolved
llms/mlx_lm/tuner/grpo_trainer.py Outdated Show resolved Hide resolved
@Goekdeniz-Guelmez
Copy link
Contributor Author

Thanks @Guo-astro however this did make the computation sky rocket Val from 70s - 80s to 130s - 150s and training has the same too, probably due to copying data multiple times through class instantiation.

@Guo-astro
Copy link

Guo-astro commented Feb 12, 2025

True. Then I think we need to to use those as few as possible. Python is managing all the class instance memories so it could be really slow😅
So you can just ignore those comments and go forward 🔥

@Goekdeniz-Guelmez
Copy link
Contributor Author

Goekdeniz-Guelmez commented Feb 12, 2025

I think I'll use a hybrid approach with your suggestions, because they make it more stable, maintainable, and easier to debug and test. Thanks for your help!!!!!!

@madroidmaq
Copy link
Contributor

Cannot load the existing dataset on HF, the following error was found when using Goastro/mlx-grpo-dataset for testing.

input:

python -m mlx_lm.lora \
    --model Qwen/Qwen2.5-0.5B \
    --train \
    --data Goastro/mlx-grpo-dataset \
    --iters 5 \
    --batch-size 1 \
    --num-layers 8 \
    --val-batches 1 \
    --steps-per-report 1 \
    --adapter-path ~/Desktop/test-grpo-full \
    --max-seq-length 255 \
    --grad-checkpoint \
    --training-mode grpo \
    --fine-tune-type lora \
    --beta 0.1 \
    --steps-per-eval 500 \
    --group-size 2 \
    --max-completion-length 6

output:

Loading pretrained model
Fetching 7 files: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 59313.39it/s]
Loading datasets
Loading Hugging Face dataset Goastro/mlx-grpo-dataset.
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/Users/madroid/Desktop/mlx-examples/llms/mlx_lm/lora.py", line 432, in <module>
    main()
  File "/Users/madroid/Desktop/mlx-examples/llms/mlx_lm/lora.py", line 428, in main
    run(types.SimpleNamespace(**args))
  File "/Users/madroid/Desktop/mlx-examples/llms/mlx_lm/lora.py", line 391, in run
    train_set, valid_set, test_set = load_dataset(args, tokenizer)
                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/madroid/Desktop/mlx-examples/llms/mlx_lm/tuner/datasets.py", line 313, in load_dataset
    train, valid, test = load_hf_dataset(args.data, tokenizer, args)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: load_hf_dataset() missing 1 required positional argument: 'config'

@Goekdeniz-Guelmez
Copy link
Contributor Author

Goekdeniz-Guelmez commented Feb 12, 2025

Should be fixed now I also suggest you to use the --use-prompt argument when its a base model and the dataset having only prompts and answers or use my Goekdeniz-Guelmez/GRPO-MLX-Dataset its gsm8k but prompted correctly like in the DeepSeek R1 paper so you dont need to use the --use-prompt.

@Goekdeniz-Guelmez
Copy link
Contributor Author

@mark-lord should be able to run it now!! If you want to use a base model you can use the Goastro/mlx-grpo-dataset but need to include the --use-prompt argument in the call or you can use my Goekdeniz-Guelmez/GRPO-MLX-Dataset which has the prompting already in it.

@Goekdeniz-Guelmez Goekdeniz-Guelmez marked this pull request as ready for review February 13, 2025 20:46
@Goekdeniz-Guelmez
Copy link
Contributor Author

Thanks!!

@mark-lord
Copy link

@mark-lord should be able to run it now!! If you want to use a base model you can use the Goastro/mlx-grpo-dataset but need to include the --use-prompt argument in the call or you can use my Goekdeniz-Guelmez/GRPO-MLX-Dataset which has the prompting already in it.

🔝🔝🔝 Will try and get it running on my machine ASAP!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.