Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mamba_split_conv1d_scan_combined and mamba_split_conv1d_scan_ref exhibit large numerical differences when ngroups > 1 #647

Open
HanGuo97 opened this issue Dec 15, 2024 · 0 comments

Comments

@HanGuo97
Copy link

First of all, thanks for this amazing library!

This is likely not the intended use case of the library, but I noticed (by chance) that the behavior of mamba_split_conv1d_scan_combined seems different than its reference(?) Python implementation mamba_split_conv1d_scan_ref when using ngroups > 1.

import torch
from mamba_ssm.ops.triton.ssd_combined import (
    mamba_split_conv1d_scan_ref,
    mamba_split_conv1d_scan_combined)

def f(ngroups: int) -> None:
    batch_size = 2
    seqlen = 2048
    dim = 1024 * 2
    headdim = 64
    dstate = 128
    nheads = dim // headdim
    conv_dim = 4
    chunk_size = 128
    dtype = torch.float32
    device = "cuda"

    zxbcdt = torch.ones(batch_size, seqlen, 2 * dim + 2 * ngroups * dstate + nheads, device=device, dtype=dtype) / 10.
    conv1d_weight = torch.ones(dim + 2 * ngroups * dstate, conv_dim, device=device, dtype=dtype) / 10.
    conv1d_bias = torch.ones(dim + 2 * ngroups * dstate, device=device, dtype=dtype) / 10.
    dt_bias = torch.ones(nheads, device=device, dtype=dtype) / 10.
    A = (-torch.exp(torch.zeros(nheads, device=device, dtype=dtype)))
    rmsnorm_weight = torch.ones(dim, device=device, dtype=dtype) / 10.
    rmsnorm_eps = 1e-5
    outproj_weight = torch.ones(dim, dim, device=device, dtype=dtype) / 10.
    outproj_bias = torch.ones(dim, device=device, dtype=dtype) / 10.

    # make arguments different by heads (similar effects can be observed by setting other arguments too)
    D = torch.arange(nheads, device=device, dtype=dtype) * 1000.

    out_ref1 = mamba_split_conv1d_scan_combined(
        zxbcdt=zxbcdt,
        conv1d_weight=conv1d_weight,
        conv1d_bias=conv1d_bias,
        dt_bias=dt_bias,
        A=A,
        D=D,
        chunk_size=chunk_size,
        initial_states=None,
        seq_idx=None,
        dt_limit=(0.0, float("inf")),
        return_final_states=False,
        activation="silu",
        rmsnorm_weight=rmsnorm_weight,
        rmsnorm_eps=rmsnorm_eps,
        outproj_weight=outproj_weight,
        outproj_bias=outproj_bias,
        headdim=headdim,
        ngroups=ngroups,
        norm_before_gate=False)

    out_ref2 = mamba_split_conv1d_scan_ref(
        zxbcdt=zxbcdt,
        conv1d_weight=conv1d_weight,
        conv1d_bias=conv1d_bias,
        dt_bias=dt_bias,
        A=A,
        D=D,
        chunk_size=chunk_size,
        # initial_states=None,
        # seq_idx=None,
        dt_limit=(0.0, float("inf")),
        # return_final_states=False,
        activation="silu",
        rmsnorm_weight=rmsnorm_weight,
        rmsnorm_eps=rmsnorm_eps,
        outproj_weight=outproj_weight,
        outproj_bias=outproj_bias,
        headdim=headdim,
        ngroups=ngroups,
        norm_before_gate=False)

    print("mamba_split_conv1d_scan_combined")
    print(out_ref1[0, 0, 0])
    print()
    print("mamba_split_conv1d_scan_ref")
    print(out_ref2[0, 0, 0])

for ngroups in [1, 2, 4]:
    print(f"\n\n------- ngroups = {ngroups} -------")
    f(ngroups=ngroups)

The above code will print the following

------- ngroups = 1 -------
mamba_split_conv1d_scan_combined
tensor(17.6950, device='cuda:0')

mamba_split_conv1d_scan_ref
tensor(17.6951, device='cuda:0')


------- ngroups = 2 -------
mamba_split_conv1d_scan_combined
tensor(18.8725, device='cuda:0')

mamba_split_conv1d_scan_ref
tensor(17.6951, device='cuda:0')


------- ngroups = 4 -------
mamba_split_conv1d_scan_combined
tensor(19.5926, device='cuda:0')

mamba_split_conv1d_scan_ref
tensor(17.6951, device='cuda:0')

Did I miss certain arguments, or is having ngroups > 1 simply a not supported feature? Thanks in advance!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant