Open
Description
First of all, thanks for this amazing library!
This is likely not the intended use case of the library, but I noticed (by chance) that the behavior of mamba_split_conv1d_scan_combined
seems different than its reference(?) Python implementation mamba_split_conv1d_scan_ref
when using ngroups > 1
.
import torch
from mamba_ssm.ops.triton.ssd_combined import (
mamba_split_conv1d_scan_ref,
mamba_split_conv1d_scan_combined)
def f(ngroups: int) -> None:
batch_size = 2
seqlen = 2048
dim = 1024 * 2
headdim = 64
dstate = 128
nheads = dim // headdim
conv_dim = 4
chunk_size = 128
dtype = torch.float32
device = "cuda"
zxbcdt = torch.ones(batch_size, seqlen, 2 * dim + 2 * ngroups * dstate + nheads, device=device, dtype=dtype) / 10.
conv1d_weight = torch.ones(dim + 2 * ngroups * dstate, conv_dim, device=device, dtype=dtype) / 10.
conv1d_bias = torch.ones(dim + 2 * ngroups * dstate, device=device, dtype=dtype) / 10.
dt_bias = torch.ones(nheads, device=device, dtype=dtype) / 10.
A = (-torch.exp(torch.zeros(nheads, device=device, dtype=dtype)))
rmsnorm_weight = torch.ones(dim, device=device, dtype=dtype) / 10.
rmsnorm_eps = 1e-5
outproj_weight = torch.ones(dim, dim, device=device, dtype=dtype) / 10.
outproj_bias = torch.ones(dim, device=device, dtype=dtype) / 10.
# make arguments different by heads (similar effects can be observed by setting other arguments too)
D = torch.arange(nheads, device=device, dtype=dtype) * 1000.
out_ref1 = mamba_split_conv1d_scan_combined(
zxbcdt=zxbcdt,
conv1d_weight=conv1d_weight,
conv1d_bias=conv1d_bias,
dt_bias=dt_bias,
A=A,
D=D,
chunk_size=chunk_size,
initial_states=None,
seq_idx=None,
dt_limit=(0.0, float("inf")),
return_final_states=False,
activation="silu",
rmsnorm_weight=rmsnorm_weight,
rmsnorm_eps=rmsnorm_eps,
outproj_weight=outproj_weight,
outproj_bias=outproj_bias,
headdim=headdim,
ngroups=ngroups,
norm_before_gate=False)
out_ref2 = mamba_split_conv1d_scan_ref(
zxbcdt=zxbcdt,
conv1d_weight=conv1d_weight,
conv1d_bias=conv1d_bias,
dt_bias=dt_bias,
A=A,
D=D,
chunk_size=chunk_size,
# initial_states=None,
# seq_idx=None,
dt_limit=(0.0, float("inf")),
# return_final_states=False,
activation="silu",
rmsnorm_weight=rmsnorm_weight,
rmsnorm_eps=rmsnorm_eps,
outproj_weight=outproj_weight,
outproj_bias=outproj_bias,
headdim=headdim,
ngroups=ngroups,
norm_before_gate=False)
print("mamba_split_conv1d_scan_combined")
print(out_ref1[0, 0, 0])
print()
print("mamba_split_conv1d_scan_ref")
print(out_ref2[0, 0, 0])
for ngroups in [1, 2, 4]:
print(f"\n\n------- ngroups = {ngroups} -------")
f(ngroups=ngroups)
The above code will print the following
------- ngroups = 1 -------
mamba_split_conv1d_scan_combined
tensor(17.6950, device='cuda:0')
mamba_split_conv1d_scan_ref
tensor(17.6951, device='cuda:0')
------- ngroups = 2 -------
mamba_split_conv1d_scan_combined
tensor(18.8725, device='cuda:0')
mamba_split_conv1d_scan_ref
tensor(17.6951, device='cuda:0')
------- ngroups = 4 -------
mamba_split_conv1d_scan_combined
tensor(19.5926, device='cuda:0')
mamba_split_conv1d_scan_ref
tensor(17.6951, device='cuda:0')
Did I miss certain arguments, or is having ngroups > 1
simply a not supported feature? Thanks in advance!
Metadata
Metadata
Assignees
Labels
No labels