Skip to content

ahmed-alllam/Direct-Preference-Optimization

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

14 Commits
 
 
 
 
 
 
 
 

Repository files navigation

Contributors Forks Stargazers Issues LinkedIn


Direct Preference Optimization from scratch in PyTorch

Report Bug · Request Feature

About The Project

This project is an implementation of Direct Preference Optimization, an alternative to RLHF for aligning Large Language Models (LLMs) to human. The algorithm is described in the research paper Direct Preference Optimization: Your Language Model is Secretly a Reward Model .

Direct Preference Optimization (DPO) is a promising and efficient technique for fine-tuning Large Language Models (LLMs) aligned with human preferences. Compared to traditional Reinforcement Learning From Human Feedback (RLHF), DPO eliminates the need for a separate reward model and simplifies the training process, leading to better stability and computational efficiency.

The key insight in Direct Preference Optimization is replacing the complex reward modeling process in RLHF with a simple loss function that directly optimizes for human preferences in closed form. It does this by simply increasing the log probability of the tokens in the human prefered responses, and decreasing the log probability of the tokens in the human disprefered responses, given a preferences dataset, which basically makes the model have an implicit reward function that is directly optimized for human preferences. Through this clever math trick, the process now becomes much simpler and more efficient than RLHF, as it does not require a separate reward model, and it is also more stable, as it does not use other methods like PPO for fine-tuning.

The DPO loss function is defined as follows:

$$ L_\text{DPO}(\pi_{\theta}; \pi_\text{ref}) = -E_{(x, y_w, y_l)\sim D}\left[\log \sigma \left( \beta \log \frac{\pi_{\theta}(y_w\mid x)}{\pi_\text{ref}(y_w\mid x)} \thinspace {- \beta \log \frac{\pi_{\theta}(y_l\mid x)}{\pi_\text{ref}(y_l\mid x)}}\right)\right] $$

where:

  • $\pi_{\theta}$ is the language model we want to fine-tune
  • $\pi_\text{ref}$ is a reference model, usually a frozen version of the original pre-trained language model
  • $D$ is the dataset of preferences
  • $x$ is a sample prompt from the dataset $D$
  • $y_w$ is the human prefered response to the prompt $x$
  • $y_l$ is the human disprefered response to the prompt $x$
  • $\beta$ is a hyperparameter that controls the amount of divergence from the reference model $\pi_\text{ref}$

The DPO loss function can be broken down into two main terms, the first term represents the log probability of the human-preferred response $y_w$. This term aims to maximize the probability of $y_w$ as generated by the model $\pi_{\theta}$, relative to the reference model $\pi_{\text{ref}}$. The division by $\pi_{\text{ref}}$ serves as a regularizing factor, ensuring that the fine-tuning does not cause the model to deviate excessively from its original training. Maximizing this term effectively increases the likelihood of $\pi_{\theta}$ generating responses similar to $y_w$ in response to inputs like $x$, reinforcing the human preference patterns. Conversely, the second term focuses on minimizing the log probability of the human-dispreferred response $y_l$. This is achieved by reducing the model's tendency to generate $y_l$ type responses, as indicated by the negative sign.

The hyperparameter $\beta$, typically set between 0.1 and 0.5, affects the amount of divergence from the reference model $\pi_\text{ref}$, allowing for controlled adjustments in the model's outputs while preventing significant deviations from the behavior of the reference model. The entire computation is then simply averaged across the dataset $D$ or a batch of samples from it, giving us the final DPO loss that we can optimize for using gradient descent to fine-tune the language model.

For a detailed explanation, you can check my blog post Unveiling the Hidden Reward System in Language Models: A Dive into DPO

About

Direct Preference Optimization from scratch in PyTorch

Resources

Stars

Watchers

Forks