fix: define torch safe globals for torch.load

Required for loading some models using torch.load(..., weights_only=True). This
is only available from Pytorch 2.4
This commit is contained in:
Enno Hermann 2024-09-12 17:04:10 +02:00
parent 17ca24c3d6
commit 86b58fb6d9
3 changed files with 28 additions and 5 deletions

View File

@ -1,3 +1,29 @@
import _codecs
import importlib.metadata
from collections import defaultdict
import numpy as np
import torch
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig
from TTS.utils.radam import RAdam
__version__ = importlib.metadata.version("coqui-tts")
torch.serialization.add_safe_globals([dict, defaultdict, RAdam])
# Bark
torch.serialization.add_safe_globals(
[
np.core.multiarray.scalar,
np.dtype,
np.dtypes.Float64DType,
_codecs.encode, # TODO: safe by default from Pytorch 2.5
]
)
# XTTS
torch.serialization.add_safe_globals([BaseDatasetConfig, XttsConfig, XttsAudioConfig, XttsArgs])

View File

@ -12,9 +12,6 @@ from TTS.config import load_config
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.models import setup_model as setup_tts_model
from TTS.tts.models.vits import Vits
# pylint: disable=unused-wildcard-import
# pylint: disable=wildcard-import
from TTS.tts.utils.synthesis import synthesis, transfer_voice, trim_silence
from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import save_wav

View File

@ -44,10 +44,10 @@ classifiers = [
]
dependencies = [
# Core
"numpy>=1.24.3,<2.0.0", # TODO: remove upper bound after spacy/thinc release
"numpy>=1.25.2,<2.0.0", # TODO: remove upper bound after spacy/thinc release
"cython>=0.29.30",
"scipy>=1.11.2",
"torch>=2.1",
"torch>=2.4",
"torchaudio",
"soundfile>=0.12.0",
"librosa>=0.10.1",