Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

could the Vocos.from_pretrained code supoort to load from local? #65

Open
lilyzlt opened this issue Nov 3, 2024 · 0 comments
Open

could the Vocos.from_pretrained code supoort to load from local? #65

lilyzlt opened this issue Nov 3, 2024 · 0 comments

Comments

@lilyzlt
Copy link

lilyzlt commented Nov 3, 2024

Firstly, thanks for your effort~
and when I use Vocos.from_pretrained("charactr/vocos-encodec-24khz"),I found it will Request huggingface everytime even I already download it, could this part code update from only download to support local dir as well?
update this code bellow:
image
to :
def from_pretrained(cls, repo_or_dir: str, revision: Optional[str] = None) -> Vocos:
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub
or from a local directory.
Args:
repo_or_dir (str): Either the repository ID on Hugging Face hub or the path to a local directory containing the model files.
revision (Optional[str], optional): The specific model version to use (if loading from Hugging Face hub).

    Returns:
        Vocos: An instance of the Vocos model loaded with pretrained weights.
    """
    if os.path.isdir(repo_or_dir):
        # Load from local directory
        config_path = os.path.join(repo_or_dir, "config.yaml")
        model_path = os.path.join(repo_or_dir, "pytorch_model.bin")
    else:
        # Load from Hugging Face model hub
        config_path = hf_hub_download(repo_id=repo_or_dir, filename="config.yaml", revision=revision)
        model_path = hf_hub_download(repo_id=repo_or_dir, filename="pytorch_model.bin", revision=revision)

    # Load the model configuration and state dictionary
    model = cls.from_hparams(config_path)
    state_dict = torch.load(model_path, map_location="cpu")

    # Handle special cases for feature extractor if necessary
    if isinstance(model.feature_extractor, EncodecFeatures):
        encodec_parameters = {
            "feature_extractor.encodec." + key: value
            for key, value in model.feature_extractor.encodec.state_dict().items()
        }
        state_dict.update(encodec_parameters)

    # Load the state dictionary into the model
    model.load_state_dict(state_dict)
    model.eval()
    return model

thanks a lot~

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant