Skip to content

Commit

Permalink
initial script copied from the dpo trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Feb 11, 2025
1 parent 7fb481f commit ccfaf0b
Show file tree
Hide file tree
Showing 3 changed files with 1,201 additions and 0 deletions.
30 changes: 30 additions & 0 deletions examples/research_projects/diffusion_grpo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Diffusion Model Alignment Using GRPO


This directory provides LoRA implementations of Diffusion [GRPO](https://arxiv.org/abs/2402.03300) an RL based alignment method which is a variant of Proximal Policy Optimization (PPO) in the diffusion model setting.

## SDXL training command

```bash
accelerate launch train_diffusion_grpo_sdxl.py \
--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
--output_dir="diffusion-sdxl-dpo" \
--mixed_precision="fp16" \
--dataset_name=kashif/pickascore \
--train_batch_size=8 \
--gradient_accumulation_steps=2 \
--gradient_checkpointing \
--use_8bit_adam \
--rank=8 \
--learning_rate=1e-5 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=2000 \
--checkpointing_steps=500 \
--run_validation --validation_steps=50 \
--seed="0" \
--report_to="wandb" \
--push_to_hub
```
8 changes: 8 additions & 0 deletions examples/research_projects/diffusion_grpo/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
accelerate>=0.16.0
torchvision
transformers>=4.25.1
ftfy
tensorboard
Jinja2
peft
wandb
Loading

0 comments on commit ccfaf0b

Please sign in to comment.