From 9f75b2bb500c462875d2c51ea8b7998f37b8e49b Mon Sep 17 00:00:00 2001 From: Norman Mu Date: Wed, 20 Dec 2023 07:36:25 +0000 Subject: [PATCH] Add support for left padding and masking in forward() and generate() --- mamba_ssm/models/config_mamba.py | 1 + mamba_ssm/models/mixer_seq_simple.py | 18 ++++--- mamba_ssm/modules/mamba_simple.py | 14 +++-- mamba_ssm/ops/selective_scan_interface.py | 10 ++-- mamba_ssm/utils/generation.py | 16 ++++-- tests/test_padding.py | 62 +++++++++++++++++++++++ 6 files changed, 103 insertions(+), 18 deletions(-) create mode 100644 tests/test_padding.py diff --git a/mamba_ssm/models/config_mamba.py b/mamba_ssm/models/config_mamba.py index ffd31abc..97813f80 100644 --- a/mamba_ssm/models/config_mamba.py +++ b/mamba_ssm/models/config_mamba.py @@ -12,3 +12,4 @@ class MambaConfig: residual_in_fp32: bool = True fused_add_norm: bool = True pad_vocab_size_multiple: int = 8 + use_fast_path: bool = True diff --git a/mamba_ssm/models/mixer_seq_simple.py b/mamba_ssm/models/mixer_seq_simple.py index 5b3ddfcf..9cd1cbab 100644 --- a/mamba_ssm/models/mixer_seq_simple.py +++ b/mamba_ssm/models/mixer_seq_simple.py @@ -31,11 +31,12 @@ def create_block( layer_idx=None, device=None, dtype=None, + use_fast_path=True, ): if ssm_cfg is None: ssm_cfg = {} factory_kwargs = {"device": device, "dtype": dtype} - mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs) + mixer_cls = partial(Mamba, layer_idx=layer_idx, use_fast_path=use_fast_path, **ssm_cfg, **factory_kwargs) norm_cls = partial( nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs ) @@ -97,6 +98,7 @@ def __init__( residual_in_fp32=False, device=None, dtype=None, + use_fast_path=True, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -124,6 +126,7 @@ def __init__( residual_in_fp32=residual_in_fp32, fused_add_norm=fused_add_norm, layer_idx=i, + use_fast_path=use_fast_path, **factory_kwargs, ) for i in range(n_layer) @@ -148,12 +151,12 @@ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs) for i, layer in enumerate(self.layers) } - def forward(self, input_ids, inference_params=None): + def forward(self, input_ids, mask=None, inference_params=None): hidden_states = self.embedding(input_ids) residual = None for layer in self.layers: hidden_states, residual = layer( - hidden_states, residual, inference_params=inference_params + hidden_states, residual, mask=mask, inference_params=inference_params ) if not self.fused_add_norm: residual = (hidden_states + residual) if residual is not None else hidden_states @@ -205,6 +208,7 @@ def __init__( initializer_cfg=initializer_cfg, fused_add_norm=fused_add_norm, residual_in_fp32=residual_in_fp32, + use_fast_path=config.use_fast_path, **factory_kwargs, ) self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs) @@ -225,12 +229,12 @@ def tie_weights(self): def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) - def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0): + def forward(self, input_ids, attention_mask=None, position_ids=None, inference_params=None, num_last_tokens=0): """ "position_ids" is just to be compatible with Transformer generation. We don't use it. num_last_tokens: if > 0, only return the logits for the last n tokens """ - hidden_states = self.backbone(input_ids, inference_params=inference_params) + hidden_states = self.backbone(input_ids, mask=attention_mask, inference_params=inference_params) if num_last_tokens > 0: hidden_states = hidden_states[:, -num_last_tokens:] lm_logits = self.lm_head(hidden_states) @@ -240,8 +244,8 @@ def forward(self, input_ids, position_ids=None, inference_params=None, num_last_ @classmethod def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs): config_data = load_config_hf(pretrained_model_name) - config = MambaConfig(**config_data) - model = cls(config, device=device, dtype=dtype, **kwargs) + config = MambaConfig(**config_data, **kwargs) + model = cls(config, device=device, dtype=dtype) model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)) return model diff --git a/mamba_ssm/modules/mamba_simple.py b/mamba_ssm/modules/mamba_simple.py index 98d97a57..f1968195 100644 --- a/mamba_ssm/modules/mamba_simple.py +++ b/mamba_ssm/modules/mamba_simple.py @@ -116,7 +116,7 @@ def __init__( self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) - def forward(self, hidden_states, inference_params=None): + def forward(self, hidden_states, mask=None, inference_params=None): """ hidden_states: (B, L, D) Returns: same shape as hidden_states @@ -156,10 +156,15 @@ def forward(self, hidden_states, inference_params=None): None, # input-dependent C self.D.float(), delta_bias=self.dt_proj.bias.float(), + mask=mask, delta_softplus=True, ) else: x, z = xz.chunk(2, dim=1) + + if mask is not None: + x = x * mask.unsqueeze(1) + # Compute short convolution if conv_state is not None: # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv @@ -176,6 +181,9 @@ def forward(self, hidden_states, inference_params=None): activation=self.activation, ) + if mask is not None: + x = x * mask.unsqueeze(1) + # We're careful here about the layout, to avoid extra transposes. # We want dt to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. @@ -322,7 +330,7 @@ def __init__( ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" def forward( - self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None + self, hidden_states: Tensor, residual: Optional[Tensor] = None, mask: Optional[Tensor] = None, inference_params=None ): r"""Pass the input through the encoder layer. @@ -346,7 +354,7 @@ def forward( residual_in_fp32=self.residual_in_fp32, eps=self.norm.eps, ) - hidden_states = self.mixer(hidden_states, inference_params=inference_params) + hidden_states = self.mixer(hidden_states, mask=mask, inference_params=inference_params) return hidden_states, residual def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index b8f14dd0..35143adf 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -159,7 +159,7 @@ class MambaInnerFn(torch.autograd.Function): def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1): + C_proj_bias=None, mask=None, delta_softplus=True, checkpoint_lvl=1): """ xz: (batch, dim, seqlen) """ @@ -177,6 +177,8 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh xz = xz.contiguous() conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") x, z = xz.chunk(2, dim=1) + if mask is not None: + x = x * mask.unsqueeze(1) conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True) # We're being very careful here about the layout, to avoid extra transposes. @@ -214,6 +216,8 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh C = C.contiguous() if D is not None: D = D.contiguous() + if mask is not None: + conv1d_out = conv1d_out * mask.unsqueeze(1) out, scan_intermediates, out_z = selective_scan_cuda.fwd( conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus ) @@ -301,11 +305,11 @@ def mamba_inner_fn( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True + C_proj_bias=None, mask=None, delta_softplus=True ): return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, - A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus) + A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, mask, delta_softplus) def mamba_inner_ref( diff --git a/mamba_ssm/utils/generation.py b/mamba_ssm/utils/generation.py index 460ff7b1..3b3593f9 100644 --- a/mamba_ssm/utils/generation.py +++ b/mamba_ssm/utils/generation.py @@ -108,6 +108,7 @@ def decode( input_ids, model, max_length, + attention_mask=None, top_k=1, top_p=0.0, temperature=1.0, @@ -171,10 +172,11 @@ def get_logits(input_ids, inference_params): position_ids=position_ids, inference_params=inference_params, num_last_tokens=1, + attention_mask=attention_mask, # mask is not used in incremental step() calls, so don't update ).logits.squeeze(dim=1) else: logits = model._decoding_cache.run( - input_ids, position_ids, inference_params.seqlen_offset + input_ids, attention_mask, position_ids, inference_params.seqlen_offset ).squeeze(dim=1) return logits[..., :vocab_size] if vocab_size is not None else logits @@ -234,6 +236,7 @@ def generate( self, input_ids, max_length, + attention_mask=None, top_k=1, top_p=0.0, temperature=1.0, @@ -242,7 +245,7 @@ def generate( **kwargs, ): output = decode( - input_ids, self, max_length, top_k=top_k, top_p=top_p, temperature=temperature, **kwargs + input_ids, self, max_length, attention_mask=attention_mask, top_k=top_k, top_p=top_p, temperature=temperature, **kwargs ) if not output_scores: output.scores = None @@ -312,9 +315,9 @@ def update_graph_cache( n_warmups=n_warmups, ) - def dispatch(input_ids, position_ids, seqlen): + def dispatch(input_ids, attention_mask, position_ids, seqlen): batch_size, decoding_seqlen = input_ids.shape[:2] - return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen) + return cache.callables[batch_size, decoding_seqlen](input_ids, attention_mask, position_ids, seqlen) cache.run = dispatch cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing @@ -326,6 +329,7 @@ def capture_graph( ): device = next(iter(model.parameters())).device input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) + attention_mask = torch.full((batch_size, decoding_seqlen), 1, dtype=torch.long, device=device) position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) seqlen_offset_og = inference_params.seqlen_offset inference_params.seqlen_offset = max_seqlen - decoding_seqlen @@ -338,6 +342,7 @@ def capture_graph( for _ in range(n_warmups): logits = model( input_ids, + attention_mask=attention_mask, position_ids=position_ids, inference_params=inference_params, num_last_tokens=decoding_seqlen, @@ -355,12 +360,13 @@ def capture_graph( with torch.cuda.graph(graph, pool=mempool): logits = model( input_ids, + attention_mask=attention_mask, position_ids=position_ids, inference_params=inference_params, num_last_tokens=decoding_seqlen, ).logits - def run(new_input_ids, new_position_ids, seqlen): + def run(new_input_ids, attention_mask, new_position_ids, seqlen): inference_params.lengths_per_sample[:] = seqlen input_ids.copy_(new_input_ids) position_ids.copy_(new_position_ids) diff --git a/tests/test_padding.py b/tests/test_padding.py new file mode 100644 index 00000000..60d71857 --- /dev/null +++ b/tests/test_padding.py @@ -0,0 +1,62 @@ +import torch +from transformers import AutoTokenizer + +from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel + + +model = MambaLMHeadModel.from_pretrained('/data/norman_mu/models/mamba-1.4b', use_fast_path=True).to('cuda') +tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b') +tokenizer.padding_side = 'left' +tokenizer.pad_token = tokenizer.eos_token + +pad_count = 10 + +# Check prefill logits +input_ids = torch.randint(1, 1000, (1, 1024)).to('cuda') +input_ids_padded = torch.cat([torch.zeros_like(input_ids[:, [0] * pad_count]), input_ids], dim=1) +attention_mask = torch.cat([torch.zeros_like(input_ids[:, [0] * pad_count]), torch.ones_like(input_ids)], dim=1) + +out = model(input_ids_padded).logits.detach().cpu() +out_padded = model(input_ids_padded, attention_mask).logits.detach().cpu() +out_true = model(input_ids).logits.detach().cpu() + +print("max L2 error:", (out_true - out[:, pad_count:]).norm(dim=-1).max()) +print("max L2 errors (padded):", (out_true - out_padded[:, pad_count:]).norm(dim=-1).max()) + + +# Check decoding outputs +text = 'Lorem ipsum dolor sit amet, consectetur adipiscing elit.' + +print("\n\nNo CUDA graph:") +inputs = tokenizer([text], return_tensors='pt').to('cuda') +x = model.generate(inputs.input_ids, max_length=100, temperature=0, cg=False) +print("\nNo pad, no mask:") +print(tokenizer.decode(x[0], skip_special_tokens=True)) + +inputs = tokenizer(['<|endoftext|>' * pad_count + text], return_tensors='pt').to('cuda') +x = model.generate(inputs.input_ids, max_length=100 + pad_count, temperature=0, cg=False) +print("\nPad, no mask:") +print(tokenizer.decode(x[0], skip_special_tokens=True)) + +inputs = tokenizer(['<|endoftext|>' * pad_count + text], return_tensors='pt').to('cuda') +inputs.attention_mask[:, :pad_count] = 0 +x = model.generate(inputs.input_ids, attention_mask=inputs.attention_mask, max_length=100 + pad_count, temperature=0, cg=False) +print("\nPad, mask:") +print(tokenizer.decode(x[0], skip_special_tokens=True)) + +print("\n\nCUDA graph:") +inputs = tokenizer([text], return_tensors='pt').to('cuda') +x = model.generate(inputs.input_ids, max_length=100, temperature=0, cg=True) +print("\nNo pad, no mask:") +print(tokenizer.decode(x[0], skip_special_tokens=True)) + +inputs = tokenizer(['<|endoftext|>' * pad_count + text], return_tensors='pt').to('cuda') +x = model.generate(inputs.input_ids, max_length=100 + pad_count, temperature=0, cg=True) +print("\nPad, no mask:") +print(tokenizer.decode(x[0], skip_special_tokens=True)) + +inputs = tokenizer(['<|endoftext|>' * pad_count + text], return_tensors='pt').to('cuda') +inputs.attention_mask[:, :pad_count] = 0 +x = model.generate(inputs.input_ids, attention_mask=inputs.attention_mask, max_length=100 + pad_count, temperature=0, cg=True) +print("\nPad, mask:") +print(tokenizer.decode(x[0], skip_special_tokens=True))