From 86b58fb6d992a2250313481580be4a3eab4fa28f Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Thu, 12 Sep 2024 17:04:10 +0200 Subject: [PATCH] fix: define torch safe globals for torch.load Required for loading some models using torch.load(..., weights_only=True). This is only available from Pytorch 2.4 --- TTS/__init__.py | 26 ++++++++++++++++++++++++++ TTS/utils/synthesizer.py | 3 --- pyproject.toml | 4 ++-- 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/TTS/__init__.py b/TTS/__init__.py index 9e87bca4be..64c7369bc0 100644 --- a/TTS/__init__.py +++ b/TTS/__init__.py @@ -1,3 +1,29 @@ +import _codecs import importlib.metadata +from collections import defaultdict + +import numpy as np +import torch + +from TTS.config.shared_configs import BaseDatasetConfig +from TTS.tts.configs.xtts_config import XttsConfig +from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig +from TTS.utils.radam import RAdam __version__ = importlib.metadata.version("coqui-tts") + + +torch.serialization.add_safe_globals([dict, defaultdict, RAdam]) + +# Bark +torch.serialization.add_safe_globals( + [ + np.core.multiarray.scalar, + np.dtype, + np.dtypes.Float64DType, + _codecs.encode, # TODO: safe by default from Pytorch 2.5 + ] +) + +# XTTS +torch.serialization.add_safe_globals([BaseDatasetConfig, XttsConfig, XttsAudioConfig, XttsArgs]) diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 50a7893047..90af4f48f9 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -12,9 +12,6 @@ from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.models import setup_model as setup_tts_model from TTS.tts.models.vits import Vits - -# pylint: disable=unused-wildcard-import -# pylint: disable=wildcard-import from TTS.tts.utils.synthesis import synthesis, transfer_voice, trim_silence from TTS.utils.audio import AudioProcessor from TTS.utils.audio.numpy_transforms import save_wav diff --git a/pyproject.toml b/pyproject.toml index 94ed3a2c36..371d0b10dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,10 +44,10 @@ classifiers = [ ] dependencies = [ # Core - "numpy>=1.24.3,<2.0.0", # TODO: remove upper bound after spacy/thinc release + "numpy>=1.25.2,<2.0.0", # TODO: remove upper bound after spacy/thinc release "cython>=0.29.30", "scipy>=1.11.2", - "torch>=2.1", + "torch>=2.4", "torchaudio", "soundfile>=0.12.0", "librosa>=0.10.1",