Skip to content

Commit

Permalink
fix: define torch safe globals for torch.load
Browse files Browse the repository at this point in the history
Required for loading some models using torch.load(..., weights_only=True). This
is only available from Pytorch 2.4
  • Loading branch information
eginhard committed Sep 12, 2024
1 parent 17ca24c commit 86b58fb
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 5 deletions.
26 changes: 26 additions & 0 deletions TTS/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,29 @@
import _codecs
import importlib.metadata
from collections import defaultdict

import numpy as np
import torch

from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig
from TTS.utils.radam import RAdam

__version__ = importlib.metadata.version("coqui-tts")


torch.serialization.add_safe_globals([dict, defaultdict, RAdam])

# Bark
torch.serialization.add_safe_globals(
[
np.core.multiarray.scalar,
np.dtype,
np.dtypes.Float64DType,
_codecs.encode, # TODO: safe by default from Pytorch 2.5
]
)

# XTTS
torch.serialization.add_safe_globals([BaseDatasetConfig, XttsConfig, XttsAudioConfig, XttsArgs])
3 changes: 0 additions & 3 deletions TTS/utils/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.models import setup_model as setup_tts_model
from TTS.tts.models.vits import Vits

# pylint: disable=unused-wildcard-import
# pylint: disable=wildcard-import
from TTS.tts.utils.synthesis import synthesis, transfer_voice, trim_silence
from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import save_wav
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ classifiers = [
]
dependencies = [
# Core
"numpy>=1.24.3,<2.0.0", # TODO: remove upper bound after spacy/thinc release
"numpy>=1.25.2,<2.0.0", # TODO: remove upper bound after spacy/thinc release
"cython>=0.29.30",
"scipy>=1.11.2",
"torch>=2.1",
"torch>=2.4",
"torchaudio",
"soundfile>=0.12.0",
"librosa>=0.10.1",
Expand Down

0 comments on commit 86b58fb

Please sign in to comment.