Skip to content

mamba_split_conv1d_scan_combined and mamba_split_conv1d_scan_ref exhibit large numerical differences when ngroups > 1 #647

Open
@HanGuo97

Description

@HanGuo97

First of all, thanks for this amazing library!

This is likely not the intended use case of the library, but I noticed (by chance) that the behavior of mamba_split_conv1d_scan_combined seems different than its reference(?) Python implementation mamba_split_conv1d_scan_ref when using ngroups > 1.

import torch
from mamba_ssm.ops.triton.ssd_combined import (
    mamba_split_conv1d_scan_ref,
    mamba_split_conv1d_scan_combined)

def f(ngroups: int) -> None:
    batch_size = 2
    seqlen = 2048
    dim = 1024 * 2
    headdim = 64
    dstate = 128
    nheads = dim // headdim
    conv_dim = 4
    chunk_size = 128
    dtype = torch.float32
    device = "cuda"

    zxbcdt = torch.ones(batch_size, seqlen, 2 * dim + 2 * ngroups * dstate + nheads, device=device, dtype=dtype) / 10.
    conv1d_weight = torch.ones(dim + 2 * ngroups * dstate, conv_dim, device=device, dtype=dtype) / 10.
    conv1d_bias = torch.ones(dim + 2 * ngroups * dstate, device=device, dtype=dtype) / 10.
    dt_bias = torch.ones(nheads, device=device, dtype=dtype) / 10.
    A = (-torch.exp(torch.zeros(nheads, device=device, dtype=dtype)))
    rmsnorm_weight = torch.ones(dim, device=device, dtype=dtype) / 10.
    rmsnorm_eps = 1e-5
    outproj_weight = torch.ones(dim, dim, device=device, dtype=dtype) / 10.
    outproj_bias = torch.ones(dim, device=device, dtype=dtype) / 10.

    # make arguments different by heads (similar effects can be observed by setting other arguments too)
    D = torch.arange(nheads, device=device, dtype=dtype) * 1000.

    out_ref1 = mamba_split_conv1d_scan_combined(
        zxbcdt=zxbcdt,
        conv1d_weight=conv1d_weight,
        conv1d_bias=conv1d_bias,
        dt_bias=dt_bias,
        A=A,
        D=D,
        chunk_size=chunk_size,
        initial_states=None,
        seq_idx=None,
        dt_limit=(0.0, float("inf")),
        return_final_states=False,
        activation="silu",
        rmsnorm_weight=rmsnorm_weight,
        rmsnorm_eps=rmsnorm_eps,
        outproj_weight=outproj_weight,
        outproj_bias=outproj_bias,
        headdim=headdim,
        ngroups=ngroups,
        norm_before_gate=False)

    out_ref2 = mamba_split_conv1d_scan_ref(
        zxbcdt=zxbcdt,
        conv1d_weight=conv1d_weight,
        conv1d_bias=conv1d_bias,
        dt_bias=dt_bias,
        A=A,
        D=D,
        chunk_size=chunk_size,
        # initial_states=None,
        # seq_idx=None,
        dt_limit=(0.0, float("inf")),
        # return_final_states=False,
        activation="silu",
        rmsnorm_weight=rmsnorm_weight,
        rmsnorm_eps=rmsnorm_eps,
        outproj_weight=outproj_weight,
        outproj_bias=outproj_bias,
        headdim=headdim,
        ngroups=ngroups,
        norm_before_gate=False)

    print("mamba_split_conv1d_scan_combined")
    print(out_ref1[0, 0, 0])
    print()
    print("mamba_split_conv1d_scan_ref")
    print(out_ref2[0, 0, 0])

for ngroups in [1, 2, 4]:
    print(f"\n\n------- ngroups = {ngroups} -------")
    f(ngroups=ngroups)

The above code will print the following

------- ngroups = 1 -------
mamba_split_conv1d_scan_combined
tensor(17.6950, device='cuda:0')

mamba_split_conv1d_scan_ref
tensor(17.6951, device='cuda:0')


------- ngroups = 2 -------
mamba_split_conv1d_scan_combined
tensor(18.8725, device='cuda:0')

mamba_split_conv1d_scan_ref
tensor(17.6951, device='cuda:0')


------- ngroups = 4 -------
mamba_split_conv1d_scan_combined
tensor(19.5926, device='cuda:0')

mamba_split_conv1d_scan_ref
tensor(17.6951, device='cuda:0')

Did I miss certain arguments, or is having ngroups > 1 simply a not supported feature? Thanks in advance!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions