Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] (Willing to PR) Proposal: Drop-in fast replacement of PreTrainedModel.generate #2569

Open
2 tasks done
fzyzcjy opened this issue Dec 24, 2024 · 6 comments
Open
2 tasks done
Labels
collaboration enhancement New feature or request feature high priority RLHF Using SGLang for post training

Comments

@fzyzcjy
Copy link

fzyzcjy commented Dec 24, 2024

Checklist

Motivation

Hi thanks for the lib! Currently, a lot of code uses model.generate(), such as TRL's PPOTrainer, etc. If we can make a drop-in replacement of it using SGLang, then everyone can very easily speed up their code related to generation. For example, TRL's PPOTrainer, OpenRLHF's train_ppo.py (not the train_ppo_ray.py which is more for distributed training). IMHO there are many places this can be useful - many online RL algorithm can benefit from this.

As for when to update SGLang weight from HF weight, most naive solution may be, we update weights every time the generate is called. This may not be a big problem, because we can configure the PPO batch size to be so huge that the model.generate is only called once.

Related: #2542 With that, we can reduce memory footprint outside generate.

Related resources

No response

@zhaochenyang20
Copy link
Collaborator

For example, TRL's PPOTrainer, OpenRLHF's train_ppo.py (not the train_ppo_ray.py which is more for distributed training). IMHO there are many places this can be useful - many online RL algorithm can benefit from this.

Thanks for pointing this out! For your information, we may not try to integrate SGLang into TRL ourselves since TRL is more or less out of date. But for OpenRLHF, yes. Definitely, we will do this. And I am working on that.

For OpenRLHF train_ppo, it's not distributed, and we have an engine API for update_weights_from_distributed. I am not sure if the current API in SGLang already can do this, or should we make some similar API? Like directly passing in the weights and name rather than using torch.distributed to broadcast weights?

For train_ppo_ray, we already have the update_weights_from_distributed API now. I will integrate it soon.

@zhaochenyang20
Copy link
Collaborator

As for when to update SGLang weight from HF weight, the most naive solution may be that we update weights every time generate is called. This may not be a big problem because we can configure the PPO batch size to be so huge that model.generate is only called once.

Also, I am not quite sure why we need to discuss this. In my experience, every time the policy model gets updated, we should also update the inference engine's weights. It's not related to the PPO batch size but only related to the training epochs?

@zhaochenyang20
Copy link
Collaborator

zhaochenyang20 commented Dec 24, 2024

Also, in your title, “Drop-in fast replacement of PreTrainedModel.generate”. Do you mean that changing the inference engine in these post-training frameworks, from huggingface/vllm to sglang? I am working on this and making some progress. If you do really have time and are willing to contribute, we are really glad. Do you have time for a quick discussion? I send you the link on WeChat. Thanks so much!

@zhaochenyang20 zhaochenyang20 added RLHF Using SGLang for post training enhancement New feature or request collaboration feature labels Dec 24, 2024
@fzyzcjy
Copy link
Author

fzyzcjy commented Dec 24, 2024

we may not try to integrate SGLang into TRL ourselves since TRL is more or less out of date. But for OpenRLHF, yes. Definitely, we will do this. And I am working on that.

Interesting! I have searched how people are doing OpenRLHF nowadays, have found TRL and OpenRLHF (and maybe other frameworks?), and it seems TRL is more popular than OpenRLHF. May I know a bit about the "out of date" thing? If it is out of date then I will not spend much time looking at it.

For OpenRLHF train_ppo, it's not distributed, and we have an engine API for update_weights_from_distributed. I am not sure if the current API in SGLang already can do this, or should we make some similar API? Like directly passing in the weights and name rather than using torch.distributed to broadcast weights?

Yes that would be super great. I have tried to use gloo backend and it can broadcast weight to same GPU (while nccl throws error), but I am not sure whether that's suboptimal. For example, I guess it at least have a memory copy? Instead, if we already know it is in the same GPU, maybe we can do zero-copy to save some time, and also avoid the need of broadcasting which is extra complexity and can introduce bugs.

Also, I am not quite sure why we need to discuss this. In my experience, every time the policy model gets updated, we should also update the inference engine's weights. It's not related to the PPO batch size but only related to the training epochs?
Also, in your title, “Drop-in fast replacement of PreTrainedModel.generate”. Do you mean that changing the inference engine in these post-training frameworks, from huggingface/vllm to sglang? I am working on this and making some progress.

My proposal was a drop-in replacement. In other words, users only need to do something like, say, model = SGLang.magically_wrap_the_model(model), and later they will not need to change any existing code to call model.generate. It is us that changes (wraps) model.generate to not call hf's generate but sglang's generate.

For example, users can directly use TRL's PPOTrainer and OpenRLHF's train_ppo (non-ray) trainer, with the only change being magically_wrap_the_model, without changing a single line of code about the trainer.

At the same time, we surely can do the non-dropin things, i.e. directly modify the TRL/OpenRLHF/whatever framework's code to add it.

Happy to see this is WIP and I am also happy to contribute some PRs (for this one and #2542)!

@zhaochenyang20
Copy link
Collaborator

Interesting! I have searched how people are doing OpenRLHF nowadays, have found TRL and OpenRLHF (and maybe other frameworks?), and it seems TRL is more popular than OpenRLHF. May I know a bit about the "out of date" thing? If it is out of date then I will not spend much time looking at it.

Well. A lot of my friends use OpenRLHF since it's easier to hack than trl. I think TRL and OpenRLHF are both good and perfect for us to contribute.

Yes that would be super great. I have tried to use gloo backend and it can broadcast weight to the same GPU (while nccl throws an error), but I am not sure whether that's suboptimal. For example, I guess it at least has a memory copy? Instead, if we already know it is on the same GPU, maybe we can do zero-copy to save some time and also avoid the need for broadcasting, which is extra complexity and can introduce bugs.

Well. Yesterday I used OpenRLHF with vllm on ray across 8 * H100, but the broadcast failed. Maybe I should use gloo instead of nccl this time. I will try this out. And, I am not sure how big the model we can use on a single GPU? PPO a 7B model takes 3 * H100 with adam offloading and co-locate ref/actor, critic/reward. If we can do this on one GPU, that would be perfect.

My proposal was a drop-in replacement. In other words, users only need to do something like, say, model = SGLang.magically_wrap_the_model(model), and later they will not need to change any existing code to call model.generate. It is us that changes (wraps) model.generate to not call hf's generate but sglang's generate.

For example, users can directly use TRL's PPOTrainer and OpenRLHF's train_ppo (non-ray) trainer, with the only change being magically_wrap_the_model, without changing a single line of code about the trainer.

At the same time, we surely can do the non-dropin things, i.e. directly modify the TRL/OpenRLHF/whatever framework's code to add it.

Sorry. I do not fully understand what is drop-in 😂 I always do non-drop-in and we call this open to use. Like how we integrate SGLang with xgrammar. We can discuss this tomorrow. Thanks so much for help and merry Christmas.

@fzyzcjy
Copy link
Author

fzyzcjy commented Dec 25, 2024

A lot of my friends use OpenRLHF since it's easier to hack than trl. I think TRL and OpenRLHF are both good and perfect for us to contribute.

I see, thanks for the info.

I am not sure how big the model we can use on a single GPU? PPO a 7B model takes 3 * H100 with adam offloading and co-locate ref/actor, critic/reward. If we can do this on one GPU, that would be perfect.

Theoretically speaking, if we use bf16 model weight, then

  • policy: 7B x 2byte/param = 14B
  • critic: 7B x 2byte/param = 14B
  • ref: 7B x 2byte/param = 14B
  • reward: 7B x 2byte/param = 14B

=> at least 56B.

If we can make SGLang use almost zero memory by temporarily removing model weight and kv cache, then looks like we can fill in one 80B card, though not sure whether there are enough memory for large enough batch size for forward/backward.

Sorry. I do not fully understand what is drop-in 😂 I always do non-drop-in and we call this open to use. Like how we integrate SGLang with xgrammar.

A super naive version would be:

class SGLangModelWrapper:
  def __init__(model: PreTrainedModel):
    self.model = model
    self.sglang_engine = create_sglang_engine_from_model(model)

  # be a proxy class (should write more code than this)
  def __getattr__(self, attr):
    return self.model.__getattr__(attr) 

  def generate(self, ...args...):
    return self.sglang_engine.generate()

usage:

model = SGLangModelWrapper(model)
trl.PPOTrainer(model=model).train()

then, even though PPOTrainer calls PreTrainedModel.generate, but our SGLangModelWrapper will make it call SGLang.generate and gets faster.

Thanks so much for help and merry Christmas.

You are welcome, and also merry Christmas :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
collaboration enhancement New feature or request feature high priority RLHF Using SGLang for post training
Projects
None yet
Development

No branches or pull requests

2 participants