Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Direct Preference Optimization (DPO) method #1279

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

anupamme
Copy link

@anupamme anupamme commented Feb 12, 2025

Fixes #513

Implement the Direct Preference Optimization (DPO) method as a Reinforcement Learning from Human Feedback (RLHF) example.

  • Add DPO Functions: Add get_batched_logps and dpo_loss functions to llms/mlx_lm/utils.py for DPO implementation.
  • Update Training Logic: Update llms/mlx_lm/tuner/trainer.py to include DPO-specific training logic, including a new dpo_loss function and condition to check for DPO loss in the training loop.
  • Add Configuration Options: Add configuration options for DPO in llms/mlx_lm/examples/lora_config.yaml.
  • Update Documentation: Update llms/mlx_lm/README.md to include instructions for using DPO.
  • Add Unit Tests: Add llms/tests/test_dpo.py with unit tests for get_batched_logps, dpo_loss, and DPO-specific training logic.

For more details, open the Copilot Workspace session.

Fixes ml-explore#513

Implement the Direct Preference Optimization (DPO) method as a Reinforcement Learning from Human Feedback (RLHF) example.

* **Add DPO Functions**: Add `get_batched_logps` and `dpo_loss` functions to `llms/mlx_lm/utils.py` for DPO implementation.
* **Update Training Logic**: Update `llms/mlx_lm/tuner/trainer.py` to include DPO-specific training logic, including a new `dpo_loss` function and condition to check for DPO loss in the training loop.
* **Add Configuration Options**: Add configuration options for DPO in `llms/mlx_lm/examples/lora_config.yaml`.
* **Update Documentation**: Update `llms/mlx_lm/README.md` to include instructions for using DPO.
* **Add Unit Tests**: Add `llms/tests/test_dpo.py` with unit tests for `get_batched_logps`, `dpo_loss`, and DPO-specific training logic.

---

For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/ml-explore/mlx-examples/issues/513?shareId=XXXX-XXXX-XXXX-XXXX).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Reinforcement Learning from Human Feedback (RLHF) examples: Direct Preference Optimization (DPO)
1 participant