diff --git a/tests/kernels/test_mamba_mixer2.py b/tests/kernels/test_mamba_mixer2.py new file mode 100644 index 0000000000000..8c441fcbe61e2 --- /dev/null +++ b/tests/kernels/test_mamba_mixer2.py @@ -0,0 +1,125 @@ +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from typing import Tuple + +import pytest +import torch + +from tests.utils import multi_gpu_test +from vllm.distributed.parallel_state import (init_distributed_environment, + initialize_model_parallel) +from vllm.model_executor.layers.mamba.mamba_mixer2 import Mixer2RMSNormGated +from vllm.platforms import current_platform +from vllm.utils import update_environment_variables + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seq_len", [128]) +@pytest.mark.parametrize( + "hidden_size_n_groups", + [ + (64, 1), + (64, 2), + (64, 4), # hidden_size be divisible by num_gpus + (100, 5), # and n_groups must divide hidden_size + ]) +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_mixer2_gated_norm_multi_gpu( + batch_size: int, + seq_len: int, + hidden_size_n_groups: Tuple[int, int], + dtype: torch.dtype, + device: str = 'cuda', +): + hidden_size, n_groups = hidden_size_n_groups + num_processes = 2 + + def run_torch_spawn(fn, nprocs): + # need to use torch.mp.spawn otherwise will have problems with + # torch.distributed and cuda + torch.multiprocessing.spawn(fn, + args=( + num_processes, + batch_size, + seq_len, + hidden_size, + n_groups, + dtype, + device, + ), + nprocs=nprocs) + + run_torch_spawn(mixer2_gated_norm_tensor_parallel, 2) + + +def mixer2_gated_norm_tensor_parallel( + local_rank: int, + world_size: int, + batch_size: int, + seq_len: int, + hidden_size: int, + n_groups: int, + dtype: torch.dtype, + device: str, +): + current_platform.seed_everything(0) + + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + torch.set_default_device(device) + torch.set_default_dtype(dtype) + + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': '12345', + }) + + # initialize distributed + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # create random weights an inputs + weight = torch.rand((hidden_size, ), dtype=dtype, device=device) + hidden_states = torch.randn(batch_size, seq_len, hidden_size) + gate_states = torch.randn(batch_size, seq_len, hidden_size) + + # create gated-norm with TP + mixer = Mixer2RMSNormGated( + full_hidden_size=hidden_size, + full_n_groups=n_groups, + ) + mixer.weight.weight_loader(mixer.weight, weight) # load + + # create gated-norm without TP to compute reference + # - utilize mock patching to disable TP when + with (unittest.mock.patch( + "vllm.model_executor.layers.mamba.mamba_mixer2." + "get_tensor_model_parallel_world_size", + return_value=1), + unittest.mock.patch( + "vllm.model_executor.layers.mamba.mamba_mixer2." + "get_tensor_model_parallel_rank", + return_value=0)): + mixer_single_gpu = Mixer2RMSNormGated( + full_hidden_size=hidden_size, + full_n_groups=n_groups, + ) + # assign weight to single-gpu mixer + mixer_single_gpu.weight.data = weight + + # generate and compare + N = hidden_size // world_size + output = mixer( + hidden_states[..., local_rank * N:(local_rank + 1) * N], + gate_states[..., local_rank * N:(local_rank + 1) * N], + ) + ref_output = mixer_single_gpu(hidden_states, gate_states) + torch.allclose(output, + ref_output[..., local_rank * N:(local_rank + 1) * N], + atol=1e-3, + rtol=1e-3) diff --git a/tests/kernels/test_mamba_ssm_ssd.py b/tests/kernels/test_mamba_ssm_ssd.py new file mode 100644 index 0000000000000..882513116ed6d --- /dev/null +++ b/tests/kernels/test_mamba_ssm_ssd.py @@ -0,0 +1,304 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, Tuple + +import pytest +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + +from vllm.model_executor.layers.mamba.ops.ssd_combined import ( + mamba_chunk_scan_combined) +from vllm.platforms import current_platform + +# Added by the IBM Team, 2024 + +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/modules/ssd_minimal.py + + +# this is the segsum implementation taken from above +def segsum(x): + """Calculates segment sum.""" + T = x.size(-1) + x = repeat(x, "... d -> ... d e", e=T) + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), + diagonal=-1) + x = x.masked_fill(~mask, 0) + x_segsum = torch.cumsum(x, dim=-2) + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), + diagonal=0) + x_segsum = x_segsum.masked_fill(~mask, -torch.inf) + return x_segsum + + +def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None): + """ + Arguments: + X: (batch, length, n_heads, d_head) + A: (batch, length, n_heads) + B: (batch, length, n_heads, d_state) + C: (batch, length, n_heads, d_state) + Return: + Y: (batch, length, n_heads, d_head) + """ + assert X.dtype == A.dtype == B.dtype == C.dtype + assert X.shape[1] % block_len == 0 + + # Rearrange into blocks/chunks + X, A, B, C = (rearrange(x, "b (c l) ... -> b c l ...", l=block_len) + for x in (X, A, B, C)) + + A = rearrange(A, "b c l h -> b h c l") + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + L = torch.exp(segsum(A)) + Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum) + states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at + # chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + if initial_states is None: + initial_states = torch.zeros_like(states[:, :1]) + states = torch.cat([initial_states, states], dim=1) + decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)))) + new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states) + states, final_state = new_states[:, :-1], new_states[:, -1] + + # 4. Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out) + + # Add output of intra-chunk and inter-chunk terms + # (diagonal and off-diagonal blocks) + Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p") + return Y, final_state + + +def generate_random_inputs(batch_size, + seqlen, + n_heads, + d_head, + itype, + device='cuda'): + + current_platform.seed_everything(0) + A = (-torch.exp(torch.rand(n_heads, dtype=itype, device=device))) + dt = F.softplus( + torch.randn(batch_size, seqlen, n_heads, dtype=itype, device=device) - + 4) + X = torch.randn((batch_size, seqlen, n_heads, d_head), + dtype=itype, + device=device) + B = torch.randn((batch_size, seqlen, n_heads, d_head), + dtype=itype, + device=device) + C = torch.randn((batch_size, seqlen, n_heads, d_head), + dtype=itype, + device=device) + + return A, dt, X, B, C + + +def generate_continous_batched_examples(example_lens_by_batch, + num_examples, + full_length, + last_taken, + exhausted, + n_heads, + d_head, + itype, + device='cuda'): + + # this function generates a random examples of certain length + # and then cut according to "example_lens_by_batch" and feed + # them in continuous batches to the kernels + + # generate the full-length example + A, dt, X, B, C = generate_random_inputs(num_examples, full_length, n_heads, + d_head, itype) + + Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), + A * dt, + B, + C, + block_len=full_length // 4) + + # internal function that outputs a cont batch of examples + # given a tuple of lengths for each example in the batch + # e.g., example_lens=(8, 4) means take 8 samples from first eg, + # 4 examples from second eg, etc + def get_continuous_batch(example_lens: Tuple[int, ...]): + + indices = [] + for i, x in enumerate(example_lens): + c = last_taken.get(i, 0) + indices.append((c, c + x)) + last_taken[i] = (c + x) % full_length + exhausted[i] = last_taken[i] == 0 + + return (torch.concat([x[i, s:e] for i, (s, e) in enumerate(indices) + ]).unsqueeze(0) for x in (dt, X, B, C)) + + # internal function that maps "n" to the appropriate right boundary + # value when forming continuous batches from examples of length given + # by "full_length". + # - e.g., when n > full_length, returns n % full_length + # when n == full_length, returns full_length + def end_boundary(n: int): + return n - ((n - 1) // full_length) * full_length + + IND_E = None + for spec in example_lens_by_batch: + + # get the (maybe partial) example seen in this cont batch + dt2, X2, B2, C2 = get_continuous_batch(spec) + + # get the metadata + cu_seqlens = torch.tensor((0, ) + spec, device=device).cumsum(dim=0) + sed_idx = torch.zeros(cu_seqlens[-1], + dtype=torch.int32, + device=cu_seqlens.device) + for i, (srt, end) in enumerate(zip( + cu_seqlens, + cu_seqlens[1:], + )): + sed_idx[srt:end] = i + + # for cont batch + if IND_E is None: + IND_S = [0 for _ in range(len(spec))] + else: + IND_S = [x % full_length for x in IND_E] + IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)] + + yield ([Y_min[s, IND_S[s]:IND_E[s]] for s in range(num_examples)], + cu_seqlens, sed_idx.unsqueeze(0), (A, dt2, X2, B2, C2)) + + +@pytest.mark.parametrize("itype", + [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("n_heads", [3, 4, 11, 16, 32]) +@pytest.mark.parametrize("d_head", [5, 8, 19, 32, 128]) +@pytest.mark.parametrize("seq_len_chunk_size", [(119, 17), (128, 32)]) +def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, + itype): + + # this tests the kernels on a single example (no batching) + + # set seed + batch_size = 1 # batch_size + # ssd_minimal_discrete requires chunk_size divide seqlen + # - this is only required for generating the reference seqs, + # it is not an operational limitation. + seqlen, chunk_size = seq_len_chunk_size + + A, dt, X, B, C = generate_random_inputs(batch_size, seqlen, n_heads, + d_head, itype) + + Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt, + B, C, chunk_size) + + Y, final_state = mamba_chunk_scan_combined(X, + dt, + A, + B, + C, + chunk_size, + D=None, + return_final_states=True) + + # just test the last in sequence + torch.allclose(Y[:, -1], Y_min[:, -1], atol=1e-3, rtol=1e-3) + + # just test the last head + # NOTE, in the kernel we always cast states to fp32 + torch.allclose(final_state[:, -1], + final_state_min[:, -1].to(torch.float32), + atol=1e-3, + rtol=1e-3) + + +@pytest.mark.parametrize("itype", [torch.float32, torch.float16]) +@pytest.mark.parametrize("n_heads", [4, 8, 13]) +@pytest.mark.parametrize("d_head", [5, 16, 21, 32]) +@pytest.mark.parametrize( + "seq_len_chunk_size_cases", + [ + + # small-ish chunk_size (8) + (64, 8, 2, [(64, 32), (64, 32)]), + (64, 8, 2, [(32, 32), (32, 32), (32, 32)]), + (64, 8, 2, [(8, 8), (8, 8), (8, 8)]), # chunk size boundary + (64, 8, 2, [(4, 4), (4, 4), (4, 4), + (4, 4)]), # chunk_size larger than cont batches + (64, 8, 5, [ + (64, 32, 16, 8, 8), + (8, 16, 32, 16, 8), + (8, 8, 16, 32, 16), + ]), # mode examples with varied lengths + + # odd chunk_size + (64, 29, 2, [(11, 4), (13, 23), (19, 22), + (21, 15)]), # irregular sizes + + # large-ish chunk_size (256) + (64, 256, 1, [(5, ), (1, ), (1, ), + (1, )]), # irregular sizes with small sequences + (64, 256, 2, [(5, 30), (1, 2), (1, 2), + (1, 2)]), # irregular sizes with small sequences + ]) +def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, + itype): + + # this test with multiple examples in a continuous batch + # (i.e. chunked prefill) + + seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases + + # hold state during the cutting process so we know if an + # example has been exhausted and needs to cycle + last_taken: Dict = {} # map: eg -> pointer to last taken sample + exhausted: Dict = {} # map: eg -> boolean indicating example is exhausted + + states = None + for Y_min, cu_seqlens, sed_idx, (A, dt, X, B, + C) in generate_continous_batched_examples( + cases, num_examples, seqlen, + last_taken, exhausted, n_heads, + d_head, itype): + + Y, new_states = mamba_chunk_scan_combined( + X, + dt, + A, + B, + C, + chunk_size, + D=None, + cu_seqlens=cu_seqlens, + seq_idx=sed_idx, + return_varlen_states=True, + initial_states=states, + ) + + # just test the last in sequence + for i in range(num_examples): + + # just test one dim and dstate + Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0] + Y_min_eg = Y_min[i][:, 0, 0] + torch.allclose(Y_eg, Y_min_eg, atol=1e-3, rtol=1e-3) + + # update states + states = new_states + for i, clear in exhausted.items(): + if clear: + states[i].fill_(0.) + exhausted[i] = False diff --git a/tests/models/decoder_only/language/test_jamba.py b/tests/models/decoder_only/language/test_hybrid.py similarity index 91% rename from tests/models/decoder_only/language/test_jamba.py rename to tests/models/decoder_only/language/test_hybrid.py index cc98f1d7b5ce8..a39b11923582c 100644 --- a/tests/models/decoder_only/language/test_jamba.py +++ b/tests/models/decoder_only/language/test_hybrid.py @@ -8,7 +8,8 @@ from ...utils import check_outputs_equal -MODELS = ["ai21labs/Jamba-tiny-dev"] +# This test is for the hybrid models +MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"] @pytest.mark.parametrize("model", MODELS) @@ -23,6 +24,10 @@ def test_models( max_tokens: int, ) -> None: + # numeric error produces different generation + if 'Bamba' in model: + example_prompts.pop(3) + with hf_runner( model, dtype=dtype, @@ -108,15 +113,21 @@ def test_mamba_prefill_chunking_with_parallel_sampling( @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [10]) +@pytest.mark.parametrize("max_tokens", [7]) def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts, model: str, dtype: str, max_tokens: int) -> None: # numeric error during prefill chucking produces different generation # compared to w/o prefill chunking for those examples, removed them for now - example_prompts.pop(7) - example_prompts.pop(2) - example_prompts.pop(1) + if 'Jamba' in model: + example_prompts.pop(7) + example_prompts.pop(2) + example_prompts.pop(1) + elif 'Bamba' in model: + example_prompts.pop(6) + example_prompts.pop(3) + example_prompts.pop(2) + dtype = "half" # use a different dtype for Bamba with hf_runner( model, @@ -145,7 +156,7 @@ def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts, @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [15]) def test_parallel_sampling( vllm_runner, @@ -249,17 +260,17 @@ def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks( dtype: str, example_prompts, ) -> None: - # This test is for verifying that the Jamba inner state management doesn't + # This test is for verifying that the hybrid inner state management doesn't # collapse in case where the number of incoming requests and # finished_requests_ids is larger than the maximum mamba block capacity. - # This could generally happen due to the fact that Jamba does support + # This could generally happen due to the fact that hybrid does support # statelessness mechanism where it can cleanup new incoming requests in # a single step. try: with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model: vllm_model.generate_greedy([example_prompts[0]] * 100, 10) except ValueError: - pytest.fail("Jamba inner state wasn't cleaned up properly between" + pytest.fail("Hybrid inner state wasn't cleaned up properly between" "steps finished requests registered unnecessarily ") @@ -271,14 +282,14 @@ def test_state_cleanup( dtype: str, example_prompts, ) -> None: - # This test is for verifying that the Jamba state is cleaned up between + # This test is for verifying that the Hybrid state is cleaned up between # steps, If its not cleaned, an error would be expected. try: with vllm_runner(model, dtype=dtype) as vllm_model: for _ in range(10): vllm_model.generate_greedy([example_prompts[0]] * 100, 1) except ValueError: - pytest.fail("Jamba inner state wasn't cleaned up between states, " + pytest.fail("Hybrid inner state wasn't cleaned up between states, " "could be related to finished_requests_ids") @@ -324,7 +335,7 @@ def test_multistep_correctness(vllm_runner, model: str, dtype: str, @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [64]) -def test_jamba_distributed_produces_identical_generation( +def test_hybrid_distributed_produces_identical_generation( vllm_runner, model: str, dtype: str, max_tokens: int, example_prompts) -> None: diff --git a/tests/models/registry.py b/tests/models/registry.py index 20787fe008aa8..3fd94b89c8a60 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -102,6 +102,7 @@ def check_available_online( trust_remote_code=True), "BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat", trust_remote_code=True), + "BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B"), "BloomForCausalLM": _HfExamplesInfo("bigscience/bloomz-1b1"), # ChatGLMModel supports multimodal "CohereForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r-v01", diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index 9f6e731afd193..f363ba0c1e30c 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -2,6 +2,7 @@ from collections import defaultdict from dataclasses import dataclass +from itertools import accumulate from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type import torch @@ -15,6 +16,7 @@ if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata) +from vllm.utils import async_tensor_h2d # Placeholder attention backend for models like Mamba and pooling models that # lack attention. @@ -77,43 +79,39 @@ class PlaceholderAttentionMetadata(AttentionMetadata): # seq_lens stored as a tensor. seq_lens_tensor: Optional[torch.Tensor] - # Maximum query length in the batch. - max_query_len: Optional[int] - - # Max number of query tokens among request in the batch. - max_decode_query_len: Optional[int] - # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. max_prefill_seq_len: int # Maximum sequence length among decode batch. 0 if there are prefill # requests only. max_decode_seq_len: int - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] # (batch_size,) A tensor of context lengths (tokens that are computed # so far). context_lens_tensor: Optional[torch.Tensor] - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks - # in the kv cache. Each block can contain up to block_size tokens. - # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph - # captured. - block_tables: Optional[torch.Tensor] - # Whether or not if cuda graph is enabled. # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool + # Maximum query length in the batch. + max_query_len: Optional[int] + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + + # Placeholder. + block_tables: Optional[torch.Tensor] = None + _cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None _cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None @@ -125,11 +123,17 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: if self._cached_prefill_metadata is not None: return self._cached_prefill_metadata - assert self.seq_lens is not None - assert self.seq_lens_tensor is not None - assert self.query_start_loc is not None - assert self.context_lens_tensor is not None - assert self.seq_start_loc is not None + # Compute some attn_metadata fields which default to None + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + seq_start_loc = (None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) # Placeholders slot_mapping = torch.empty(0) @@ -143,15 +147,15 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, enable_kv_scales_calculation=self.enable_kv_scales_calculation, - seq_lens=self.seq_lens[:self.num_prefills], - seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, max_decode_query_len=0, max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, - query_start_loc=self.query_start_loc[:self.num_prefills + 1], - seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], - context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, ) @@ -169,6 +173,8 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: # Placeholders slot_mapping = torch.empty(0) block_tables = torch.empty(0) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) self._cached_decode_metadata = PlaceholderAttentionMetadata( num_prefills=0, @@ -178,13 +184,16 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=True, seq_lens=None, - seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + seq_lens_tensor=seq_lens_tensor, max_decode_query_len=self.max_decode_query_len, max_query_len=None, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, - query_start_loc=None, - seq_start_loc=None, + query_start_loc=(self.query_start_loc[self.num_prefills:] - + self.query_start_loc[self.num_prefills]) + if self.query_start_loc is not None else None, + seq_start_loc=self.seq_start_loc[self.num_prefills:] + if self.seq_start_loc is not None else None, context_lens_tensor=None, block_tables=block_tables, use_cuda_graph=self.use_cuda_graph, @@ -235,8 +244,6 @@ def advance_step(self, assert self.context_lens_tensor is not None assert self.context_lens_tensor.shape == (num_queries, ) - assert self.block_tables is not None - # Update query lengths. Note that we update only queries and not seqs, # since tensors may be padded due to captured cuda graph batch size for i in range(num_queries): @@ -299,9 +306,6 @@ def _add_seq_group( self.num_prefill_tokens += token_len self.prefill_seq_lens.append(seq_len) else: - assert query_len == 1, ( - "seq_len: {}, context_len: {}, query_len: {}".format( - seq_len, context_len, query_len)) self.num_decode_tokens += query_len self.curr_seq_lens.append(curr_seq_len) @@ -323,15 +327,6 @@ def build(self, seq_lens: List[int], query_lens: List[int], device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1 - logits_soft_cap = getattr(self.runner.model_config.hf_config, - "attn_logit_softcapping", None) - if logits_soft_cap is not None: - raise ValueError( - "Please use Flashinfer backend for models with logits_soft_cap" - " (i.e., Gemma-2). Otherwise, the output might be wrong." - " Set Flashinfer backend by " - "export VLLM_ATTENTION_BACKEND=FLASHINFER.") - max_query_len = max(query_lens) decode_query_lens = query_lens[self.num_prefills:] if len(decode_query_lens) > 0: @@ -341,48 +336,37 @@ def build(self, seq_lens: List[int], query_lens: List[int], max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens + query_start_loc = list(accumulate(query_lens, initial=0)) + seq_start_loc = list(accumulate(seq_lens, initial=0)) if use_captured_graph: - num_decode_tokens = batch_size - + num_decode_tokens = batch_size - self.num_prefill_tokens assert max_query_len > 0, ("query_lens: {}".format(query_lens)) - context_lens_tensor = torch.tensor(self.context_lens, - dtype=torch.int, - device=device) - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=device) - query_lens_tensor = torch.tensor(query_lens, - dtype=torch.long, - device=device) - query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) + assert device is not None + context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, + device, self.runner.pin_memory) + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, + device, + self.runner.pin_memory) + seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, + device, self.runner.pin_memory) + placeholder_index_maps = { modality: placeholder_map.index_map() for modality, placeholder_map in self.multimodal_placeholder_maps.items() } - torch.cumsum(seq_lens_tensor, - dim=0, - dtype=seq_start_loc.dtype, - out=seq_start_loc[1:]) - torch.cumsum(query_lens_tensor, - dim=0, - dtype=query_start_loc.dtype, - out=query_start_loc[1:]) # Placeholders - slot_mapping = torch.empty(0) + slot_mapping_tensor = torch.empty(0) block_tables = torch.empty(0) return PlaceholderAttentionMetadata( num_prefills=self.num_prefills, - slot_mapping=slot_mapping, + slot_mapping=slot_mapping_tensor, multi_modal_placeholder_index_maps=placeholder_index_maps, enable_kv_scales_calculation=True, num_prefill_tokens=self.num_prefill_tokens, @@ -393,8 +377,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], max_decode_query_len=max_decode_query_len, max_prefill_seq_len=max_prefill_seq_len, max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, + query_start_loc=query_start_loc_tensor, + seq_start_loc=seq_start_loc_tensor, context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=use_captured_graph, diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py new file mode 100644 index 0000000000000..5fd1264910231 --- /dev/null +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -0,0 +1,534 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.attention.backends.placeholder_attn import ( + PlaceholderAttentionMetadata) +from vllm.attention.backends.xformers import XFormersMetadata +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce) +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + selective_state_update) +from vllm.model_executor.layers.mamba.ops.ssd_combined import ( + mamba_chunk_scan_combined) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import ( + LoaderFunction, composed_weight_loader, sharded_weight_loader) +from vllm.model_executor.models.mamba_cache import MambaCacheParams +from vllm.model_executor.utils import set_weight_attrs + +# Added by the IBM Team, 2024 + + +# Adapted from transformers.models.mamba2.modeling_mamba2.MambaRMSNormGated +@CustomOp.register("mixer2_gated_rms_norm") +class Mixer2RMSNormGated(CustomOp): + + def __init__(self, full_hidden_size, full_n_groups, eps=1e-6): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.full_hidden_size = full_hidden_size + self.group_size = full_hidden_size // full_n_groups + self.per_rank_hidden_size = full_hidden_size // self.tp_size + self.n_groups = full_hidden_size // self.group_size + + self.variance_epsilon = eps + self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size)) + set_weight_attrs(self.weight, + {"weight_loader": sharded_weight_loader(0)}) + assert self.full_hidden_size % self.tp_size== 0,\ + "Tensor parallel world size must divide hidden size." + + def forward_native( + self, + x: torch.Tensor, + gate: torch.Tensor, + ): + # Three tensor-parallel cases: + # 1. n_groups is 1 + # In this case we parallelize along the reduction dim. + # Each rank computes a local sum of squares followed by AllReduce + # 2. tp_size divides n_groups + # Each rank only reduces within its local group(s). + # No collective ops necessary. + # 3. The general case can be pretty complicated so we AllGather + # the input and then redundantly compute the RMSNorm. + input_dtype = x.dtype + x = x * nn.functional.silu(gate.to(torch.float32)) + + if self.n_groups == 1: + if self.tp_size > 1: + # Compute local sum and then reduce to obtain global sum + local_sums = x.pow(2).sum(dim=-1, keepdim=True) + global_sums = tensor_model_parallel_all_reduce(local_sums) + # Calculate the variance + count = self.tp_size * x.shape[-1] + variance = (global_sums / count) + + else: + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + else: + redundant_tp: bool = self.n_groups % self.tp_size != 0 + if redundant_tp: + # To handle the general case, redundantly apply the variance + x = tensor_model_parallel_all_gather(x, -1) + + *prefix_dims, hidden_dim = x.shape + group_count = hidden_dim // self.group_size + x_grouped = x.view(*prefix_dims, group_count, self.group_size) + variance = x_grouped.pow(2).mean(-1, keepdim=True) + x_grouped = x_grouped * torch.rsqrt(variance + + self.variance_epsilon) + x = x_grouped.view(*prefix_dims, hidden_dim) + + if redundant_tp: + start = self.per_rank_hidden_size * self.tp_rank + end = start + self.per_rank_hidden_size + x = x[..., start:end] + + return self.weight * x.to(input_dtype) + + def forward_cuda( + self, + x: torch.Tensor, + gate: torch.Tensor, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + + if self.tp_size > 1 or self.n_groups != 1: + return self.forward_native(x, gate) + + from vllm import _custom_ops as ops + + # cast x and gate to float32 before silu + out = torch.empty_like(x) + y = x * nn.functional.silu(gate.to(torch.float32)) + ops.rms_norm( + out, + y.to(x.dtype), + self.weight.data, + self.variance_epsilon, + ) + return out + + +def extra_groups_for_head_shards(ngroups: int, tp_size: int): + """Compute the increase in group numbers to account for + replication in order to accompany the head shards.""" + + # in the case ngoups % tp_size == 0, this will be zero + if ngroups % tp_size == 0: + return 0 + + return tp_size - ngroups % tp_size + + +def mamba_v2_sharded_weight_loader( + shard_spec: List[Tuple[int, int, float]], + tp_size: int, + tp_rank: int, +) -> LoaderFunction: + """Create a weight loader for mamba v2. This ensures that the projections + are correctly sharded so that they can be split into x, B, C. It also + ensures the the all the groups corresponding to a head shard is placed + together with it. + """ + + def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + + # - track boundary of (sharded) param, and loaded_weight, respectively + boundary, loaded_boundary = 0, 0 + + # - iterate over the shard specs + for full_dim, extra, ratio in shard_spec: + # - full dim is the model dim (before TP). + # - extra > 0, means there is expected overall increase + # of dimensions. This is so because of replication. + # - ratio is used map the tp_rank to the actual shard + # rank. This is useful when there is replication of + # groups to accompany head shards. + + # - size of the loaded shard + shard_size = full_dim // tp_size + + # - compute the rank into the loaded shard. + # - if there is replication, different TP shards will + # take from the same rank. + rank = tp_rank // ratio + + # - leftmost boundary index into loaded weight. + loaded_skip = rank * shard_size + loaded_start_idx = loaded_boundary + loaded_skip + + # - take these many dims from the loaded weight. + take = min(shard_size, full_dim - extra - loaded_skip) + + # - always shard on dim 0 + # - the ignore is for a mundane mypy error as it does not + # seem to handle slices well. + # https://github.com/python/mypy/issues/2410 + param.data[ + boundary:(boundary + take), # type: ignore[misc] + ...] = loaded_weight[loaded_start_idx:( # type: ignore[misc] + loaded_start_idx + take)] # type: ignore[misc] + + # move indexing boundaries + boundary += shard_size + loaded_boundary += (full_dim - extra) + + return loader + + +# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer +@CustomOp.register("mamba_mixer2") +class MambaMixer2(CustomOp): + """ + Compute ∆, A, B, C, and D the state space parameters and compute + the `contextualized_states`. A, D are input independent + (see Mamba paper [1] Section 3.5.2 "Interpretation of A" + for why A isn't selective) ∆, B, C are input-dependent + (this is a key difference between Mamba and the linear time + invariant S4, and is why Mamba is called + **selective** state spaces) + """ + + def __init__(self, + hidden_size: int, + ssm_state_size: int, + conv_kernel_size: int, + intermediate_size: int, + use_conv_bias: bool, + use_bias: bool, + n_groups: int = 1, + num_heads: int = 128, + head_dim: int = 64, + rms_norm_eps: float = 1e-5, + activation="silu", + chunk_size: int = 256, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + + # For TP, the sharding plan is as follows: + # - for the conv modules, since + # conv_dim = intermediate_size * 2 * n_groups * ssm_state_size, + # we shard intermediate_size and n_groups + # - since intermediate_size = n_heads * head_dim, sharding on + # intermediate_size is achieved by sharding on n_heads. + # - IF, world_size divides groups, then sharding + # (n_groups / world_size, n_heads / world_size) + # also maintains the invariant n_heads % n_groups == 0 + # - HOWEVER IF, world_size DOES NOT divide groups, then we need + # to allocate extra space in the shard, such that groups + # may be replicated to follow the head shard. + self.tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + assert num_heads % self.tp_size == 0, \ + "Tensor parallel world size must divide num heads." + + self.ssm_state_size = ssm_state_size + self.activation = activation + + self.chunk_size = chunk_size + self.intermediate_size = intermediate_size + self.head_dim = head_dim + self.num_heads = num_heads + + self.n_groups = n_groups + if n_groups % self.tp_size != 0: + # - for TP we shard conv_dim by sharding on n_groups, + # - but if n_groups cannot divide tp_size, we need to + # extend some extra groups + self.n_groups = n_groups + extra_groups_for_head_shards( + n_groups, self.tp_size) + + self.conv_dim = (intermediate_size + + 2 * self.n_groups * ssm_state_size) + self.conv1d = ColumnParallelLinear( + input_size=conv_kernel_size, + output_size=self.conv_dim, + bias=use_conv_bias, + quant_config=None, + ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` + # doesn't allow to override it + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + self.in_proj = ColumnParallelLinear(input_size=hidden_size, + output_size=intermediate_size + + self.conv_dim + self.num_heads, + bias=use_bias, + quant_config=quant_config) + + # - because in_proj is a concatenation of 3 weights, we + # need to interleave them before sharding + # - use the custom weight loader mamba_v2_sharded_weight_loader + # for conv1d.bias, covn1d.weight and in_proj.weight + # - need to set these settings, to assign the groups to the head shards + group_shard_settings = ( + self.n_groups * self.ssm_state_size, # expected model size + (self.n_groups - n_groups) * + self.ssm_state_size, # extra dims assigned + self.num_heads // + n_groups, # ratio for mapping back to original group + ) + intermediate_settings = (intermediate_size, 0, 1) + head_setings = (self.num_heads, 0, 1) + + # - the weight already has a "weight_loader" attribute + # which set_weight_attrs will raise if we do not + # delete before trying to override it + # - ditto for the otther two weights below + delattr(self.conv1d.bias, "weight_loader") + set_weight_attrs( + self.conv1d.bias, { + "weight_loader": + mamba_v2_sharded_weight_loader( + [ + intermediate_settings, + group_shard_settings, + group_shard_settings, + ], + self.tp_size, + tp_rank, + ) + }) + + delattr(self.conv1d.weight, "weight_loader") + set_weight_attrs( + self.conv1d.weight, { + "weight_loader": + mamba_v2_sharded_weight_loader([ + intermediate_settings, + group_shard_settings, + group_shard_settings, + ], self.tp_size, tp_rank) + }) + + delattr(self.in_proj.weight, "weight_loader") + set_weight_attrs( + self.in_proj.weight, + { + "weight_loader": + mamba_v2_sharded_weight_loader( + [ + intermediate_settings, # for gate + intermediate_settings, + group_shard_settings, + group_shard_settings, + head_setings, # for dt + ], + self.tp_size, + tp_rank) + }) + + # - these are TPed by heads to reduce the size of the + # temporal shape + self.A = nn.Parameter( + torch.empty( + divide(num_heads, self.tp_size), + dtype=torch.float32, + )) + self.D = nn.Parameter(torch.ones(num_heads // self.tp_size)) + self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size)) + + set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)}) + a_weight_loader = composed_weight_loader( + sharded_weight_loader(0), lambda x: -torch.exp(x.float())) + set_weight_attrs(self.A, {"weight_loader": a_weight_loader}) + set_weight_attrs(self.dt_bias, + {"weight_loader": sharded_weight_loader(0)}) + + self.out_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=use_bias, + input_is_parallel=True, + quant_config=quant_config) + + self.norm = Mixer2RMSNormGated(intermediate_size, + n_groups, + eps=rms_norm_eps) + + def forward_native(self, hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + conv_state: torch.Tensor, ssm_state: torch.Tensor): + pass + + def forward_cuda( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + mamba_cache_params: MambaCacheParams, + sequence_idx: Optional[torch.Tensor] = None, + ): + + seq_len, _ = hidden_states.shape + groups_time_state_size = self.n_groups * self.ssm_state_size + + # detect if there are prefills + has_prefill = attn_metadata.num_prefills > 0 + + # - also need flags to indicate if there are initial states + # - currently we really only support the FlashAttention backend + has_initial_states = None + if (isinstance(attn_metadata, + (FlashAttentionMetadata, XFormersMetadata, + PlaceholderAttentionMetadata)) + and attn_metadata.context_lens_tensor is not None): + has_initial_states = attn_metadata.context_lens_tensor > 0 + + # 1. Gated MLP's linear projection + projected_states, _ = self.in_proj(hidden_states) + gate, hidden_states_B_C, dt = torch.split( + projected_states, + [ + self.intermediate_size // self.tp_size, + self.conv_dim // self.tp_size, + self.num_heads // self.tp_size, + ], + dim=-1, + ) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + + if has_prefill: + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + # - "cache_indices" updates the conv_state cache in positions + # pointed to by "mamba_cache_params.state_indices_tensor" + hidden_states_B_C = causal_conv1d_fn( + hidden_states_B_C.transpose(0, 1), + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=mamba_cache_params.conv_state, + has_initial_state=has_initial_states, + cache_indices=mamba_cache_params.state_indices_tensor, + query_start_loc=attn_metadata.query_start_loc).transpose( + 0, 1)[:seq_len] + + # TODO: Why is this needed? + hidden_states_B_C = hidden_states_B_C.contiguous() + else: + hidden_states_B_C = causal_conv1d_update( + hidden_states_B_C, + mamba_cache_params.conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=mamba_cache_params.state_indices_tensor) + + # - get hidden_states, B and C after depthwise convolution. + hidden_states, B, C = torch.split( + hidden_states_B_C, + [ + self.intermediate_size // self.tp_size, + groups_time_state_size // self.tp_size, + groups_time_state_size // self.tp_size, + ], + dim=-1, + ) + + # 3. State Space Model sequence transformation + if has_prefill: + + initial_states = None + if has_initial_states is not None and any(has_initial_states): + for idx in mamba_cache_params.state_indices_tensor[ + ~has_initial_states]: + mamba_cache_params.ssm_state[idx].zero_() + initial_states = mamba_cache_params.ssm_state[ + mamba_cache_params.state_indices_tensor] + + scan_output, varlen_state = mamba_chunk_scan_combined( + hidden_states.view(1, seq_len, self.num_heads // self.tp_size, + self.head_dim), + dt.unsqueeze(0), + self.A, + B.view(1, seq_len, self.n_groups // self.tp_size, -1), + C.view(1, seq_len, self.n_groups // self.tp_size, -1), + chunk_size=self.chunk_size, + D=self.D, + z=None, + dt_bias=self.dt_bias, + seq_idx=sequence_idx, + cu_seqlens=attn_metadata.query_start_loc, + initial_states=initial_states, + return_varlen_states=True, + return_final_states=False, + dt_softplus=True, + dt_limit=(0.0, float("inf")), + ) + + # update ssm states + # - varlen state is a (batch, nheads, headdim, dstate) tensor + for i, idx in enumerate(mamba_cache_params.state_indices_tensor): + mamba_cache_params.ssm_state[idx].copy_(varlen_state[i]) + + # - reshape + hidden_states = scan_output.view(seq_len, -1) + else: + + n_groups = self.n_groups // self.tp_size + A = self.A[:, None, ...][:, :, None].expand( + -1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + dt = dt[:, :, None].expand(-1, -1, self.head_dim) + dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) + D = self.D[:, None, ...].expand(-1, self.head_dim) + B = B.view(-1, n_groups, B.shape[1] // n_groups) + C = C.view(-1, n_groups, C.shape[1] // n_groups) + hidden_states_reshaped = hidden_states.view( + -1, self.num_heads // self.tp_size, self.head_dim) + + # - the hidden is reshaped into number of current batches + # - in this case there is no more prefill, so the batches gen + # 1 token at a time + # - thus hidden will be (bs, num_heads, head_dim) + # - mamba_cache_params.ssm_state's slots will be selected + # using "mamba_cache_params.state_indices_tensor", just as + # above in the prefill case + + hidden_states = selective_state_update( + mamba_cache_params.ssm_state, + hidden_states_reshaped, + dt, + A, + B, + C, + D, + z=None, + dt_bias=dt_bias, + dt_softplus=True, + state_batch_indices=mamba_cache_params.state_indices_tensor, + ) + hidden_states = hidden_states.view( + -1, (self.num_heads // self.tp_size) * self.head_dim) + + # # 4. gated MLP + hidden_states = self.norm(hidden_states, gate) + + # # 5. Final linear projection + out, _ = self.out_proj(hidden_states) + return out diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 3c35f1ac0dcf5..b31b980fbe84a 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright (c) 2024, Tri Dao, Albert Gu. -# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py import torch import triton diff --git a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py new file mode 100644 index 0000000000000..388a63327213b --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py @@ -0,0 +1,261 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_bmm.py + +# ruff: noqa: E501,SIM102 + +import math + +import torch +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=2), + ], + key=['chunk_size', 'K', 'IS_CAUSAL'], +) +@triton.jit +def _bmm_chunk_fwd_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + out_ptr, + seq_idx_ptr, + # Matrix dimensions + seqlen, + chunk_size, + K, + ngroups, + stride_a_batch, + stride_a_seqlen, + stride_a_head, + stride_ak, + stride_b_batch, + stride_b_seqlen, + stride_b_head, + stride_bk, + stride_out_batch, + stride_out_chunk, + stride_out_head, + stride_outm, + stride_outn, + stride_seq_idx_batch, + stride_seq_idx_seqlen, + # Meta-parameters + IS_CAUSAL: tl.constexpr, + dot_dtype: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_ch = tl.program_id(axis=2).to(tl.int64) + pid_c = pid_ch // ngroups + pid_h = pid_ch - pid_c * ngroups + num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + if IS_CAUSAL: + if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M: + return + a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + + offs_n[None, :] * stride_b_seqlen) + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0).to(dot_dtype) + b = tl.load(b_ptrs, + mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & + (offs_n[None, :] < chunk_size_limit), + other=0.0).to(dot_dtype) + acc += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + if HAS_SEQ_IDX: + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, + mask=offs_m < chunk_size_limit, + other=-1) + seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, + mask=offs_n < chunk_size_limit, + other=-2) + acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0) + out = acc.to(out_ptr.dtype.element_ty) + + out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head + out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + + offs_n[None, :] * stride_outn) + tl.store(out_ptrs, + out, + mask=(offs_m[:, None] < chunk_size) & + (offs_n[None, :] < chunk_size)) + + +def _bmm_chunk_fwd(a, + b, + chunk_size, + seq_idx=None, + causal=False, + output_dtype=None): + """ + Argument: + a: (batch, seqlen, k) or (batch, seqlen, ngroups, k) + b: (batch, seqlen, k) or (batch, seqlen, ngroups, k) + seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out. + causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are + guaranteed to be correct. + Return: + out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size) + """ + # Check constraints. + has_groups = a.dim() == 4 + if not has_groups: + batch, seqlen, k = a.shape + else: + batch, seqlen, ngroups, k = a.shape + assert b.shape == a.shape + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if a.stride(-1) != 1 and a.stride(1) != 1: + a = a.contiguous() + if b.stride(-1) != 1 and b.stride(1) != 1: + b = b.contiguous() + nchunks = math.ceil(seqlen / chunk_size) + # Allocates output. + out_dtype = a.dtype if output_dtype is None else output_dtype + out = torch.empty( + (batch, nchunks, chunk_size, chunk_size) if not has_groups else + (batch, nchunks, ngroups, chunk_size, chunk_size), + device=a.device, + dtype=out_dtype) + dot_dtype = (tl.bfloat16 + if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else + (tl.float16 if a.dtype == torch.float16 + or b.dtype == torch.float16 else tl.float32)) + grid = lambda META: (triton.cdiv( + chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv( + chunk_size, META['BLOCK_SIZE_N']), batch, nchunks + if not has_groups else nchunks * ngroups) + with torch.cuda.device(a.device.index): + _bmm_chunk_fwd_kernel[grid]( + a, + b, + out, + seq_idx, + seqlen, + chunk_size, + k, + ngroups if has_groups else 1, + a.stride(0), + a.stride(1), + 0 if not has_groups else a.stride(2), + a.stride(-1), + b.stride(0), + b.stride(1), + 0 if not has_groups else b.stride(2), + b.stride(-1), + out.stride(0), + out.stride(1), + 0 if not has_groups else out.stride(2), + out.stride(-2), + out.stride(-1), + *((seq_idx.stride(0), + seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + causal, + dot_dtype, + HAS_SEQ_IDX=seq_idx is not None, + ) + return out diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py new file mode 100644 index 0000000000000..722fbd714ca8f --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -0,0 +1,615 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_scan.py + +# ruff: noqa: E501,SIM102 + +import math + +import torch +import triton +import triton.language as tl +from packaging import version + +TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') + + +@triton.autotune( + configs=[ + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 64 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 64 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=2), + ], + key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'], +) +@triton.jit +def _chunk_scan_fwd_kernel( + # Pointers to matrices + cb_ptr, + x_ptr, + z_ptr, + out_ptr, + out_x_ptr, + dt_ptr, + dA_cumsum_ptr, + seq_idx_ptr, + C_ptr, + states_ptr, + D_ptr, + initstates_ptr, + chunk_indices_ptr, + chunk_offsets_ptr, + chunk_meta_num, + # Matrix dimensions + chunk_size, + hdim, + dstate, + batch, + seqlen, + nheads_ngroups_ratio, + # Strides + stride_cb_batch, + stride_cb_chunk, + stride_cb_head, + stride_cb_csize_m, + stride_cb_csize_k, + stride_x_batch, + stride_x_seqlen, + stride_x_head, + stride_x_hdim, + stride_z_batch, + stride_z_seqlen, + stride_z_head, + stride_z_hdim, + stride_out_batch, + stride_out_seqlen, + stride_out_head, + stride_out_hdim, + stride_dt_batch, + stride_dt_chunk, + stride_dt_head, + stride_dt_csize, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + stride_seq_idx_batch, + stride_seq_idx_seqlen, + stride_C_batch, + stride_C_seqlen, + stride_C_head, + stride_C_dstate, + stride_states_batch, + stride_states_chunk, + stride_states_head, + stride_states_hdim, + stride_states_dstate, + stride_init_states_batch, + stride_init_states_head, + stride_init_states_hdim, + stride_init_states_dstate, + stride_D_head, + # Meta-parameters + IS_CAUSAL: tl.constexpr, + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + HAS_Z: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, + IS_TRITON_22: tl.constexpr, + HAS_INITSTATES: tl.constexpr, +): + pid_bc = tl.program_id(axis=1).to(tl.int64) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + if not HAS_INITSTATES: + c_idx = pid_c + c_off = 0 + else: + c_idx = tl.load(chunk_indices_ptr + pid_c, mask=pid_c > -1, other=0) + c_off = tl.load(chunk_offsets_ptr + pid_c, mask=pid_c > -1, other=0) + + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + cb_ptr += pid_b * stride_cb_batch + c_idx * stride_cb_chunk + ( + pid_h // nheads_ngroups_ratio) * stride_cb_head + x_ptr += pid_b * stride_x_batch + c_idx * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + c_idx * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + c_idx * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + C_ptr += pid_b * stride_C_batch + c_idx * chunk_size * stride_C_seqlen + ( + pid_h // nheads_ngroups_ratio) * stride_C_head + + # M-block offsets and prev states + # - logic in next block may override these if there is an active offset + offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M) + prev_states_ptr = states_ptr + pid_b * stride_states_batch + c_idx * stride_states_chunk + pid_h * stride_states_head + prev_states_hdim = stride_states_hdim + prev_states_dstate = stride_states_dstate + + chunk_size_limit = min(chunk_size, seqlen - c_idx * chunk_size) + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + c_idx * chunk_size * stride_seq_idx_seqlen + + # - we only need seq_idx_prev to be aligned to chunk boundary + seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, + mask=c_idx >= 1, + other=0) + + if HAS_INITSTATES: + # if there are init states, we only need seq_idx_m to point + # what is the current seq_idx + + # get current seq idx + if (pid_m * BLOCK_SIZE_M + c_off) < chunk_size_limit: + seq_idx_m = tl.load( + seq_idx_ptr + + (pid_m * BLOCK_SIZE_M + c_off) * stride_seq_idx_seqlen, ) + + # - recall that in ssd_state_passing, for the case c_off == 0 + # i.e., the very first sequence, we made states_ptr hold its initial state + # so this edge case is taken care of + if ((c_off == 0) and + (seq_idx_prev != seq_idx_m + ) # if a seq is changed exactly on boundary + or (c_off > 0) # implies a new example (pseudo chunk) + ): + + # - replace prev_states_ptr with init_states + prev_states_ptr = initstates_ptr + seq_idx_m * stride_init_states_batch + pid_h * stride_init_states_head + prev_states_hdim = stride_init_states_hdim # override strides + prev_states_dstate = stride_init_states_dstate + + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, + mask=offs_m < chunk_size, + other=0.0).to(tl.float32) + + # - handle chunk state limit + if HAS_INITSTATES: + + # have to split this if otherwise compilation will have problems + dA_cs_m_boundary = 0.0 + + # get the c_idx for the next (logica) chunk + c_idx_n = tl.load( + chunk_indices_ptr + (pid_c + 1), + mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num, + other=-1 # to trigger different chunk + ) + + # - there are things to consider + # A. if c_off > 0 then we need to move the dA_cs boundary to ensure correct + # contribution of past states + # B. if c_off_n < chunk_size_limit, then we need to adjust this so as not to + # encroach into the next sequence, where c_off_n is the offset of the next + # (logical) chunk. + # An equivalent check for B is c_idx == c_idx_n, where there is repetition in + # (logical) chunk indices. + + if (c_idx == c_idx_n) or c_off > 0: + + # get the next offset + c_off_n = tl.load(chunk_offsets_ptr + (pid_c + 1), + mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num, + other=chunk_size) + + # in this case, adjust down the chunk_size_limit + if c_idx == c_idx_n: + chunk_size_limit = min(c_off_n, chunk_size_limit) + + # get the cs at the offset boundary + # - c_off == 0 is a passthrough + dA_cs_m_boundary = tl.load( + dA_cumsum_ptr + + (pid_m * BLOCK_SIZE_M + c_off - 1) * stride_dA_cs_csize, + mask=(pid_m * BLOCK_SIZE_M + c_off - 1) > -1, + other=0.0).to(tl.float32) + + if HAS_SEQ_IDX: + # - handle seq idx when HAS_INITSTATES==False + if not HAS_INITSTATES: + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, + mask=offs_m < chunk_size_limit, + other=-1) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Without the if (pid_c > -1), with Triton 2.1.0, I get + # Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed. + # With Triton 2.2.0, this works + if IS_TRITON_22 or c_idx > -1: + # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 + offs_k_dstate = tl.arange( + 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) + C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + + offs_k_dstate[None, :] * stride_C_dstate) + + prev_states_ptrs = prev_states_ptr + ( + offs_n[None, :] * prev_states_hdim + + offs_k_dstate[:, None] * prev_states_dstate) + if HAS_SEQ_IDX: + + if not HAS_INITSTATES: + # - this is for continuous batching where there is no init states + scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), + 0.0) + else: + # - if there is initstates, we will rely on prev_states, no zeroing + # required. + scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary) + else: + scale_m = tl.exp(dA_cs_m) + if BLOCK_SIZE_DSTATE <= 128: + C = tl.load(C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_k_dstate[None, :] < dstate), + other=0.0) + + prev_states = tl.load(prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate) & + (offs_n[None, :] < hdim), + other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + acc = tl.dot(C, prev_states) * scale_m[:, None] + else: + for k in range(0, dstate, BLOCK_SIZE_K): + C = tl.load(C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_k_dstate[None, :] < dstate - k), + other=0.0) + # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty) + prev_states = tl.load( + prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate - k) & + (offs_n[None, :] < hdim), + other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + acc += tl.dot(C, prev_states) + C_ptrs += BLOCK_SIZE_K + prev_states_ptrs += BLOCK_SIZE_K + acc *= scale_m[:, None] + + offs_k = tl.arange(0, BLOCK_SIZE_K) + c_off + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + + offs_k[None, :] * stride_cb_csize_k) + x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen + + offs_n[None, :] * stride_x_hdim) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + K_MAX = chunk_size_limit if not IS_CAUSAL else min( + (pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) + for k in range(0, K_MAX, BLOCK_SIZE_K): + cb = tl.load(cb_ptrs, + mask=(offs_m[:, None] < chunk_size) & + (offs_k[None, :] < chunk_size - k), + other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, + mask=offs_k < chunk_size - k, + other=0.0).to(tl.float32) + # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j]. + # So we don't need masking wrt seq_idx here. + cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :]) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, + other=0.0).to(tl.float32) + cb *= dt_k + if IS_CAUSAL: + mask = offs_m[:, None] >= k + offs_k[None, :] + cb = tl.where(mask, cb, 0.0) + cb = cb.to(x_ptr.dtype.element_ty) + x = tl.load(x_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & + (offs_n[None, :] < hdim), + other=0.0) + acc += tl.dot(cb, x) + cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + offs_out_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M) + offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + if HAS_D: + if D_HAS_HDIM: + D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, + mask=offs_n < hdim, + other=0.0).to(tl.float32) + else: + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + x_residual = tl.load(x_ptr + (offs_m[:, None] * stride_x_seqlen + + offs_n[None, :] * stride_x_hdim), + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_n[None, :] < hdim), + other=0.0).to(tl.float32) + acc += x_residual * D + + if HAS_Z: + out_x_ptr += pid_b * stride_out_batch + c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head + out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + + offs_out_n[None, :]) + tl.store(out_x_ptrs, + acc, + mask=(offs_out_m[:, None] < chunk_size_limit) & + (offs_out_n[None, :] < hdim)) + + z_ptr += pid_b * stride_z_batch + c_idx * chunk_size * stride_z_seqlen + pid_h * stride_z_head + z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + + stride_z_hdim * offs_out_n[None, :]) + z = tl.load(z_ptrs, + mask=(offs_out_m[:, None] < chunk_size_limit) & + (offs_out_n[None, :] < hdim), + other=0.0).to(tl.float32) + acc *= z * tl.sigmoid(z) + + out_ptr += pid_b * stride_out_batch + c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head + out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + + offs_out_n[None, :] * stride_out_hdim) + tl.store(out_ptrs, + acc, + mask=(offs_out_m[:, None] < chunk_size_limit) & + (offs_out_n[None, :] < hdim)) + + +def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int): + + # convert seq_idx to chunk indices and offsets + # - derive the cu_seqlens + _, cu_seqlens = torch.where(seq_idx.diff()) + cu_seqlens += 1 + + # outputs will have length expansion of chunks that do not divide + # chunk_size + N = math.ceil(seq_idx.shape[-1] / chunk_size) + (cu_seqlens % chunk_size + > 0).sum() + chunk_indices = torch.arange(N, dtype=torch.int, device=seq_idx.device) + chunk_offsets = torch.zeros((N, ), dtype=torch.int, device=seq_idx.device) + + cu_seqlens = cu_seqlens.tolist() + [seq_idx.shape[-1]] + p = 0 # num of insertions + for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]): + + # if does not divide chunk_size, then there is one chunk insertion + p += (s % chunk_size > 0) + + # get the dimensions + _s, _e = s // chunk_size + p, e // chunk_size + p + 1 + + # adjust inidces and offsets + chunk_indices[_s:_e] -= p + chunk_offsets[_s] = s % chunk_size + + return chunk_indices, chunk_offsets + + +def _chunk_scan_fwd( + cb, + x, + dt, + dA_cumsum, + C, + states, + D=None, + z=None, + seq_idx=None, + initial_states=None, +): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = C.shape + assert nheads % ngroups == 0 + assert C.shape == (batch, seqlen, ngroups, dstate) + assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads, ) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert states.shape == (batch, nchunks, nheads, headdim, dstate) + + chunk_indices, chunk_offsets = None, None + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + + if initial_states is not None: + # with initial states, we need to take care of how + # seq_idx crosses the boundaries + assert batch == 1, "chunk scan only supports initial states with batch 1" + assert initial_states.shape == (seq_idx[0].max() + 1, nheads, + headdim, dstate) + + if initial_states.shape[0] == 1: + # no in this case no point to use initial states + initial_states = None + else: + chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets( + seq_idx, chunk_size) + + # Allocates output. + out = torch.empty(batch, + seqlen, + nheads, + headdim, + device=x.device, + dtype=x.dtype) + if z is not None: + out_x = torch.empty(batch, + seqlen, + nheads, + headdim, + device=x.device, + dtype=x.dtype) + assert out_x.stride() == out.stride() + else: + out_x = None + + grid = lambda META: ( + triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv( + headdim, META['BLOCK_SIZE_N']), batch * nchunks + if chunk_offsets is None else len(chunk_offsets), nheads) + z_strides = ((z.stride(0), z.stride(1), z.stride(2), + z.stride(3)) if z is not None else (0, 0, 0, 0)) + _chunk_scan_fwd_kernel[grid]( + cb, + x, + z, + out, + out_x, + dt, + dA_cumsum, + seq_idx, + C, + states, + D, + initial_states, + chunk_indices, + chunk_offsets, + len(chunk_indices) if chunk_indices is not None else 0, + chunk_size, + headdim, + dstate, + batch, + seqlen, + nheads // ngroups, + cb.stride(0), + cb.stride(1), + cb.stride(2), + cb.stride(3), + cb.stride(4), + x.stride(0), + x.stride(1), + x.stride(2), + x.stride(3), + z_strides[0], + z_strides[1], + z_strides[2], + z_strides[3], + out.stride(0), + out.stride(1), + out.stride(2), + out.stride(3), + dt.stride(0), + dt.stride(2), + dt.stride(1), + dt.stride(3), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else + (0, 0)), + C.stride(0), + C.stride(1), + C.stride(2), + C.stride(3), + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + states.stride(4), + *((initial_states.stride(0), initial_states.stride(1), + initial_states.stride(2), + initial_states.stride(3)) if initial_states is not None else + (0, 0, 0, 0)), + D.stride(0) if D is not None else 0, + True, + D is not None, + D.dim() == 2 if D is not None else True, + BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), + HAS_Z=z is not None, + HAS_SEQ_IDX=seq_idx is not None, + IS_TRITON_22=TRITON_22, + HAS_INITSTATES=initial_states is not None, + ) + return out, out_x diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py new file mode 100644 index 0000000000000..a970ac94580b4 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -0,0 +1,750 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_state.py + +# ruff: noqa: E501 + +import math + +import torch +import triton +import triton.language as tl + +from .mamba_ssm import softplus + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_H': 1}), + triton.Config({'BLOCK_SIZE_H': 2}), + triton.Config({'BLOCK_SIZE_H': 4}), + triton.Config({'BLOCK_SIZE_H': 8}), + triton.Config({'BLOCK_SIZE_H': 16}), + triton.Config({'BLOCK_SIZE_H': 32}), + triton.Config({'BLOCK_SIZE_H': 64}), + ], + key=['chunk_size', 'nheads'], +) +@triton.jit +def _chunk_cumsum_fwd_kernel( + # Pointers to matrices + dt_ptr, + A_ptr, + dt_bias_ptr, + dt_out_ptr, + dA_cumsum_ptr, + # Matrix dimension + batch, + seqlen, + nheads, + chunk_size, + dt_min, + dt_max, + # Strides + stride_dt_batch, + stride_dt_seqlen, + stride_dt_head, + stride_A_head, + stride_dt_bias_head, + stride_dt_out_batch, + stride_dt_out_chunk, + stride_dt_out_head, + stride_dt_out_csize, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + # Meta-parameters + DT_SOFTPLUS: tl.constexpr, + HAS_DT_BIAS: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_CHUNK: tl.constexpr, +): + pid_b = tl.program_id(axis=0) + + # if dt is long, may cause problems, so use 64 bit + # https://github.com/triton-lang/triton/issues/1058 + pid_c = tl.program_id(axis=1).to(tl.int64) + pid_h = tl.program_id(axis=2) + dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen + dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + + offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) + offs_c = tl.arange(0, BLOCK_SIZE_CHUNK) + dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + + offs_c[None, :] * stride_dt_seqlen) + A_ptrs = A_ptr + offs_h * stride_A_head + dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head + + offs_c[None, :] * stride_dt_out_csize) + dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + + offs_c[None, :] * stride_dA_cs_csize) + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + dt = tl.load(dt_ptrs, + mask=(offs_h[:, None] < nheads) & + (offs_c[None, :] < chunk_size_limit), + other=0.0).to(tl.float32) + if HAS_DT_BIAS: + dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, + mask=offs_h < nheads, + other=0.0).to(tl.float32) + dt += dt_bias[:, None] + if DT_SOFTPLUS: + dt = tl.where(dt <= 20.0, softplus(dt), dt) + # As of Triton 2.2.0, tl.clamp is not available yet + # dt = tl.clamp(dt, dt_min, dt_max) + dt = tl.minimum(tl.maximum(dt, dt_min), dt_max) + dt = tl.where( + (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, + 0.0) + tl.store(dt_out_ptrs, + dt, + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32) + dA = dt * A[:, None] + dA_cs = tl.cumsum(dA, axis=1) + tl.store(dA_cs_ptrs, + dA_cs, + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + + +@triton.autotune( + configs=[ + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=2), + ], + key=['hdim', 'dstate', 'chunk_size'], +) +@triton.jit +def _chunk_state_fwd_kernel( + # Pointers to matrices + x_ptr, + b_ptr, + states_ptr, + dt_ptr, + dA_cumsum_ptr, + seq_idx_ptr, + # Matrix dimensions + hdim, + dstate, + chunk_size, + batch, + seqlen, + nheads_ngroups_ratio, + # Strides + stride_x_batch, + stride_x_seqlen, + stride_x_head, + stride_x_hdim, + stride_b_batch, + stride_b_seqlen, + stride_b_head, + stride_b_dstate, + stride_states_batch, + stride_states_chunk, + stride_states_head, + stride_states_hdim, + stride_states_dstate, + stride_dt_batch, + stride_dt_chunk, + stride_dt_head, + stride_dt_csize, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + stride_seq_idx_batch, + stride_seq_idx_seqlen, + # Meta-parameters + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1).to(tl.int64) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + ( + pid_h // nheads_ngroups_ratio) * stride_b_head + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + + offs_k[None, :] * stride_x_seqlen) + b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + + offs_k[:, None] * stride_b_seqlen) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cs_last = tl.load(dA_cumsum_ptr + + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + if HAS_SEQ_IDX: + seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + if HAS_SEQ_IDX: + seq_idx_last = tl.load(seq_idx_ptr + + (chunk_size_limit - 1) * stride_seq_idx_seqlen) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, chunk_size_limit, BLOCK_SIZE_K): + x = tl.load(x_ptrs, + mask=(offs_m[:, None] < hdim) & + (offs_k[None, :] < chunk_size_limit - k), + other=0.0) + b = tl.load(b_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & + (offs_n[None, :] < dstate), + other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, + mask=offs_k < chunk_size_limit - k, + other=0.0).to(tl.float32) + if HAS_SEQ_IDX: + seq_idx_k = tl.load(seq_idx_ptrs, + mask=offs_k < chunk_size_limit - k, + other=-1) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, + other=0.0).to(tl.float32) + if not HAS_SEQ_IDX: + scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k + else: + scale = tl.where(seq_idx_k == seq_idx_last, + tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0) + b *= scale[:, None] + b = b.to(x_ptr.dtype.element_ty) + acc += tl.dot(x, b) + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + b_ptrs += BLOCK_SIZE_K * stride_b_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + if HAS_SEQ_IDX: + seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen + states = acc.to(states_ptr.dtype.element_ty) + + states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + + offs_n[None, :] * stride_states_dstate) + c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) + tl.store(states_ptrs, states, mask=c_mask) + + +@triton.autotune( + configs=[ + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=2), + ], + key=['hdim', 'dstate', 'chunk_size'], +) +@triton.jit +def _chunk_state_varlen_kernel( + # Pointers to matrices + x_ptr, + b_ptr, + dt_ptr, + dA_cumsum_ptr, + chunk_states_ptr, + cu_seqlens_ptr, + states_ptr, + initstates_ptr, + # Matrix dimensions + hdim, + dstate, + chunk_size, + seqlen, + nheads_ngroups_ratio, + # Strides + stride_x_seqlen, + stride_x_head, + stride_x_hdim, + stride_b_seqlen, + stride_b_head, + stride_b_dstate, + stride_dt_chunk, + stride_dt_head, + stride_dt_csize, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + stride_chunk_states_chunk, + stride_chunk_states_head, + stride_chunk_states_hdim, + stride_chunk_states_dstate, + stride_states_batch, + stride_states_head, + stride_states_hdim, + stride_states_dstate, + stride_init_states_batch, + stride_init_states_head, + stride_init_states_hdim, + stride_init_states_dstate, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + HAS_INITSTATES: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + end_idx = tl.load(cu_seqlens_ptr + pid_b + 1) + pid_c = (end_idx - 1) // chunk_size + b_ptr += pid_c * chunk_size * stride_b_seqlen + ( + pid_h // nheads_ngroups_ratio) * stride_b_head + x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + chunk_states_ptr += pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head + + if HAS_INITSTATES: + # if there are init states provided, we differentiate between states (which + # are boundary conditions at a chunk boundary) and initstates (which are boundary + # conditions when a new example in a cont batch starts) + initstates_ptr += pid_h * stride_init_states_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + + offs_k[None, :] * stride_x_seqlen) + b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + + offs_k[:, None] * stride_b_seqlen) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cs_last = tl.load(dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * + stride_dA_cs_csize).to(tl.float32) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + + chunk_size_limit = end_idx - pid_c * chunk_size + start_idx = tl.load(cu_seqlens_ptr + pid_b) + start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, chunk_size_limit, BLOCK_SIZE_K): + x = tl.load(x_ptrs, + mask=(offs_m[:, None] < hdim) & + (offs_k[None, :] < chunk_size_limit - k) & + (offs_k[None, :] >= start_idx_cur - k), + other=0.0) + b = tl.load(b_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & + (offs_n[None, :] < dstate) & + (offs_k[:, None] >= start_idx_cur - k), + other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, + mask=offs_k < chunk_size_limit - k, + other=0.0).to(tl.float32) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, + other=0.0).to(tl.float32) + scale = tl.where( + (offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k), + tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0) + b *= scale[:, None] + b = b.to(x_ptr.dtype.element_ty) + acc += tl.dot(x, b) + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + b_ptrs += BLOCK_SIZE_K * stride_b_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + # If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk + # If HAS_INITSTATES==True need to consider two possiblties + # - if start_idx < pid_c * chunk_size, then we need to take the past_states_ptrs + # - if state_idx >= pid * chunk_size, then we need to insert initstates + if ((start_idx < pid_c * chunk_size) # first chunk + or (HAS_INITSTATES)): + + dA_cs_boundary = 0.0 # default + + if not HAS_INITSTATES: + past_states_ptrs = chunk_states_ptr + ( + offs_m[:, None] * stride_chunk_states_hdim + + offs_n[None, :] * stride_chunk_states_dstate) + else: + + # - this seems repetitve, buts its to help the compiler + if start_idx < pid_c * chunk_size: + past_states_ptrs = chunk_states_ptr + ( + offs_m[:, None] * stride_chunk_states_hdim + + offs_n[None, :] * stride_chunk_states_dstate) + else: + past_states_ptrs = initstates_ptr + ( + pid_b * stride_init_states_batch + + offs_m[:, None] * stride_init_states_hdim + + offs_n[None, :] * stride_init_states_dstate) + + # need to adjust the boundary + if start_idx > pid_c * chunk_size: + dA_cs_boundary = tl.load(dA_cumsum_ptr + + (start_idx - pid_c * chunk_size - + 1) * stride_dA_cs_csize).to( + tl.float32) + + past_states = tl.load(past_states_ptrs, + mask=(offs_m[:, None] < hdim) & + (offs_n[None, :] < dstate), + other=0.0).to(tl.float32) + + scale = tl.exp(dA_cs_last - dA_cs_boundary) + acc += past_states * scale + + states = acc.to(states_ptr.dtype.element_ty) + + states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + + offs_n[None, :] * stride_states_dstate) + c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) + tl.store(states_ptrs, states, mask=c_mask) + + +def _chunk_cumsum_fwd(dt, + A, + chunk_size, + dt_bias=None, + dt_softplus=False, + dt_limit=(0.0, float("inf"))): + batch, seqlen, nheads = dt.shape + assert A.shape == (nheads, ) + if dt_bias is not None: + assert dt_bias.shape == (nheads, ) + nchunks = math.ceil(seqlen / chunk_size) + dt_out = torch.empty(batch, + nheads, + nchunks, + chunk_size, + device=dt.device, + dtype=torch.float32) + dA_cumsum = torch.empty(batch, + nheads, + nchunks, + chunk_size, + device=dt.device, + dtype=torch.float32) + grid_chunk_cs = lambda META: (batch, nchunks, + triton.cdiv(nheads, META['BLOCK_SIZE_H'])) + with torch.cuda.device(dt.device.index): + _chunk_cumsum_fwd_kernel[grid_chunk_cs]( + dt, + A, + dt_bias, + dt_out, + dA_cumsum, + batch, + seqlen, + nheads, + chunk_size, + dt_limit[0], + dt_limit[1], + dt.stride(0), + dt.stride(1), + dt.stride(2), + A.stride(0), + dt_bias.stride(0) if dt_bias is not None else 0, + dt_out.stride(0), + dt_out.stride(2), + dt_out.stride(1), + dt_out.stride(3), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), + dt_softplus, + HAS_DT_BIAS=dt_bias is not None, + BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), + ) + return dA_cumsum, dt_out + + +def _chunk_state_fwd(B, + x, + dt, + dA_cumsum, + seq_idx=None, + states=None, + states_in_fp32=True): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if states is not None: + assert states.shape == (batch, nchunks, nheads, headdim, dstate) + else: + states_dtype = torch.float32 if states_in_fp32 else B.dtype + states = torch.empty((batch, nchunks, nheads, headdim, dstate), + device=x.device, + dtype=states_dtype) + grid = lambda META: ( + triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv( + dstate, META['BLOCK_SIZE_N']), batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _chunk_state_fwd_kernel[grid]( + x, + B, + states, + dt, + dA_cumsum, + seq_idx, + headdim, + dstate, + chunk_size, + batch, + seqlen, + nheads // ngroups, + x.stride(0), + x.stride(1), + x.stride(2), + x.stride(3), + B.stride(0), + B.stride(1), + B.stride(2), + B.stride(-1), + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + states.stride(4), + dt.stride(0), + dt.stride(2), + dt.stride(1), + dt.stride(3), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), + *((seq_idx.stride(0), + seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + HAS_SEQ_IDX=seq_idx is not None, + ) + return states + + +def chunk_state_varlen(B, + x, + dt, + dA_cumsum, + cu_seqlens, + chunk_states, + initial_states=None): + total_seqlen, nheads, headdim = x.shape + _, nchunks, chunk_size = dt.shape + _, ngroups, dstate = B.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + assert nheads % ngroups == 0 + assert B.shape == (total_seqlen, ngroups, dstate) + assert dt.shape == (nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + assert chunk_states.shape == (nchunks, nheads, headdim, dstate) + + if initial_states is not None: + assert initial_states.shape == (batch, nheads, headdim, dstate) + + states = torch.empty(batch, + nheads, + headdim, + dstate, + dtype=chunk_states.dtype, + device=chunk_states.device) + grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton. + cdiv(dstate, META['BLOCK_SIZE_N']), batch, nheads) + with torch.cuda.device(x.device.index): + _chunk_state_varlen_kernel[grid]( + x, + B, + dt, + dA_cumsum, + chunk_states, + cu_seqlens, + states, + initial_states, + headdim, + dstate, + chunk_size, + total_seqlen, + nheads // ngroups, + x.stride(0), + x.stride(1), + x.stride(2), + B.stride(0), + B.stride(1), + B.stride(2), + dt.stride(1), + dt.stride(0), + dt.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + chunk_states.stride(0), + chunk_states.stride(1), + chunk_states.stride(2), + chunk_states.stride(3), + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + *((initial_states.stride(0), initial_states.stride(1), + initial_states.stride(2), + initial_states.stride(3)) if initial_states is not None else + (0, 0, 0, 0)), + HAS_INITSTATES=initial_states is not None) + return states diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py new file mode 100644 index 0000000000000..97cdb70b63cc6 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -0,0 +1,223 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_combined.py + +# ruff: noqa: E501 + +import torch +import triton +from einops import rearrange +from packaging import version + +from .ssd_bmm import _bmm_chunk_fwd +from .ssd_chunk_scan import _chunk_scan_fwd +from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd, + chunk_state_varlen) +from .ssd_state_passing import _state_passing_fwd + +TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') + + +def _mamba_chunk_scan_combined_fwd(x, + dt, + A, + B, + C, + chunk_size, + D=None, + z=None, + dt_bias=None, + initial_states=None, + seq_idx=None, + cu_seqlens=None, + dt_softplus=False, + dt_limit=(0.0, float("inf"))): + batch, seqlen, nheads, headdim = x.shape + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert x.shape == (batch, seqlen, nheads, headdim) + assert dt.shape == (batch, seqlen, nheads) + assert A.shape == (nheads, ) + assert C.shape == B.shape + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads, ) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if x.stride(-1) != 1 and x.stride( + 1) != 1: # Either M or K dimension should be contiguous + x = x.contiguous() + if z is not None and z.stride(-1) != 1 and z.stride( + 1) != 1: # Either M or K dimension should be contiguous + z = z.contiguous() + if D is not None and D.stride(-1) != 1: + D = D.contiguous() + if initial_states is not None: + if cu_seqlens is None: + assert initial_states.shape == (batch, nheads, headdim, dstate) + else: + assert initial_states.shape == (len(cu_seqlens) - 1, nheads, + headdim, dstate) + + # This function executes 5 sub-functions for computing mamba + # - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/ + # which has a minimal implementation to understand the below operations + # - as explained by the blog, mamba is a special case of causal attention + # - the idea is to chunk the attention matrix and compute each + # submatrix separately using different optimizations. + # - see the blog and paper for a visualization of the submatrices + # which we refer to in the comments below + + # 1. Compute chunked cumsum of A * dt + # - here dt may go through a softplus activation + dA_cumsum, dt = _chunk_cumsum_fwd(dt, + A, + chunk_size, + dt_bias=dt_bias, + dt_softplus=dt_softplus, + dt_limit=dt_limit) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + states = _chunk_state_fwd(B, + x, + dt, + dA_cumsum, + seq_idx=seq_idx, + states_in_fp32=True) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + # - for handling chunked prefill, this requires i) initial_states + # ii) seq_idx and iii) has_cu_seqlens to be all specified. + # - When a new seq_idx is detected, we will stop passing the prev_state + # and switch accordingly to the init_state corresponding to the new seq_idx. + # - this will ensure that states will be updated with the rightmost flushed seq_idx + # of the previous chunk. This implies that the first chunk of states is either 0 + # or equal to init_states of the first example. + states, final_states = _state_passing_fwd( + rearrange(states, "... p n -> ... (p n)"), + dA_cumsum[:, :, :, -1], + initial_states=rearrange(initial_states, "... p n -> ... (p n)") + if initial_states is not None else None, + seq_idx=seq_idx, + chunk_size=chunk_size, + out_dtype=C.dtype, + is_cont_batched=cu_seqlens is not None) + states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate) + for t in [states, final_states]) + + # 4. Compute batched matrix multiply for C_j^T B_i terms + CB = _bmm_chunk_fwd(C, + B, + chunk_size, + seq_idx=seq_idx, + output_dtype=torch.float32) + + # 5. Scan and compute the diagonal blocks, taking into + # account past causal states. + # - if initial states are provided, then states information will be + # augmented with initial_states. + # - to do this properly, we need to account for example changes in + # the continuous batch, therefore we introduce pseudo chunks, which is + # a chunk that is split up each time an example changes. + # - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had + # a seq_idx change, in which case we take states information from + # init_states. + out, out_x = _chunk_scan_fwd( + CB, + x, + dt, + dA_cumsum, + C, + states, + D=D, + z=z, + seq_idx=seq_idx, + initial_states=initial_states, + ) + if cu_seqlens is None: + return out, out_x, dt, dA_cumsum, states, final_states + else: + assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" + varlen_states = chunk_state_varlen( + B.squeeze(0), + x.squeeze(0), + dt.squeeze(0), + dA_cumsum.squeeze(0), + cu_seqlens, + states.squeeze(0), + initial_states=initial_states, + ) + return out, out_x, dt, dA_cumsum, states, final_states, varlen_states + + +def mamba_chunk_scan_combined(x, + dt, + A, + B, + C, + chunk_size, + D=None, + z=None, + dt_bias=None, + initial_states=None, + seq_idx=None, + cu_seqlens=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + return_final_states=False, + return_varlen_states=False): + """ + Argument: + x: (batch, seqlen, nheads, headdim) + dt: (batch, seqlen, nheads) + A: (nheads) + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + chunk_size: int + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + dt_bias: (nheads,) + initial_states: (batch, nheads, headdim, dstate) + seq_idx: (batch, seqlen) + cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True + dt_softplus: Whether to apply softplus to dt + Return: + out: (batch, seqlen, nheads, headdim) + """ + + if not return_varlen_states: + cu_seqlens = None + else: + assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True" + out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd( + x, + dt, + A, + B, + C, + chunk_size, + D=D, + z=z, + dt_bias=dt_bias, + initial_states=initial_states, + seq_idx=seq_idx, + cu_seqlens=cu_seqlens, + dt_softplus=dt_softplus, + dt_limit=dt_limit) + if not return_varlen_states: + return out if not return_final_states else (out, final_states) + else: + varlen_states = rest[0] + return (out, + varlen_states) if not return_final_states else (out, + final_states, + varlen_states) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py new file mode 100644 index 0000000000000..d8f87c113f168 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -0,0 +1,207 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_state_passing.py + +# ruff: noqa: E501 + +import torch +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': 64}), + triton.Config({'BLOCK_SIZE': 128}), + triton.Config({'BLOCK_SIZE': 256}), + triton.Config({'BLOCK_SIZE': 512}), + triton.Config({'BLOCK_SIZE': 1024}), + triton.Config({'BLOCK_SIZE': 2048}), + ], + key=['dim'], +) +@triton.jit +def _state_passing_fwd_kernel( + # Pointers to matrices + states_ptr, + out_ptr, + final_states_ptr, + dA_cs_ptr, + initstates_ptr, + seq_idx_ptr, + # Matrix dimensions + dim, + nchunks, + seqlen, + chunk_size, + # Strides + stride_states_batch, + stride_states_chunk, + stride_states_head, + stride_states_dim, + stride_out_batch, + stride_out_chunk, + stride_out_head, + stride_out_dim, + stride_final_states_batch, + stride_final_states_head, + stride_final_states_dim, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_initstates_batch, + stride_initstates_head, + stride_initstates_dim, + stride_seq_idx_batch, + stride_seq_idx_seqlen, + # Meta-parameters + HAS_INITSTATES: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + IS_CONT_BATCHED: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + pid_m = tl.program_id(axis=0) + states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head + dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head + if HAS_INITSTATES: + initstates_ptr += pid_h * stride_initstates_head + if not IS_CONT_BATCHED: + initstates_ptr += pid_b * stride_initstates_batch + + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + states_ptrs = states_ptr + offs_m * stride_states_dim + out_ptrs = out_ptr + offs_m * stride_out_dim + final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim + + # - states will be the past state of the sequence that continues on the current check + if not HAS_INITSTATES: + states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) + else: + initstates_ptr += offs_m * stride_initstates_dim + initstates_ptrs = initstates_ptr + # - for cont batches, for the first chunk mean it will be the first batch's + # init state + states = tl.load(initstates_ptrs, mask=offs_m < dim, + other=0.0).to(tl.float32) + + tl.store(out_ptrs, states, mask=offs_m < dim) + out_ptrs += stride_out_chunk + seq_idx = 0 + for c in range(nchunks): + new_states = tl.load(states_ptrs, mask=offs_m < dim, + other=0.0).to(tl.float32) + dA_cs = tl.load(dA_cs_ptr).to(tl.float32) + scale = tl.exp(dA_cs) + if HAS_SEQ_IDX: + # - the seq to pass forward is the one that is flushed to the right + # boundary. + # - that is given by seq_idx_new below. + seq_idx_new = tl.load(seq_idx_ptr + + (min((c + 1) * chunk_size, seqlen) - 1) * + stride_seq_idx_seqlen) + if HAS_INITSTATES: + if IS_CONT_BATCHED and seq_idx != seq_idx_new: + # this means in the current chunk the rightmost flushed seq + # has changed. + # - so we do not propagate the state from previous chunk + # - but rather we load that sequence's init state + initstates_ptrs = initstates_ptr + seq_idx_new * stride_initstates_batch + + # - update state with seq_idx_new's init state + states = tl.load(initstates_ptrs, + mask=offs_m < dim, + other=0.0).to(tl.float32) + else: + scale = tl.where(seq_idx_new == seq_idx, scale, 0.0) + + seq_idx = seq_idx_new + states = scale * states + new_states + if c < nchunks - 1: + tl.store(out_ptrs, states, mask=offs_m < dim) + else: + tl.store(final_states_ptrs, states, mask=offs_m < dim) + states_ptrs += stride_states_chunk + dA_cs_ptr += stride_dA_cs_chunk + out_ptrs += stride_out_chunk + + +def _state_passing_fwd( + states, + dA_chunk_cumsum, + initial_states=None, + seq_idx=None, + chunk_size=None, + out_dtype=None, + is_cont_batched=False, +): + batch, nchunks, nheads, dim = states.shape + assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) + if initial_states is not None: + if is_cont_batched: + # - if cu_seqlens is provided, then the initial states + # are used for continuous batching. In which case we + # require seq_idx to be provided + assert seq_idx is not None, "" + assert initial_states.shape == (seq_idx.max().item() + 1, nheads, + dim) + else: + # - this is the regular batching case, where initial + # states are used are for each example of the batch. + assert initial_states.shape == (batch, nheads, dim) + + if seq_idx is not None: + assert chunk_size is not None + seqlen = seq_idx.shape[-1] + assert seq_idx.shape == (batch, seqlen) + out_dtype = states.dtype if out_dtype is None else out_dtype + out = torch.empty((batch, nchunks, nheads, dim), + device=states.device, + dtype=out_dtype) + final_states = torch.empty((batch, nheads, dim), + device=states.device, + dtype=torch.float32) + grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads) + with torch.cuda.device(states.device.index): + _state_passing_fwd_kernel[grid]( + states, + out, + final_states, + dA_chunk_cumsum, + initial_states, + seq_idx, + dim, + nchunks, + seqlen if seq_idx is not None else 0, + chunk_size if seq_idx is not None else 0, + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + out.stride(0), + out.stride(1), + out.stride(2), + out.stride(3), + final_states.stride(0), + final_states.stride(1), + final_states.stride(2), + dA_chunk_cumsum.stride(0), + dA_chunk_cumsum.stride(2), + dA_chunk_cumsum.stride(1), + *((initial_states.stride(0), initial_states.stride(1), + initial_states.stride(2)) if initial_states is not None else + (0, 0, 0)), + *((seq_idx.stride(0), + seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + HAS_INITSTATES=initial_states is not None, + HAS_SEQ_IDX=seq_idx is not None, + IS_CONT_BATCHED=is_cont_batched, + ) + return out, final_states diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py new file mode 100644 index 0000000000000..72b74e31b6cc8 --- /dev/null +++ b/vllm/model_executor/models/bamba.py @@ -0,0 +1,592 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Inference-only Bamba model.""" +# Added by the IBM Team, 2024 +from typing import Iterable, List, Optional, Set, Tuple + +import torch +from torch import nn +from transformers import BambaConfig + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.layer import Attention +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import get_pp_group +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba_mixer2 import ( + MambaMixer2, extra_groups_for_head_shards) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.mamba_cache import (MambaCacheManager, + MambaCacheParams) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.utils import LayerBlockType + +from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class BambaMLP(nn.Module): + + def __init__( + self, + config: BambaConfig, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=config.hidden_size, + output_sizes=[config.intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + ) + self.down_proj = RowParallelLinear( + input_size=config.intermediate_size, + output_size=config.hidden_size, + bias=bias, + quant_config=quant_config, + ) + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + +class BambaMixerDecoderLayer(nn.Module): + + def __init__(self, + config: BambaConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.config = config + self.mamba = MambaMixer2(hidden_size= config.hidden_size, + ssm_state_size = config.mamba_d_state, + conv_kernel_size = config.mamba_d_conv, + intermediate_size = config.mamba_expand *\ + config.hidden_size, + use_conv_bias = config.mamba_conv_bias, + use_bias = config.mamba_proj_bias, + n_groups=config.mamba_n_groups, + num_heads=config.mamba_n_heads, + head_dim=config.mamba_d_head, + rms_norm_eps=config.rms_norm_eps, + activation=config.hidden_act, + chunk_size=config.mamba_chunk_size, + quant_config=quant_config) + + self.feed_forward = BambaMLP(config, quant_config=quant_config) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + mamba_cache_params: MambaCacheParams, + sequence_idx: Optional[torch.Tensor] = None, + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.mamba(hidden_states, attn_metadata, + mamba_cache_params, sequence_idx) + # Fully Connected + hidden_states, residual = self.pre_ff_layernorm( + hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +class BambaAttentionDecoderLayer(nn.Module): + + def __init__( + self, + config: BambaConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = config.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + if hasattr(config, "partial_rotary_factor"): + rotary_dim = self.head_dim * config.partial_rotary_factor + elif hasattr(config, "attn_rotary_emb"): + rotary_dim = config.attn_rotary_emb # for backward compatibility + else: + rotary_dim = self.head_dim # default + + self.rotary_emb = get_rope( + head_size=self.head_dim, + rotary_dim=rotary_dim, + max_position=max_position_embeddings, + rope_scaling=rope_scaling, + base=rope_theta, + is_neox_style=True, + dtype=torch.get_default_dtype(), # see impl of get_rope + ) + + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + config.hidden_size, + bias=False, + quant_config=quant_config) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + prefix=f"{prefix}.attn", + ) + + self.feed_forward = BambaMLP(config, quant_config=quant_config) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def self_attention( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + **kwargs, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.self_attention( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + # Fully Connected + hidden_states, residual = self.pre_ff_layernorm( + hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +ALL_DECODER_LAYER_TYPES = { + "attention": BambaAttentionDecoderLayer, + "mamba": BambaMixerDecoderLayer +} + + +class BambaModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + def get_layer(prefix: str): + layer_idx = int(prefix.rsplit(".", 1)[1]) + layer_class = ALL_DECODER_LAYER_TYPES[ + config.layers_block_type[layer_idx]] + return layer_class( + config, + layer_idx, + cache_config, + quant_config=quant_config, + prefix=prefix, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + self.final_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + mamba_cache_params: MambaCacheParams, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + # pass a sequence index tensor, that is required for + # proper continuous batching computation including + # chunked prefill + seq_idx = None + if attn_metadata.num_prefills > 0: + seq_idx = torch.zeros_like(input_ids, dtype=torch.int32) + for i, (srt, end) in enumerate( + zip( + attn_metadata.query_start_loc, + attn_metadata.query_start_loc[1:], + )): + seq_idx[srt:end] = i + seq_idx.unsqueeze_(0) + + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + residual = None + num_attn = 0 + for i in range(len(self.layers)): + layer = self.layers[i] + kv_cache = None + if isinstance(layer, BambaAttentionDecoderLayer): + kv_cache = kv_caches[num_attn] + num_attn += 1 + + layer_mamba_cache_params = None + if isinstance(layer, BambaMixerDecoderLayer): + layer_mamba_cache_params = mamba_cache_params.at_layer_idx( + i - num_attn) + + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + residual=residual, + mamba_cache_params=layer_mamba_cache_params, + sequence_idx=seq_idx, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.final_layernorm(hidden_states, residual) + return hidden_states + + +class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, + IsHybrid): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": ["up_proj", "down_proj"] + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config + assert not cache_config.enable_prefix_caching, \ + "Bamba currently does not support prefix caching" + + self.quant_config = vllm_config.quant_config + + super().__init__() + self.config = config + self.scheduler_config = scheduler_config + self.model = BambaModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + # Used to track and store by the Mamba cache between steps. + self.mamba_cache: Optional[MambaCacheManager] = None + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = get_sampler() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + # follow jamba + if self.scheduler_config is not None and \ + not self.model_config.enforce_eager: + # for compilation + if self.scheduler_config.max_num_seqs > \ + vllm_config.compilation_config.max_capture_size: + self.max_batch_size = \ + vllm_config.compilation_config.max_capture_size + else: + self.max_batch_size = vllm_config.pad_for_cudagraph( + self.scheduler_config.max_num_seqs) + elif self.scheduler_config is not None: + # for eager just take the scheduler_config if avail + self.max_batch_size = self.scheduler_config.max_num_seqs + else: + self.max_batch_size = 8192 + 2 + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs): + if self.mamba_cache is None: + + num_mamba_layers = self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, LayerBlockType.mamba) + + self.mamba_cache = MambaCacheManager( + self.lm_head.weight.dtype, num_mamba_layers, + self.max_batch_size, *self._get_mamba_cache_shape()) + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, mamba_cache_params, + intermediate_tensors, inputs_embeds) + + return hidden_states + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + + def _get_mamba_cache_shape( + self) -> Tuple[Tuple[int, int], Tuple[int, int]]: + world_size = get_tensor_model_parallel_world_size() + hidden_size = self.config.hidden_size + + conv_state_shape, temporal_state_shape = None, None + + intermediate_size = self.config.mamba_expand * hidden_size + + # if n_groups is not divisible by world_size, need to extend the shards + # to ensure all groups needed by a head is sharded along with it + n_groups = (self.config.mamba_n_groups + extra_groups_for_head_shards( + self.config.mamba_n_groups, world_size)) + + # - heads and n_groups are TP-ed + conv_dim = (intermediate_size + + 2 * n_groups * self.config.mamba_d_state) + conv_state_shape = ( + divide(conv_dim, world_size), + self.config.mamba_d_conv - 1, + ) + + # These are not TP-ed as they depend on A, dt_bias, D + # - they are typically small + # e.g., (h_heads, d_head, d_state) = (128, 64, 128) + temporal_state_shape = ( + divide(self.config.mamba_n_heads, world_size), + self.config.mamba_d_head, + self.config.mamba_d_state, + ) + return conv_state_shape, temporal_state_shape + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if "A_log" in name: + name = name.replace("A_log", "A") + + if ".self_attn." in name: + name = name.replace(".self_attn", "") + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index d82c0815213bc..f307f279dad4d 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -455,14 +455,9 @@ def forward(self, self.mamba_cache = MambaCacheManager( self.lm_head.weight.dtype, num_mamba_layers, self.max_batch_size, *self._get_mamba_cache_shape()) - ( - mamba_cache_tensors, - state_indices_tensor, - ) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata, - **kwargs) - mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0], - mamba_cache_tensors[1], - state_indices_tensor) + + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, mamba_cache_params, intermediate_tensors, inputs_embeds) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 5034b334564e8..3bbc219e92a65 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -232,15 +232,7 @@ def forward(self, self.lm_head.weight.dtype, num_mamba_layers, self.max_batch_size, *self._get_mamba_cache_shape()) - ( - mamba_cache_tensors, - state_indices_tensor, - ) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata, - **kwargs) - - mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0], - mamba_cache_tensors[1], - state_indices_tensor) + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) hidden_states = self.backbone(input_ids, positions, attn_metadata, mamba_cache_params, intermediate_tensors, diff --git a/vllm/model_executor/models/mamba_cache.py b/vllm/model_executor/models/mamba_cache.py index 353177f784b2e..ce4197507da03 100644 --- a/vllm/model_executor/models/mamba_cache.py +++ b/vllm/model_executor/models/mamba_cache.py @@ -5,7 +5,6 @@ import torch -from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.utils import PAD_SLOT_ID @@ -42,8 +41,7 @@ def __init__(self, dtype, num_mamba_layers, max_batch_size, self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} self.free_cache_indices = list(range(max_batch_size)) - def current_run_tensors(self, input_ids: torch.Tensor, - attn_metadata: AttentionMetadata, **kwargs): + def current_run_tensors(self, **kwargs) -> MambaCacheParams: """ Return the tensors for the current run's conv and ssm state. """ @@ -66,7 +64,8 @@ def current_run_tensors(self, input_ids: torch.Tensor, (mamba_cache_tensors, state_indices_tensor) = kwargs["seqlen_agnostic_capture_inputs"] - return (mamba_cache_tensors, state_indices_tensor) + return MambaCacheParams(mamba_cache_tensors[0], mamba_cache_tensors[1], + state_indices_tensor) def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): """ diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 3b2a7069efc91..c2d0fae7056c7 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -37,6 +37,7 @@ "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-13b, lower case 'c' in the class name "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), + "BambaForCausalLM": ("bamba", "BambaForCausalLM"), "BloomForCausalLM": ("bloom", "BloomForCausalLM"), # ChatGLMModel supports multimodal "CohereForCausalLM": ("commandr", "CohereForCausalLM"),