Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Bamba Model #10909

Open
wants to merge 65 commits into
base: main
Choose a base branch
from
Open

Add Bamba Model #10909

wants to merge 65 commits into from

Conversation

fabianlim
Copy link

@fabianlim fabianlim commented Dec 5, 2024

This is the companion PR to an huggingface PR for adding Bamba, which is a hybrid mamba2 architecture with SwiGLU. The checkpoints are jointly trained by IBM, Princeton, and UIUC.

In this PR we have:

  • Created the bamba model inference architecture, which we would like acknowledge the jamba team for referencing their implementation, whereby we modified to support full attention layers with RoPE and mamba v2.
  • Ensured that we have TP support.
  • Ensured we support chunked prefill. Currently we have a partial solution, which works only when the cont batch boundaries line up with the chunked boundaries. This is now completely fixed.
  • Ensured that we conform to the recent PR for adding pipeline support for SSM models.
  • Adapted the mamba v2 scan kernels into vllm/model_executor/layers/mamba/ops. Only the fwd kernels are extracted. Some modifications and fixes are made.
  • created tests/models/decoder_only/language/test_bamba.py with an initial ibm-fms/Bamba-9.8b-1.8T-hf. This is practically identical to test_mamba.py, only chunked prefill tests are disabled as it is currently not supported.

Currently only FlashAttention backend is supported, as we check fields like context_lens_tensor. Have not yet investigated other backends.

We would like to also acknowledge the draft codestral mamba PR from @tlrmchlsmth, which we also referenced the mixer.

  • we made a few simplications for bamba (simplified mixer from mamba v2)
  • Cuda graph capturing seems to be working, but we understand that cudagraphs are disabled for long sequence lengths. For SSM models the strength is in this regime, so can we handle it better?

Hope to discuss the following with the maintainers

  1. do we have to remove all the bwd kernels? yes we should
  2. for the full attention layers, we increase the sin_cos cache to cover the sequence length, if it is longer than max_sequence_len. This differs for other current models (e.g., llama). How can we better support long sequence lengths? we should keep this consistent with other models, so we propose to allow the sin_cos cache extension only when VLLM_ALLOW_LONG_MAX_MODEL_LEN is specified.
  3. have some ideas to support chunked pre-fill, but will appreciate some discussion with the maintainers on how to proceed. working on changing the kernels to support chunked prefill.
  4. since the mixer2 is simplified from mamba, should we rename it? we can keep it as is, but we should document the differences from mamba_ssm

cc: @ani300, @raghukiran1224, @cyang49, @njhill

Signed-off-by: Yu Chin Fabian Lim <[email protected]>
@fabianlim fabianlim marked this pull request as draft December 5, 2024 01:35
Copy link

github-actions bot commented Dec 5, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Signed-off-by: Yu Chin Fabian Lim <[email protected]>
@tlrmchlsmth
Copy link
Collaborator

Hi @fabianlim, thanks for the PR! It's really great to see progress being made on state-space models, especially for me as I unfortunately haven't been able to prioritize support for Mamba2

I'm happy to shepherd this PR and discuss any questions you have, especially to support chunked prefill. If you haven't already, can you join the developer slack for quicker discussion? (https://communityinviter.com/apps/vllm-dev/join-vllm-developers-slack)

Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
@fabianlim
Copy link
Author

fabianlim commented Dec 12, 2024

@tlrmchlsmth I cleaned up the PR quite abit, perhaps it might be a good time to get some early eyes. The chunked prefill implementation is incomplete ATM, as we discussed offline.

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

first pass, just a few comments. At a high level it looks good.

Will you add a test for tensor parallelism?

Comment on lines 9 to 10
# will be ch
MODELS = ["ibm-fms/Bamba-9.8b-1.8T-hf"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment trails off, but will there be a small test model available?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@raghukiran1224 any plans for a small test model? I think since we do outputs comparison it is not that good to just have a randomly initialised small model

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fabianlim @tlrmchlsmth would it be ok to test with a random model or would you rather have a tiny model (say 200M or so) to test with?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A tiny model with nonrandom weights would be much better!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw is there any update on this?

tests/models/decoder_only/language/test_bamba.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/mamba/ops/ssd_bmm.py Outdated Show resolved Hide resolved
Copy link

mergify bot commented Dec 13, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fabianlim.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 13, 2024
@fabianlim
Copy link
Author

@tlrmchlsmth i have addressed most of your comments now, not rebasing yet, waiting for you to look first. But I realized test_jamba.py has changed so I will need to do the rename and test again.

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fabianlim At a high level, the changes look good, and the PR looks good overall. I'll do a more thorough review once it's unmarked as draft.

Could you add unit tests for the added kernels in layers/mamba/ops?

tlrmchlsmth and others added 5 commits January 17, 2025 18:01
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Copy link

mergify bot commented Jan 20, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fabianlim.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 20, 2025
lora_config = vllm_config.lora_config

self.config = config
self.padding_idx = config.pad_token_id

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this one used anywhere?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh sorry good catch I will remove it

Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
@fabianlim
Copy link
Author

Note to self, some of the testing API has changed due to this PR #10353

Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
@mergify mergify bot removed the needs-rebase label Feb 1, 2025
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking pretty good, let's get this landed now that 4.48.2 is out!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's land mamba 2 in #9292

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you merge in latest main? We've already landed this change

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ya actually yest i reverted this file and took it from latest main, but somehow the diff shows up in github. The version on the left shown by github is actually old

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok if i merge in latest main it seems fine..

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this file should be changed now that #12599 has been merged. Could you revert this file?

Also wondering why this has more changes than in #12599 - did you run into any additional issues that required these additional chagnes?

Copy link
Author

@fabianlim fabianlim Feb 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above, the version on the left is old. the right is from latest main

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's land these changes as part of #9292

Comment on lines +33 to +123
# Adapted from transformers.models.mamba2.modeling_mamba2.MambaRMSNormGated
@CustomOp.register("mixer2_gated_rms_norm")
class Mixer2RMSNormGated(CustomOp):

def __init__(self, full_hidden_size, full_n_groups, eps=1e-6):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.full_hidden_size = full_hidden_size
self.group_size = full_hidden_size // full_n_groups
self.per_rank_hidden_size = full_hidden_size // self.tp_size
self.n_groups = full_hidden_size // self.group_size

self.variance_epsilon = eps
self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size))
set_weight_attrs(self.weight,
{"weight_loader": sharded_weight_loader(0)})
assert self.full_hidden_size % self.tp_size== 0,\
"Tensor parallel world size must divide hidden size."

def forward_native(
self,
x: torch.Tensor,
gate: torch.Tensor,
):
# Three tensor-parallel cases:
# 1. n_groups is 1
# In this case we parallelize along the reduction dim.
# Each rank computes a local sum of squares followed by AllReduce
# 2. tp_size divides n_groups
# Each rank only reduces within its local group(s).
# No collective ops necessary.
# 3. The general case can be pretty complicated so we AllGather
# the input and then redundantly compute the RMSNorm.
input_dtype = x.dtype
x = x * nn.functional.silu(gate.to(torch.float32))

if self.n_groups == 1:
if self.tp_size > 1:
# Compute local sum and then reduce to obtain global sum
local_sums = x.pow(2).sum(dim=-1, keepdim=True)
global_sums = tensor_model_parallel_all_reduce(local_sums)
# Calculate the variance
count = self.tp_size * x.shape[-1]
variance = (global_sums / count)

else:
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
else:
redundant_tp: bool = self.n_groups % self.tp_size != 0
if redundant_tp:
# To handle the general case, redundantly apply the variance
x = tensor_model_parallel_all_gather(x, -1)

*prefix_dims, hidden_dim = x.shape
group_count = hidden_dim // self.group_size
x_grouped = x.view(*prefix_dims, group_count, self.group_size)
variance = x_grouped.pow(2).mean(-1, keepdim=True)
x_grouped = x_grouped * torch.rsqrt(variance +
self.variance_epsilon)
x = x_grouped.view(*prefix_dims, hidden_dim)

if redundant_tp:
start = self.per_rank_hidden_size * self.tp_rank
end = start + self.per_rank_hidden_size
x = x[..., start:end]

return self.weight * x.to(input_dtype)

def forward_cuda(
self,
x: torch.Tensor,
gate: torch.Tensor,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:

if self.tp_size > 1 or self.n_groups != 1:
return self.forward_native(x, gate)

from vllm import _custom_ops as ops

# cast x and gate to float32 before silu
out = torch.empty_like(x)
y = x * nn.functional.silu(gate.to(torch.float32))
ops.rms_norm(
out,
y.to(x.dtype),
self.weight.data,
self.variance_epsilon,
)
return out
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should get a unit test in place for this, especially the various tensor parallel cases. @fabianlim do you have bandwidth to do that? Otherwise I can do it in either in #9292 or a separate PR. I do feel pretty good about correctness here, having manually tested various cases thoroughly enough.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just unit testing only the Mixer2RMSNormGated? if so how would you setup the test? conftest only has runners for the whole model.

@@ -69,6 +70,7 @@
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),
"FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
"Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's land this in #9292

@tlrmchlsmth
Copy link
Collaborator

To debug the pre-commit issue locally you may need to run:

pre-commit run mypy-3.9 --hook-stage manual -a

Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
@fabianlim
Copy link
Author

@tlrmchlsmth thank you so much for your comments. I have fixed the pre-commit and reverted changes in a bunch of files; i kept the changes that are still needed to test the PR. Also merging in upstream/main helped github to display the changes correctly. Regarding the new unit test, I left a question for you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants