diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 38b51c20b7eae..7cd0614309b85 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -4,9 +4,6 @@ from typing import Iterable, List, Optional, Tuple import torch -from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -from mamba_ssm.ops.selective_scan_interface import selective_scan_fn -from mamba_ssm.ops.triton.selective_state_update import selective_state_update from torch import nn from torch.nn.parameter import Parameter from transformers import MambaConfig @@ -21,6 +18,10 @@ MergedColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + selective_scan_fn, selective_state_update) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler @@ -157,7 +158,7 @@ def mamba_forward(self, (self.conv_kernel_size - hidden_states.shape[-1], 0)) cache_params.conv_state.copy_(conv_states) - hidden_states = causal_conv1d_fn( + hidden_states, _ = causal_conv1d_fn( hidden_states, conv_weights, self.conv1d.bias,