Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Bamba Model #10909

Open
wants to merge 65 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
62181d5
initial pr without tp fix
fabianlim Dec 5, 2024
51bc78c
fix casting in rms norm gated
fabianlim Dec 5, 2024
81b93b4
TP fix
fabianlim Dec 5, 2024
0f93e4a
fix mamba scan invalid address
fabianlim Dec 8, 2024
742ae79
some fixes and remove unused kernels
fabianlim Dec 12, 2024
b2dc5ca
fmt + lint
fabianlim Dec 12, 2024
9ad9e20
more comments
fabianlim Dec 12, 2024
25bf381
initial fix for chunked prefill (incomplete)
fabianlim Dec 12, 2024
43ce07c
improve comments
fabianlim Dec 12, 2024
80f14b5
do not attach seq_idx to attn_metadata
fabianlim Dec 12, 2024
6b8ac49
activate initial states for chunked prefill
fabianlim Dec 12, 2024
d788db6
reuse softplus and remove triton2 remark
fabianlim Dec 13, 2024
400db27
add comment on weight loader and format
fabianlim Dec 13, 2024
bda8ea7
rename test_jamba to test_hybrid and got rid of test_bamba
fabianlim Dec 13, 2024
66078d6
Merge remote-tracking branch 'upstream/main' into bamba-pr
fabianlim Dec 16, 2024
a74de9f
update bamba to ishybrid and support pp
fabianlim Dec 16, 2024
b44caa7
lint
fabianlim Dec 16, 2024
8cf3644
add unit test for mamba ssd
fabianlim Dec 16, 2024
e375b40
fix lint
fabianlim Dec 16, 2024
dcbae7b
full chunked-prefill fix (sans unit tests)
fabianlim Dec 21, 2024
2597105
format and add cont batch unit tests (will need more cases)
fabianlim Dec 23, 2024
db5eea5
fix kernel tests and add more chunked prefill cases
fabianlim Dec 23, 2024
dfbcb16
bound adjustment
fabianlim Dec 23, 2024
7913009
bound adjustment
fabianlim Dec 26, 2024
9c5d045
lint errors
fabianlim Dec 26, 2024
6bc9dac
Add permalink correction from @tlrmchlsmth
fabianlim Jan 3, 2025
6d02e85
improved comment for segsum, add more sizes for test_mamba_chunk_scan…
fabianlim Jan 3, 2025
e5882f2
rename and comment functions, add more sizes for test_mamba_chunk_sca…
fabianlim Jan 3, 2025
6d6fa86
addressed comments on mamba_mixer2.py
fabianlim Jan 3, 2025
773dd80
replace with get_rope
fabianlim Jan 3, 2025
63f5340
rope scaling
fabianlim Jan 4, 2025
89e36d8
fixes
fabianlim Jan 6, 2025
7a4ae96
zero out ssm states
fabianlim Jan 7, 2025
a9e149c
fix tests (sans updating dev checkpoint)
fabianlim Jan 7, 2025
5c9f48d
not replacing dev model for now
fabianlim Jan 11, 2025
55647b1
update requirements
fabianlim Jan 13, 2025
2342bc0
remove extraneous comment
fabianlim Jan 14, 2025
011c141
update test
fabianlim Jan 14, 2025
503bc42
fix lint
fabianlim Jan 15, 2025
312cf1d
fix lint
fabianlim Jan 15, 2025
c1db743
fix requirements-test
fabianlim Jan 15, 2025
c956a30
Mamba2 changes from #10909
tlrmchlsmth Jan 16, 2025
17923ad
Get Mamba2 working!
tlrmchlsmth Jan 16, 2025
4183d45
Add integration test -- something is wrong!!
tlrmchlsmth Jan 17, 2025
5377644
format
tlrmchlsmth Jan 17, 2025
39f55d1
fixes
tlrmchlsmth Jan 17, 2025
dd31f19
update test registry, fixes
fabianlim Jan 16, 2025
e2e5aac
Fix for conv state shape and update placeholder_attn
tlrmchlsmth Jan 19, 2025
bc1b8af
back out placeholder_attn changes
tlrmchlsmth Jan 19, 2025
9db0dd5
make seq_idx to chunk indices more efficient
fabianlim Jan 20, 2025
cd89283
WIP debugging, restore local mamba and placeholder_attn changes
tlrmchlsmth Jan 20, 2025
9a838a3
Integration tests are now green
tlrmchlsmth Jan 20, 2025
be8318e
remove bamba-specific files
tlrmchlsmth Jan 20, 2025
f34d434
Merge branch 'main' into tms/mamba2
tlrmchlsmth Jan 27, 2025
a65e2cb
Handle grouping in Mixer2RMSNormGated
tlrmchlsmth Jan 30, 2025
0d4bb0f
debug cruft
tlrmchlsmth Jan 30, 2025
74f6088
Remove codestral integration test
tlrmchlsmth Jan 30, 2025
95583b8
Merge branch 'tms/mamba2' into bamba-pr
fabianlim Feb 1, 2025
b72389c
update mamba_cache
fabianlim Feb 1, 2025
10d75eb
remove changes to requirements
fabianlim Feb 1, 2025
5aea1e6
revert changes
fabianlim Feb 1, 2025
2ee8d07
Merge remote-tracking branch 'upstream/main' into bamba-pr
fabianlim Feb 1, 2025
043e006
fix lint
fabianlim Feb 1, 2025
7e4ce4f
fix lint
fabianlim Feb 1, 2025
8219480
more reverts
fabianlim Feb 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
302 changes: 302 additions & 0 deletions tests/kernels/test_mamba_ssm_ssd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
from typing import Dict, Tuple

import pytest
import torch
import torch.nn.functional as F
from einops import rearrange, repeat

from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined)
from vllm.platforms import current_platform

# Added by the IBM Team, 2024

# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/modules/ssd_minimal.py


# this is the segsum implementation taken from above
def segsum(x):
"""Calculates segment sum."""
T = x.size(-1)
x = repeat(x, "... d -> ... d e", e=T)
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool),
diagonal=-1)
x = x.masked_fill(~mask, 0)
x_segsum = torch.cumsum(x, dim=-2)
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool),
diagonal=0)
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
return x_segsum


def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None):
"""
Arguments:
X: (batch, length, n_heads, d_head)
A: (batch, length, n_heads)
B: (batch, length, n_heads, d_state)
C: (batch, length, n_heads, d_state)
Return:
Y: (batch, length, n_heads, d_head)
"""
assert X.dtype == A.dtype == B.dtype == C.dtype
assert X.shape[1] % block_len == 0

# Rearrange into blocks/chunks
X, A, B, C = (rearrange(x, "b (c l) ... -> b c l ...", l=block_len)
for x in (X, A, B, C))

A = rearrange(A, "b c l h -> b h c l")
A_cumsum = torch.cumsum(A, dim=-1)

# 1. Compute the output for each intra-chunk (diagonal blocks)
L = torch.exp(segsum(A))
Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)

# 2. Compute the state for each intra-chunk
# (right term of low-rank factorization of off-diagonal blocks; B terms)
decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)

# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at
# chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
if initial_states is None:
initial_states = torch.zeros_like(states[:, :1])
states = torch.cat([initial_states, states], dim=1)
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
states, final_state = new_states[:, :-1], new_states[:, -1]

# 4. Compute state -> output conversion per chunk
# (left term of low-rank factorization of off-diagonal blocks; C terms)
state_decay_out = torch.exp(A_cumsum)
Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out)

# Add output of intra-chunk and inter-chunk terms
# (diagonal and off-diagonal blocks)
Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")
return Y, final_state


def generate_random_inputs(batch_size,
seqlen,
n_heads,
d_head,
itype,
device='cuda'):

current_platform.seed_everything(0)
A = (-torch.exp(torch.rand(n_heads, dtype=itype, device=device)))
dt = F.softplus(
torch.randn(batch_size, seqlen, n_heads, dtype=itype, device=device) -
4)
X = torch.randn((batch_size, seqlen, n_heads, d_head),
dtype=itype,
device=device)
B = torch.randn((batch_size, seqlen, n_heads, d_head),
dtype=itype,
device=device)
C = torch.randn((batch_size, seqlen, n_heads, d_head),
dtype=itype,
device=device)

return A, dt, X, B, C


def generate_continous_batched_examples(example_lens_by_batch,
num_examples,
full_length,
last_taken,
exhausted,
n_heads,
d_head,
itype,
device='cuda'):

# this function generates a random examples of certain length
# and then cut according to "example_lens_by_batch" and feed
# them in continuous batches to the kernels

# generate the full-length example
A, dt, X, B, C = generate_random_inputs(num_examples, full_length, n_heads,
d_head, itype)

Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1),
A * dt,
B,
C,
block_len=full_length // 4)

# internal function that outputs a cont batch of examples
# given a tuple of lengths for each example in the batch
# e.g., example_lens=(8, 4) means take 8 samples from first eg,
# 4 examples from second eg, etc
def get_continuous_batch(example_lens: Tuple[int, ...]):

indices = []
for i, x in enumerate(example_lens):
c = last_taken.get(i, 0)
indices.append((c, c + x))
last_taken[i] = (c + x) % full_length
exhausted[i] = last_taken[i] == 0

return (torch.concat([x[i, s:e] for i, (s, e) in enumerate(indices)
]).unsqueeze(0) for x in (dt, X, B, C))

# internal function that maps "n" to the appropriate right boundary
# value when forming continuous batches from examples of length given
# by "full_length".
# - e.g., when n > full_length, returns n % full_length
# when n == full_length, returns full_length
def end_boundary(n: int):
return n - ((n - 1) // full_length) * full_length

IND_E = None
for spec in example_lens_by_batch:

# get the (maybe partial) example seen in this cont batch
dt2, X2, B2, C2 = get_continuous_batch(spec)

# get the metadata
cu_seqlens = torch.tensor((0, ) + spec, device=device).cumsum(dim=0)
sed_idx = torch.zeros(cu_seqlens[-1],
dtype=torch.int32,
device=cu_seqlens.device)
for i, (srt, end) in enumerate(zip(
cu_seqlens,
cu_seqlens[1:],
)):
sed_idx[srt:end] = i

# for cont batch
if IND_E is None:
IND_S = [0 for _ in range(len(spec))]
else:
IND_S = [x % full_length for x in IND_E]
IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)]

yield ([Y_min[s, IND_S[s]:IND_E[s]] for s in range(num_examples)],
cu_seqlens, sed_idx.unsqueeze(0), (A, dt2, X2, B2, C2))


@pytest.mark.parametrize("itype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("n_heads", [3, 4, 11, 16, 32])
@pytest.mark.parametrize("d_head", [5, 8, 19, 32, 128])
@pytest.mark.parametrize("seq_len_chunk_size", [(119, 17), (128, 32)])
def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
itype):

# this tests the kernels on a single example (no batching)

# set seed
batch_size = 1 # batch_size
# ssd_minimal_discrete requires chunk_size divide seqlen
# - this is only required for generating the reference seqs,
# it is not an operational limitation.
seqlen, chunk_size = seq_len_chunk_size

A, dt, X, B, C = generate_random_inputs(batch_size, seqlen, n_heads,
d_head, itype)

Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt,
B, C, chunk_size)

Y, final_state = mamba_chunk_scan_combined(X,
dt,
A,
B,
C,
chunk_size,
D=None,
return_final_states=True)

# just test the last in sequence
torch.allclose(Y[:, -1], Y_min[:, -1], atol=1e-3, rtol=1e-3)

# just test the last head
# NOTE, in the kernel we always cast states to fp32
torch.allclose(final_state[:, -1],
final_state_min[:, -1].to(torch.float32),
atol=1e-3,
rtol=1e-3)


@pytest.mark.parametrize("itype", [torch.float32, torch.float16])
@pytest.mark.parametrize("n_heads", [4, 8, 13])
@pytest.mark.parametrize("d_head", [5, 16, 21, 32])
@pytest.mark.parametrize(
"seq_len_chunk_size_cases",
[

# small-ish chunk_size (8)
(64, 8, 2, [(64, 32), (64, 32)]),
(64, 8, 2, [(32, 32), (32, 32), (32, 32)]),
(64, 8, 2, [(8, 8), (8, 8), (8, 8)]), # chunk size boundary
(64, 8, 2, [(4, 4), (4, 4), (4, 4),
(4, 4)]), # chunk_size larger than cont batches
(64, 8, 5, [
(64, 32, 16, 8, 8),
(8, 16, 32, 16, 8),
(8, 8, 16, 32, 16),
]), # mode examples with varied lengths

# odd chunk_size
(64, 29, 2, [(11, 4), (13, 23), (19, 22),
(21, 15)]), # irregular sizes

# large-ish chunk_size (256)
(64, 256, 1, [(5, ), (1, ), (1, ),
(1, )]), # irregular sizes with small sequences
(64, 256, 2, [(5, 30), (1, 2), (1, 2),
(1, 2)]), # irregular sizes with small sequences
])
def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
itype):

# this test with multiple examples in a continuous batch
# (i.e. chunked prefill)

seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases

# hold state during the cutting process so we know if an
# example has been exhausted and needs to cycle
last_taken: Dict = {} # map: eg -> pointer to last taken sample
exhausted: Dict = {} # map: eg -> boolean indicating example is exhausted

states = None
for Y_min, cu_seqlens, sed_idx, (A, dt, X, B,
C) in generate_continous_batched_examples(
cases, num_examples, seqlen,
last_taken, exhausted, n_heads,
d_head, itype):

Y, new_states = mamba_chunk_scan_combined(
X,
dt,
A,
B,
C,
chunk_size,
D=None,
cu_seqlens=cu_seqlens,
seq_idx=sed_idx,
return_varlen_states=True,
initial_states=states,
)

# just test the last in sequence
for i in range(num_examples):

# just test one dim and dstate
Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0]
Y_min_eg = Y_min[i][:, 0, 0]
torch.allclose(Y_eg, Y_min_eg, atol=1e-3, rtol=1e-3)

# update states
states = new_states
for i, clear in exhausted.items():
if clear:
states[i].fill_(0.)
exhausted[i] = False
Loading