From 7fb78a54b36a5875262688f1d222a75acd8e50d7 Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Tue, 16 Jul 2024 03:30:39 +0200 Subject: [PATCH] Add mixin (#469) --- mamba_ssm/modules/mamba2.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mamba_ssm/modules/mamba2.py b/mamba_ssm/modules/mamba2.py index 85fd6dec..1859ab0d 100644 --- a/mamba_ssm/modules/mamba2.py +++ b/mamba_ssm/modules/mamba2.py @@ -31,8 +31,10 @@ from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined +from huggingface_hub import PyTorchModelHubMixin -class Mamba2(nn.Module): + +class Mamba2(nn.Module, PyTorchModelHubMixin): def __init__( self, d_model,