-
-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Bamba Model #10909
Open
fabianlim
wants to merge
65
commits into
vllm-project:main
Choose a base branch
from
fabianlim:bamba-pr
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+3,563
−95
Open
Add Bamba Model #10909
Changes from 11 commits
Commits
Show all changes
65 commits
Select commit
Hold shift + click to select a range
62181d5
initial pr without tp fix
fabianlim 51bc78c
fix casting in rms norm gated
fabianlim 81b93b4
TP fix
fabianlim 0f93e4a
fix mamba scan invalid address
fabianlim 742ae79
some fixes and remove unused kernels
fabianlim b2dc5ca
fmt + lint
fabianlim 9ad9e20
more comments
fabianlim 25bf381
initial fix for chunked prefill (incomplete)
fabianlim 43ce07c
improve comments
fabianlim 80f14b5
do not attach seq_idx to attn_metadata
fabianlim 6b8ac49
activate initial states for chunked prefill
fabianlim d788db6
reuse softplus and remove triton2 remark
fabianlim 400db27
add comment on weight loader and format
fabianlim bda8ea7
rename test_jamba to test_hybrid and got rid of test_bamba
fabianlim 66078d6
Merge remote-tracking branch 'upstream/main' into bamba-pr
fabianlim a74de9f
update bamba to ishybrid and support pp
fabianlim b44caa7
lint
fabianlim 8cf3644
add unit test for mamba ssd
fabianlim e375b40
fix lint
fabianlim dcbae7b
full chunked-prefill fix (sans unit tests)
fabianlim 2597105
format and add cont batch unit tests (will need more cases)
fabianlim db5eea5
fix kernel tests and add more chunked prefill cases
fabianlim dfbcb16
bound adjustment
fabianlim 7913009
bound adjustment
fabianlim 9c5d045
lint errors
fabianlim 6bc9dac
Add permalink correction from @tlrmchlsmth
fabianlim 6d02e85
improved comment for segsum, add more sizes for test_mamba_chunk_scan…
fabianlim e5882f2
rename and comment functions, add more sizes for test_mamba_chunk_sca…
fabianlim 6d6fa86
addressed comments on mamba_mixer2.py
fabianlim 773dd80
replace with get_rope
fabianlim 63f5340
rope scaling
fabianlim 89e36d8
fixes
fabianlim 7a4ae96
zero out ssm states
fabianlim a9e149c
fix tests (sans updating dev checkpoint)
fabianlim 5c9f48d
not replacing dev model for now
fabianlim 55647b1
update requirements
fabianlim 2342bc0
remove extraneous comment
fabianlim 011c141
update test
fabianlim 503bc42
fix lint
fabianlim 312cf1d
fix lint
fabianlim c1db743
fix requirements-test
fabianlim c956a30
Mamba2 changes from #10909
tlrmchlsmth 17923ad
Get Mamba2 working!
tlrmchlsmth 4183d45
Add integration test -- something is wrong!!
tlrmchlsmth 5377644
format
tlrmchlsmth 39f55d1
fixes
tlrmchlsmth dd31f19
update test registry, fixes
fabianlim e2e5aac
Fix for conv state shape and update placeholder_attn
tlrmchlsmth bc1b8af
back out placeholder_attn changes
tlrmchlsmth 9db0dd5
make seq_idx to chunk indices more efficient
fabianlim cd89283
WIP debugging, restore local mamba and placeholder_attn changes
tlrmchlsmth 9a838a3
Integration tests are now green
tlrmchlsmth be8318e
remove bamba-specific files
tlrmchlsmth f34d434
Merge branch 'main' into tms/mamba2
tlrmchlsmth a65e2cb
Handle grouping in Mixer2RMSNormGated
tlrmchlsmth 0d4bb0f
debug cruft
tlrmchlsmth 74f6088
Remove codestral integration test
tlrmchlsmth 95583b8
Merge branch 'tms/mamba2' into bamba-pr
fabianlim b72389c
update mamba_cache
fabianlim 10d75eb
remove changes to requirements
fabianlim 5aea1e6
revert changes
fabianlim 2ee8d07
Merge remote-tracking branch 'upstream/main' into bamba-pr
fabianlim 043e006
fix lint
fabianlim 7e4ce4f
fix lint
fabianlim 8219480
more reverts
fabianlim File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,329 @@ | ||
import pytest | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
from vllm.config import VllmConfig | ||
from vllm.sampling_params import SamplingParams | ||
|
||
from ...utils import check_outputs_equal | ||
|
||
# will be ch | ||
MODELS = ["ibm-fms/Bamba-9.8b-1.8T-hf"] | ||
|
||
|
||
# Use lower-level interfaces to create this greedy generator, as mamba will | ||
# choke on the model_kwarg 'attention_mask' if hf_model.generate_greedy is used. | ||
fabianlim marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def generate_greedy(model_name, example_prompts, max_tokens): | ||
# Create a text generation pipeline | ||
tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
model = AutoModelForCausalLM.from_pretrained(model_name) | ||
|
||
# Generate texts from the prompts | ||
outputs = [] | ||
for prompt in example_prompts: | ||
# Tokenize the input prompt with truncation | ||
inputs = tokenizer(prompt, return_tensors="pt", truncation=True) | ||
input_ids = inputs["input_ids"] | ||
|
||
# Generate text using the model's generate method directly | ||
generated_ids = model.generate(input_ids, max_new_tokens=max_tokens) | ||
generated_text = tokenizer.decode(generated_ids[0], | ||
skip_special_tokens=True) | ||
|
||
outputs.append((generated_ids[0].tolist(), generated_text)) | ||
|
||
return outputs | ||
|
||
|
||
"""Compare the outputs of HF and vLLM when using greedy sampling for Mamba. | ||
This actually is really identical to test_mamba, so maybe we can reuse | ||
Run `pytest tests/models/decoder_only/language/test_bamba.py`. | ||
""" | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("dtype", ["float"]) | ||
@pytest.mark.parametrize("max_tokens", [96]) | ||
def test_models( | ||
vllm_runner, | ||
example_prompts, | ||
model: str, | ||
dtype: str, | ||
max_tokens: int, | ||
) -> None: | ||
hf_outputs = generate_greedy(model, example_prompts, max_tokens) | ||
|
||
with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model: | ||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) | ||
# This test is for verifying whether the model's extra_repr | ||
# can be printed correctly. | ||
print(vllm_model.model.llm_engine.model_executor.driver_worker. | ||
model_runner.model) | ||
|
||
for i in range(len(example_prompts)): | ||
hf_output_ids, hf_output_str = hf_outputs[i] | ||
vllm_output_ids, vllm_output_str = vllm_outputs[i] | ||
assert hf_output_str == vllm_output_str, ( | ||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") | ||
assert hf_output_ids == vllm_output_ids, ( | ||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("dtype", ["float"]) | ||
@pytest.mark.parametrize("max_tokens", [96]) | ||
def test_batching( | ||
vllm_runner, | ||
example_prompts, | ||
model: str, | ||
dtype: str, | ||
max_tokens: int, | ||
) -> None: | ||
# To pass the small model tests, we need full precision. | ||
for_loop_outputs = [] | ||
with vllm_runner(model, dtype=dtype) as vllm_model: | ||
for prompt in example_prompts: | ||
for_loop_outputs.append( | ||
vllm_model.generate_greedy([prompt], max_tokens)[0]) | ||
|
||
batched_outputs = vllm_model.generate_greedy(example_prompts, | ||
max_tokens) | ||
|
||
check_outputs_equal( | ||
outputs_0_lst=for_loop_outputs, | ||
outputs_1_lst=batched_outputs, | ||
name_0="for_loop_vllm", | ||
name_1="batched_vllm", | ||
) | ||
|
||
|
||
@pytest.mark.skip("bamba does not support chunked prefill yet") | ||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("dtype", ["float"]) | ||
@pytest.mark.parametrize("max_tokens", [10]) | ||
def test_chunked_prefill_with_parallel_sampling(vllm_runner, example_prompts, | ||
model: str, dtype: str, | ||
max_tokens: int) -> None: | ||
# Tests chunked prefill in conjunction with n>1. In this case, prefill is | ||
# populated with decoding tokens and we test that it doesn't fail. | ||
# This test might fail if cache is not allocated correctly for n > 1 | ||
# decoding steps inside a chunked prefill forward pass (where we have both | ||
# prefill and decode together ) | ||
sampling_params = SamplingParams(n=3, | ||
temperature=1, | ||
seed=0, | ||
max_tokens=max_tokens) | ||
with vllm_runner( | ||
model, | ||
dtype=dtype, | ||
enable_chunked_prefill=True, | ||
max_num_batched_tokens=30, | ||
max_num_seqs=10 # forces prefill chunks with decoding | ||
) as vllm_model: | ||
vllm_model.generate(example_prompts, sampling_params) | ||
|
||
|
||
@pytest.mark.skip("bamba does not support chunked prefill yet") | ||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("dtype", ["float"]) | ||
@pytest.mark.parametrize("max_tokens", [32]) | ||
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) | ||
def test_chunked_prefill(vllm_runner, example_prompts, model: str, dtype: str, | ||
max_tokens: int, | ||
chunked_prefill_token_size: int) -> None: | ||
""" | ||
Checks exact match decode between huggingface model and vllm runner with | ||
chunked prefill. | ||
""" | ||
max_num_seqs = chunked_prefill_token_size | ||
max_num_batched_tokens = chunked_prefill_token_size | ||
|
||
non_chunked = generate_greedy(model, example_prompts, max_tokens) | ||
|
||
with vllm_runner(model, | ||
dtype=dtype, | ||
enable_chunked_prefill=True, | ||
max_num_batched_tokens=max_num_batched_tokens, | ||
max_num_seqs=max_num_seqs) as vllm_model: | ||
chunked = vllm_model.generate_greedy(example_prompts, | ||
max_tokens=max_tokens) | ||
|
||
check_outputs_equal( | ||
outputs_0_lst=chunked, | ||
outputs_1_lst=non_chunked, | ||
name_0="chunked", | ||
name_1="non_chunked", | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("dtype", ["float"]) | ||
@pytest.mark.parametrize("max_tokens", [15]) | ||
def test_parallel_sampling( | ||
vllm_runner, | ||
example_prompts, | ||
model: str, | ||
dtype: str, | ||
max_tokens: int, | ||
) -> None: | ||
|
||
with vllm_runner(model, dtype=dtype) as vllm_model: | ||
for_loop_outputs = [] | ||
for _ in range(10): | ||
for_loop_outputs.append( | ||
# using example_prompts index 1 instead of 0 since with 0 the | ||
# logprobs get really close and the test doesn't pass | ||
vllm_model.generate_greedy([example_prompts[1]], max_tokens) | ||
[0]) | ||
sampling_params = SamplingParams(n=10, | ||
temperature=0.001, | ||
seed=0, | ||
max_tokens=max_tokens) | ||
n_lt_1_outputs = vllm_model.generate([example_prompts[1]], | ||
sampling_params) | ||
token_ids, texts = n_lt_1_outputs[0] | ||
n_lt_1_outputs = [(token_id, text) | ||
for token_id, text in zip(token_ids, texts)] | ||
|
||
check_outputs_equal( | ||
outputs_0_lst=n_lt_1_outputs, | ||
outputs_1_lst=for_loop_outputs, | ||
name_0="vllm_n_lt_1_outputs", | ||
name_1="vllm", | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("dtype", ["bfloat16"]) | ||
@pytest.mark.parametrize("max_tokens", [20]) | ||
def test_mamba_cache_cg_padding( | ||
vllm_runner, | ||
example_prompts, | ||
model: str, | ||
dtype: str, | ||
max_tokens: int, | ||
) -> None: | ||
# This test is for verifying that mamba cache is padded to CG captured | ||
# batch size. If it's not, a torch RuntimeError will be raised because | ||
# tensor dimensions aren't compatible | ||
while len(example_prompts) == VllmConfig.get_graph_batch_size( | ||
len(example_prompts)): | ||
example_prompts.append(example_prompts[0]) | ||
|
||
try: | ||
with vllm_runner(model, dtype=dtype) as vllm_model: | ||
vllm_model.generate_greedy(example_prompts, max_tokens) | ||
except RuntimeError: | ||
pytest.fail( | ||
"Couldn't run batch size which is not equal to a Cuda Graph " | ||
"captured batch size. " | ||
"Could be related to mamba cache not padded correctly") | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("dtype", ["float"]) | ||
@pytest.mark.parametrize("max_tokens", [20]) | ||
def test_models_preemption_recompute( | ||
vllm_runner, | ||
example_prompts, | ||
model: str, | ||
dtype: str, | ||
max_tokens: int, | ||
) -> None: | ||
# Tests that outputs are identical with and w/o preemtions (recompute) | ||
assert dtype == "float" | ||
|
||
with vllm_runner(model, dtype=dtype) as vllm_model: | ||
vllm_model.model.llm_engine.scheduler[ | ||
0].ENABLE_ARTIFICIAL_PREEMPT = True | ||
preempt_vllm_outputs = vllm_model.generate_greedy( | ||
example_prompts, max_tokens) | ||
|
||
vllm_model.model.llm_engine.scheduler[ | ||
0].ENABLE_ARTIFICIAL_PREEMPT = False | ||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) | ||
|
||
check_outputs_equal( | ||
outputs_0_lst=preempt_vllm_outputs, | ||
outputs_1_lst=vllm_outputs, | ||
name_0="vllm_preepmtions", | ||
name_1="vllm", | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("dtype", ["float"]) | ||
def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks( | ||
vllm_runner, | ||
model: str, | ||
dtype: str, | ||
example_prompts, | ||
) -> None: | ||
# This test is for verifying that the Mamba inner state management doesn't | ||
# collapse in case where the number of incoming requests and | ||
# finished_requests_ids is larger than the maximum Mamba block capacity. | ||
# This could generally happen due to the fact that Mamba does support | ||
# statelessness mechanism where it can cleanup new incoming requests in | ||
# a single step. | ||
try: | ||
with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model: | ||
vllm_model.generate_greedy([example_prompts[0]] * 100, 10) | ||
except ValueError: | ||
pytest.fail("Mamba inner state wasn't cleaned up properly between" | ||
"steps finished requests registered unnecessarily ") | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("dtype", ["float"]) | ||
def test_state_cleanup( | ||
vllm_runner, | ||
model: str, | ||
dtype: str, | ||
example_prompts, | ||
) -> None: | ||
# This test is for verifying that the Mamba state is cleaned up between | ||
# steps, If its not cleaned, an error would be expected. | ||
try: | ||
with vllm_runner(model, dtype=dtype) as vllm_model: | ||
for _ in range(10): | ||
vllm_model.generate_greedy([example_prompts[0]] * 100, 1) | ||
except ValueError: | ||
pytest.fail("Mamba inner state wasn't cleaned up between states, " | ||
"could be related to finished_requests_ids") | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("dtype", ["float"]) | ||
def test_multistep( | ||
vllm_runner, | ||
model: str, | ||
dtype: str, | ||
example_prompts, | ||
) -> None: | ||
with vllm_runner(model, num_scheduler_steps=8, | ||
max_num_seqs=2) as vllm_model: | ||
vllm_model.generate_greedy([example_prompts[0]] * 10, 1) | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("dtype", ["float"]) | ||
@pytest.mark.parametrize("max_tokens", [64]) | ||
def test_multistep_correctness(vllm_runner, model: str, dtype: str, | ||
max_tokens: int, example_prompts) -> None: | ||
with vllm_runner(model, num_scheduler_steps=8, | ||
max_num_seqs=2) as vllm_model: | ||
vllm_outputs_multistep = vllm_model.generate_greedy( | ||
example_prompts, max_tokens) | ||
|
||
with vllm_runner(model, num_scheduler_steps=1, | ||
max_num_seqs=2) as vllm_model: | ||
vllm_outputs_single_step = vllm_model.generate_greedy( | ||
example_prompts, max_tokens) | ||
|
||
check_outputs_equal( | ||
outputs_0_lst=vllm_outputs_multistep, | ||
outputs_1_lst=vllm_outputs_single_step, | ||
name_0="vllm_outputs_multistep", | ||
name_1="vllm_outputs_single_step", | ||
) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment trails off, but will there be a small test model available?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@raghukiran1224 any plans for a small test model? I think since we do outputs comparison it is not that good to just have a randomly initialised small model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fabianlim @tlrmchlsmth would it be ok to test with a random model or would you rather have a tiny model (say 200M or so) to test with?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A tiny model with nonrandom weights would be much better!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw is there any update on this?