Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding orpo training #1210

Open
wants to merge 35 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
1ff7888
initial commit
Goekdeniz-Guelmez Jan 18, 2025
582f979
fixing reference model loading and freezing
Goekdeniz-Guelmez Jan 18, 2025
1b4e196
update LORA.md
Goekdeniz-Guelmez Jan 18, 2025
06a9f5d
update lora_config.yaml
Goekdeniz-Guelmez Jan 18, 2025
040f7c3
update ACKNOWLEDGMENTS.md
Goekdeniz-Guelmez Jan 18, 2025
51fd621
nits
Goekdeniz-Guelmez Jan 19, 2025
a9b7609
initial commit
Goekdeniz-Guelmez Jan 19, 2025
7d279b5
remerge with dpo
Goekdeniz-Guelmez Jan 19, 2025
fa80d08
finish
Goekdeniz-Guelmez Jan 19, 2025
9ede9db
nits
Goekdeniz-Guelmez Jan 19, 2025
424cb85
nits
Goekdeniz-Guelmez Jan 19, 2025
2a5b315
update ACKNOWLEDGMENTS.md
Goekdeniz-Guelmez Jan 19, 2025
ea0d11c
update
Goekdeniz-Guelmez Jan 19, 2025
363bde6
fixes
Goekdeniz-Guelmez Jan 19, 2025
61cd253
nits
Goekdeniz-Guelmez Jan 19, 2025
4098c3b
Merge branch 'ml-explore:main' into adding-orpo-training
Goekdeniz-Guelmez Jan 22, 2025
0bb0011
niits
Goekdeniz-Guelmez Jan 22, 2025
e368829
removing dpo and fixing some stuff for orpo
Goekdeniz-Guelmez Jan 24, 2025
09ed837
updates
Goekdeniz-Guelmez Jan 24, 2025
d8e7834
Removed rejected_rewards handling, Updated batch unpacking to match i…
Goekdeniz-Guelmez Jan 25, 2025
2f2ddd4
clean up
Goekdeniz-Guelmez Jan 26, 2025
294d189
Merge branch 'main' into adding-orpo-training
Goekdeniz-Guelmez Jan 26, 2025
649d3f8
fix ACKNOWLEDGMENTS
Goekdeniz-Guelmez Jan 26, 2025
ceccb4c
Merge branch 'ml-explore:main' into adding-orpo-training
Goekdeniz-Guelmez Jan 29, 2025
541677a
cleaning up
Goekdeniz-Guelmez Jan 31, 2025
2c96da5
Merge branch 'ml-explore:main' into adding-orpo-training
Goekdeniz-Guelmez Feb 3, 2025
c33c245
Merge branch 'ml-explore:main' into adding-orpo-training
Goekdeniz-Guelmez Feb 4, 2025
1beefd5
add create_dataset
Goekdeniz-Guelmez Feb 4, 2025
43940ec
fix Test
Goekdeniz-Guelmez Feb 4, 2025
5671266
nice metric printing in testing
Goekdeniz-Guelmez Feb 4, 2025
594b435
Merge branch 'ml-explore:main' into adding-orpo-training
Goekdeniz-Guelmez Feb 6, 2025
c967204
Merge branch 'ml-explore:main' into adding-orpo-training
Goekdeniz-Guelmez Feb 6, 2025
575ece6
Merge branch 'main' into adding-orpo-training
Goekdeniz-Guelmez Feb 10, 2025
80c64da
Merge branch 'ml-explore:main' into adding-orpo-training
Goekdeniz-Guelmez Feb 12, 2025
348f728
Merge branch 'ml-explore:main' into adding-orpo-training
Goekdeniz-Guelmez Feb 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ACKNOWLEDGMENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ MLX Examples was developed with contributions from the following individuals:
- Markus Enzweiler: Added the `cvae` examples.
- Prince Canuma: Helped add support for `Starcoder2` models.
- Shiyu Li: Added the `Segment Anything Model`.
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Helium`, `Mamba version 1` and support for `full-fine-tuning`.
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Helium`, `Mamba version 1` and support for `full-fine-tuning` and `Odds Ratio Preference Optimization (ORPO)` training.
55 changes: 54 additions & 1 deletion llms/mlx_lm/LORA.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families:
- Gemma
- OLMo
- MiniCPM
- Mamba
- InternLM2

## Contents

- [Run](#Run)
- [Fine-tune](#Fine-tune)
- [ORPO Training](#ORPO Training)
- [Evaluate](#Evaluate)
- [Generate](#Generate)
- [Fuse](#Fuse)
Expand Down Expand Up @@ -82,7 +84,58 @@ The default training computes a loss for every token in the sample. You can
ignore the prompt and compute loss for just the completion by passing
`--mask-prompt`. Note this is only supported for `chat` and `completion`
datasets. For `chat` datasets the final message in the message list is
considered the completion. See the [dataset section](#Data) for more details.
considered the completion. See the [dataset section](#Data) for more details.

### ORPO Training

Odds Ratio Preference Optimization (ORPO) training fine-tunes models using human preference data. Usage:

```shell
mlx_lm.lora \
--model <path_to_model> \
--train \
--training-mode orpo \
--data <path_to_data> \
--beta 0.1
```

Parameters:

- `--beta`: Temperature for logistic function (default: 0.1)

Data format (JSONL):

```jsonl
# Basic format with string responses
{"prompt": "User prompt", "chosen": "Preferred response", "rejected": "Less preferred response"}

# With custom preference score
{"prompt": "User prompt", "chosen": "Preferred response", "rejected": "Less preferred response", "preference_score": 8.0}

# With system message
{"prompt": "User prompt", "chosen": "Preferred response", "rejected": "Less preferred response", "system": "System instruction"}

# With full conversation objects
{
"prompt": "User prompt",
"chosen": {
"messages": [
{"role": "system", "content": "System instruction"},
{"role": "user", "content": "User message"},
{"role": "assistant", "content": "Assistant response"}
]
},
"rejected": {
"messages": [
{"role": "system", "content": "System instruction"},
{"role": "user", "content": "User message"},
{"role": "assistant", "content": "Assistant response"}
]
}
}
```

The trainer assigns binary rewards (1.0 chosen, 0.0 rejected) if no explicit rewards provided via `preference_score`.

### Evaluate

Expand Down
13 changes: 13 additions & 0 deletions llms/mlx_lm/examples/lora_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,19 @@ train: true
# The fine-tuning method: "lora", "dora", or "full".
fine_tune_type: lora

# The training-mode: "normal", or "dpo"
training_mode: normal

# If you set training_mode to "dpo"
# beta: 0.1
# The dpo-lodd-type: "sigmoid", "hinge", "ipo", or "dpop"
# dpo_loss_type: "sigmoid"
# is_reference_free: False
# delta: 50.0
# If reference_model_path is not given it will just use the same model
# reference_model_path: "mlx_model"
# train_bias_only: False

# Directory with {train, valid, test}.jsonl files
data: "/path/to/training/data"

Expand Down
142 changes: 104 additions & 38 deletions llms/mlx_lm/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .tokenizer_utils import TokenizerWrapper
from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.orpo_trainer import ORPOTrainingArgs, evaluate_orpo, train_orpo
from .tuner.utils import (
build_schedule,
linear_to_lora_layers,
Expand Down Expand Up @@ -43,6 +44,7 @@
"model": "mlx_model",
"train": False,
"fine_tune_type": "lora",
"training_mode": "normal",
"data": "data/",
"seed": 0,
"num_layers": 16,
Expand All @@ -62,6 +64,11 @@
"grad_checkpoint": False,
"lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
"beta": 0.1,
"dpo_loss_type": "sigmoid",
"delta": 50.0,
"reference_model_path": None,
"reward_scaling": 1.0,
}


Expand Down Expand Up @@ -102,6 +109,12 @@ def build_parser():
default=False,
)

parser.add_argument(
"--training-mode",
type=str,
choices=["normal", "dpo", "orpo"],
help="Training mode: normal, DPO or ORPO.",
)
parser.add_argument(
"--num-layers",
type=int,
Expand Down Expand Up @@ -143,7 +156,7 @@ def build_parser():
parser.add_argument(
"--test",
action="store_true",
help="Evaluate on the test set after training",
help="Evaluate on the test set after training.",
default=None,
)
parser.add_argument(
Expand All @@ -160,15 +173,29 @@ def build_parser():
"-c",
"--config",
type=str,
help="A YAML configuration file with the training options",
help="A YAML configuration file with the training options.",
)
parser.add_argument(
"--grad-checkpoint",
action="store_true",
help="Use gradient checkpointing to reduce memory use.",
default=None,
)
parser.add_argument("--seed", type=int, help="The PRNG seed")
parser.add_argument("--seed", type=int, help="The PRNG seed.")

# ORPO args
parser.add_argument(
"--beta",
type=float,
help="Temperature parameter for ORPO training.",
default=0.1
)
parser.add_argument(
"--reward-scaling",
type=float,
help="Reward scaling factor for ORPO training, not implemented.",
default=1.0
)
return parser


Expand Down Expand Up @@ -208,53 +235,92 @@ def train_model(
adapter_file = adapter_path / "adapters.safetensors"
save_config(vars(args), adapter_path / "adapter_config.json")

# init training args
training_args = TrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
)

model.train()
opt = optim.Adam(
learning_rate=(
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
)
)

# Train model based on training mode
if args.training_mode == "orpo":
training_args = ORPOTrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
beta=args.beta,
reward_scaling=args.reward_scaling
)

train_orpo(
model=model,
tokenizer=tokenizer,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
args=training_args,
training_callback=training_callback
)
else:
training_args = TrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint
)

# Train model
train(
model=model,
tokenizer=tokenizer,
args=training_args,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
training_callback=training_callback,
)
train(
model=model,
tokenizer=tokenizer,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
args=training_args,
training_callback=training_callback,
)


def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set):
model.eval()

test_loss = evaluate(
model=model,
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
)
if args.training_mode == "orpo":
test_loss, test_rewards, _, test_metrics = evaluate_orpo(
model=model,
dataset=test_set,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
beta=args.beta
)
test_ppl = math.exp(test_loss)
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}")

test_ppl = math.exp(test_loss)
print("ORPO Test Metrics:")
for metric_name, metric_value in test_metrics.items():
print(f" {metric_name}: {float(metric_value):.3f}")
else:
test_loss = evaluate(
model=model,
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
)

print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
test_ppl = math.exp(test_loss)
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")


def run(args, training_callback: TrainingCallback = None):
Expand All @@ -272,7 +338,7 @@ def run(args, training_callback: TrainingCallback = None):
load_adapters(model, args.adapter_path)

elif args.train:
print("Training")
print(f"Training in {args.training_mode} mode")
train_model(args, model, tokenizer, train_set, valid_set, training_callback)
else:
raise ValueError("Must provide at least one of --train or --test")
Expand Down Expand Up @@ -305,4 +371,4 @@ def main():


if __name__ == "__main__":
main()
main()
Loading