-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for left padding and masking in forward() and generate() #70
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -159,7 +159,7 @@ class MambaInnerFn(torch.autograd.Function): | |
def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, | ||
out_proj_weight, out_proj_bias, | ||
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, | ||
C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1): | ||
C_proj_bias=None, mask=None, delta_softplus=True, checkpoint_lvl=1): | ||
""" | ||
xz: (batch, dim, seqlen) | ||
""" | ||
|
@@ -177,6 +177,8 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh | |
xz = xz.contiguous() | ||
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") | ||
x, z = xz.chunk(2, dim=1) | ||
if mask is not None: | ||
x = x * mask.unsqueeze(1) | ||
conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None | ||
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True) | ||
# We're being very careful here about the layout, to avoid extra transposes. | ||
|
@@ -214,6 +216,8 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh | |
C = C.contiguous() | ||
if D is not None: | ||
D = D.contiguous() | ||
if mask is not None: | ||
conv1d_out = conv1d_out * mask.unsqueeze(1) | ||
out, scan_intermediates, out_z = selective_scan_cuda.fwd( | ||
conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus | ||
) | ||
|
@@ -301,11 +305,11 @@ def mamba_inner_fn( | |
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, | ||
out_proj_weight, out_proj_bias, | ||
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, | ||
C_proj_bias=None, delta_softplus=True | ||
C_proj_bias=None, mask=None, delta_softplus=True | ||
): | ||
return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have some confusion about this line of code. Since There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I might be wrong, but the PR modified Line 162 in the very same file, which is the definition of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you! You are right. I might missed this line of code on L162. Err Msg:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @xtwigs pointed out that one can add
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I have tried this approach but encountered with CUDA OOM, even with much more GPUs and much smaller |
||
out_proj_weight, out_proj_bias, | ||
A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus) | ||
A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, mask, delta_softplus) | ||
|
||
|
||
def mamba_inner_ref( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import torch | ||
from transformers import AutoTokenizer | ||
|
||
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel | ||
|
||
|
||
model = MambaLMHeadModel.from_pretrained('/data/norman_mu/models/mamba-1.4b', use_fast_path=True).to('cuda') | ||
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b') | ||
tokenizer.padding_side = 'left' | ||
tokenizer.pad_token = tokenizer.eos_token | ||
|
||
pad_count = 10 | ||
|
||
# Check prefill logits | ||
input_ids = torch.randint(1, 1000, (1, 1024)).to('cuda') | ||
input_ids_padded = torch.cat([torch.zeros_like(input_ids[:, [0] * pad_count]), input_ids], dim=1) | ||
attention_mask = torch.cat([torch.zeros_like(input_ids[:, [0] * pad_count]), torch.ones_like(input_ids)], dim=1) | ||
|
||
out = model(input_ids_padded).logits.detach().cpu() | ||
out_padded = model(input_ids_padded, attention_mask).logits.detach().cpu() | ||
out_true = model(input_ids).logits.detach().cpu() | ||
|
||
print("max L2 error:", (out_true - out[:, pad_count:]).norm(dim=-1).max()) | ||
print("max L2 errors (padded):", (out_true - out_padded[:, pad_count:]).norm(dim=-1).max()) | ||
|
||
|
||
# Check decoding outputs | ||
text = 'Lorem ipsum dolor sit amet, consectetur adipiscing elit.' | ||
|
||
print("\n\nNo CUDA graph:") | ||
inputs = tokenizer([text], return_tensors='pt').to('cuda') | ||
x = model.generate(inputs.input_ids, max_length=100, temperature=0, cg=False) | ||
print("\nNo pad, no mask:") | ||
print(tokenizer.decode(x[0], skip_special_tokens=True)) | ||
|
||
inputs = tokenizer(['<|endoftext|>' * pad_count + text], return_tensors='pt').to('cuda') | ||
x = model.generate(inputs.input_ids, max_length=100 + pad_count, temperature=0, cg=False) | ||
print("\nPad, no mask:") | ||
print(tokenizer.decode(x[0], skip_special_tokens=True)) | ||
|
||
inputs = tokenizer(['<|endoftext|>' * pad_count + text], return_tensors='pt').to('cuda') | ||
inputs.attention_mask[:, :pad_count] = 0 | ||
x = model.generate(inputs.input_ids, attention_mask=inputs.attention_mask, max_length=100 + pad_count, temperature=0, cg=False) | ||
print("\nPad, mask:") | ||
print(tokenizer.decode(x[0], skip_special_tokens=True)) | ||
|
||
print("\n\nCUDA graph:") | ||
inputs = tokenizer([text], return_tensors='pt').to('cuda') | ||
x = model.generate(inputs.input_ids, max_length=100, temperature=0, cg=True) | ||
print("\nNo pad, no mask:") | ||
print(tokenizer.decode(x[0], skip_special_tokens=True)) | ||
|
||
inputs = tokenizer(['<|endoftext|>' * pad_count + text], return_tensors='pt').to('cuda') | ||
x = model.generate(inputs.input_ids, max_length=100 + pad_count, temperature=0, cg=True) | ||
print("\nPad, no mask:") | ||
print(tokenizer.decode(x[0], skip_special_tokens=True)) | ||
|
||
inputs = tokenizer(['<|endoftext|>' * pad_count + text], return_tensors='pt').to('cuda') | ||
inputs.attention_mask[:, :pad_count] = 0 | ||
x = model.generate(inputs.input_ids, attention_mask=inputs.attention_mask, max_length=100 + pad_count, temperature=0, cg=True) | ||
print("\nPad, mask:") | ||
print(tokenizer.decode(x[0], skip_special_tokens=True)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you make this change with the
kwargs
?