Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HF integration, better discoverability #469

Merged
merged 1 commit into from
Jul 16, 2024

Conversation

NielsRogge
Copy link
Contributor

@NielsRogge NielsRogge commented Jul 15, 2024

Hi @tridao and team,

I wrote a quick PoC to showcase that you can easily have integration with the 🤗 hub so that you can automatically load the various Mamba models using from_pretrained (and push them using push_to_hub), track download numbers for your models (similar to models in the Transformers library), and have nice model cards on a per-model basis. It leverages the PyTorchModelHubMixin class which allows to inherits these methods.

Yes this works for any custom PyTorch models, it's not limited to Transformers/Diffusers :)

Usage is as follows:

from mamba_ssm import Mamba2

model = Mamba2(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor, typically 64 or 128
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape

# optionally, one can push a trained model to the hub
model.push_to_hub("state-spaces/mamba2-demo")

# reload
model = Mamba2.from_pretrained("state-spaces/mamba2-demo")

This means people don't need to manually download a checkpoint first in their local environment, it just loads automatically from the hub. All checkpoints could be hosted as part of the state-spaces organization on the hub or a personal user account if you're interested.

Would you be interested in this integration?

Kind regards,

Niels
ML @ HF

@tridao tridao merged commit 7fb78a5 into state-spaces:main Jul 16, 2024
@tridao
Copy link
Collaborator

tridao commented Jul 16, 2024

This looks very convenient, thanks!

@NielsRogge
Copy link
Contributor Author

Thanks for quickly merging my PR! Would you be interested in trying out?

@tridao
Copy link
Collaborator

tridao commented Jul 16, 2024

Actually I just realize that maybe this Mixin should be part of MambaLMHeadModel (which is the model) instead of Mamba2 (which is a layer within a model)?

class MambaLMHeadModel(nn.Module, GenerationMixin):

@NielsRogge
Copy link
Contributor Author

Thanks that's right, it will be addressed at #471

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants