Skip to content

Commit

Permalink
migrate to tridao's native varlen causal_conv1d kernel for speedup
Browse files Browse the repository at this point in the history
  • Loading branch information
zigzagcai committed Jul 22, 2024
1 parent 6961faa commit 5b024c5
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 167 deletions.
12 changes: 5 additions & 7 deletions mamba_ssm/models/mixer_seq_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from mamba_ssm.modules.mamba2 import Mamba2
from mamba_ssm.modules.mha import MHA
from mamba_ssm.modules.mlp import GatedMLP
from mamba_ssm.modules.mamba_simple import Block as Block_Mamba1
from mamba_ssm.modules.block import Block
from mamba_ssm.utils.generation import GenerationMixin
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
Expand Down Expand Up @@ -71,8 +70,7 @@ def create_block(
mlp_cls = partial(
GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs
)
block_cls = Block if ssm_layer == "Mamba2" else Block_Mamba1
block = block_cls(
block = Block(
d_model,
mixer_cls,
mlp_cls,
Expand Down Expand Up @@ -189,12 +187,12 @@ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs)
for i, layer in enumerate(self.layers)
}

def forward(self, input_ids, cu_seqlens=None, inference_params=None, **mixer_kwargs):
def forward(self, input_ids, inference_params=None, **mixer_kwargs):
hidden_states = self.embedding(input_ids)
residual = None
for layer in self.layers:
hidden_states, residual = layer(
hidden_states, residual, cu_seqlens=cu_seqlens, inference_params=inference_params, **mixer_kwargs
hidden_states, residual, inference_params=inference_params, **mixer_kwargs
)
if not self.fused_add_norm:
residual = (hidden_states + residual) if residual is not None else hidden_states
Expand Down Expand Up @@ -273,12 +271,12 @@ def tie_weights(self):
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)

def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, cu_seqlens=None, **mixer_kwargs):
def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, **mixer_kwargs):
"""
"position_ids" is just to be compatible with Transformer generation. We don't use it.
num_last_tokens: if > 0, only return the logits for the last n tokens
"""
hidden_states = self.backbone(input_ids, cu_seqlens=cu_seqlens, inference_params=inference_params, **mixer_kwargs)
hidden_states = self.backbone(input_ids, inference_params=inference_params, **mixer_kwargs)
if num_last_tokens > 0:
hidden_states = hidden_states[:, -num_last_tokens:]
lm_logits = self.lm_head(hidden_states)
Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/modules/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,4 @@ def forward(
return hidden_states, residual

def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
101 changes: 21 additions & 80 deletions mamba_ssm/modules/mamba_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,12 @@ def __init__(
def forward(self, hidden_states, cu_seqlens=None, inference_params=None):
"""
hidden_states: (B, L, D)
cu_seqlens: one-dimensional tensor representing cumulative start indexes of packed sequence, a.k.a., B=1
cu_seqlens: (Optional) cumulative sum of the sequence lengths, starting from 0 and end with L, and must already be sorted.
Returns: same shape as hidden_states
"""
batch, seqlen, dim = hidden_states.shape
if cu_seqlens is not None:
assert batch == 1 and cu_seqlens.ndimension() == 1, "varlen mamba1 is only supported with B=1"

conv_state, ssm_state = None, None
if inference_params is not None:
Expand Down Expand Up @@ -158,46 +160,41 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None):
self.D.float(),
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
cu_seqlens=cu_seqlens,
d_conv=torch.tensor(self.d_conv)
cu_seqlens=cu_seqlens
)
else:
x, z = xz.chunk(2, dim=1)

# (Optional Step1 for cu_seqlens): Right padding zeros at sequence boundary for con1d ops in cumulative sequences
if cu_seqlens is not None:
padded_x = x
count = 0
for idx in cu_seqlens[1:-1].tolist():
padded_idx = idx + count*(self.d_conv - 1)
padded_x = torch.cat((padded_x[:, :, :padded_idx], torch.zeros(1, x.shape[1], self.d_conv - 1, dtype=x.dtype, device=x.device), padded_x[:, :, padded_idx:]), dim=2)
count = count + 1
x = padded_x

# Compute short convolution
if conv_state is not None:
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
if causal_conv1d_fn is None:
x = self.act(self.conv1d(x)[..., :seqlen])
if cu_seqlens is not None:
# naive pure python implementation of varlen causal_conv1d
for i, s in enumerate(cu_seqlens[1:-1]):
x = torch.cat((x[..., :s + i*(self.d_conv - 1)], torch.zeros_like(x[..., :(self.d_conv - 1)]), x[..., s + i*(self.d_conv - 1):]), dim=2)
mask = torch.cat([torch.cat((torch.full((s,), True, dtype=torch.bool, device=x.device),
torch.full((self.d_conv - 1,), False, dtype=torch.bool, device=x.device)), dim=0)
for s in (cu_seqlens[1:] - cu_seqlens[:-1])], dim=0)
x = self.act(self.conv1d(x)[:, :, mask])
else:
x = self.act(self.conv1d(x)[..., :seqlen])
else:
assert self.activation in ["silu", "swish"]
if cu_seqlens is not None:
seq_idx = torch.cat([torch.full((s,), i, dtype=torch.int32, device=cu_seqlens.device)
for i, s in enumerate(cu_seqlens[1:]-cu_seqlens[:-1])], dim=0).unsqueeze(0)
else:
seq_idx = None
x = causal_conv1d_fn(
x=x,
x=x.transpose(1,2).contiguous().transpose(1,2) if cu_seqlens is not None else x,
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
bias=self.conv1d.bias,
seq_idx=seq_idx,
activation=self.activation,
)

# (Optional Step2 for cu_seqlens): Mask conv1d ops in cumulative sequences
if cu_seqlens is not None:
mask = []
for seq_len in (cu_seqlens[1:] - cu_seqlens[:-1]).tolist():
mask.extend([True] * seq_len)
mask.extend([False] * (self.d_conv - 1))
mask = mask[:-(self.d_conv - 1)]
x = x[:, :, mask]

# We're careful here about the layout, to avoid extra transposes.
# We want dt to have d as the slowest moving dimension
Expand Down Expand Up @@ -317,59 +314,3 @@ def _get_states_from_cache(self, inference_params, batch_size, initialize_states
conv_state.zero_()
ssm_state.zero_()
return conv_state, ssm_state

class Block(nn.Module):
def __init__(
self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
):
"""
Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
This Block has a slightly different structure compared to a regular
prenorm Transformer block.
The standard block is: LN -> MHA/MLP -> Add.
[Ref: https://arxiv.org/abs/2002.04745]
Here we have: Add -> LN -> Mixer, returning both
the hidden_states (output of the mixer) and the residual.
This is purely for performance reasons, as we can fuse add and LayerNorm.
The residual needs to be provided (except for the very first block).
"""
super().__init__()
self.residual_in_fp32 = residual_in_fp32
self.fused_add_norm = fused_add_norm
self.mixer = mixer_cls(dim)
self.norm = norm_cls(dim)
if self.fused_add_norm:
assert RMSNorm is not None, "RMSNorm import fails"
assert isinstance(
self.norm, (nn.LayerNorm, RMSNorm)
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"

def forward(
self, hidden_states: Tensor, residual: Optional[Tensor] = None, cu_seqlens=None, inference_params=None
):
r"""Pass the input through the encoder layer.
Args:
hidden_states: the sequence to the encoder layer (required).
residual: hidden_states = Mixer(LN(residual))
"""
if not self.fused_add_norm:
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
hidden_states, residual = fused_add_norm_fn(
hidden_states,
self.norm.weight,
self.norm.bias,
residual=residual,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
eps=self.norm.eps,
)
hidden_states = self.mixer(hidden_states, cu_seqlens=cu_seqlens, inference_params=inference_params)
return hidden_states, residual

def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
Loading

0 comments on commit 5b024c5

Please sign in to comment.