mirror of https://github.com/coqui-ai/TTS.git
fix: define torch safe globals for torch.load
Required for loading some models using torch.load(..., weights_only=True). This is only available from Pytorch 2.4
This commit is contained in:
parent
17ca24c3d6
commit
86b58fb6d9
|
@ -1,3 +1,29 @@
|
||||||
|
import _codecs
|
||||||
import importlib.metadata
|
import importlib.metadata
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from TTS.config.shared_configs import BaseDatasetConfig
|
||||||
|
from TTS.tts.configs.xtts_config import XttsConfig
|
||||||
|
from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig
|
||||||
|
from TTS.utils.radam import RAdam
|
||||||
|
|
||||||
__version__ = importlib.metadata.version("coqui-tts")
|
__version__ = importlib.metadata.version("coqui-tts")
|
||||||
|
|
||||||
|
|
||||||
|
torch.serialization.add_safe_globals([dict, defaultdict, RAdam])
|
||||||
|
|
||||||
|
# Bark
|
||||||
|
torch.serialization.add_safe_globals(
|
||||||
|
[
|
||||||
|
np.core.multiarray.scalar,
|
||||||
|
np.dtype,
|
||||||
|
np.dtypes.Float64DType,
|
||||||
|
_codecs.encode, # TODO: safe by default from Pytorch 2.5
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# XTTS
|
||||||
|
torch.serialization.add_safe_globals([BaseDatasetConfig, XttsConfig, XttsAudioConfig, XttsArgs])
|
||||||
|
|
|
@ -12,9 +12,6 @@ from TTS.config import load_config
|
||||||
from TTS.tts.configs.vits_config import VitsConfig
|
from TTS.tts.configs.vits_config import VitsConfig
|
||||||
from TTS.tts.models import setup_model as setup_tts_model
|
from TTS.tts.models import setup_model as setup_tts_model
|
||||||
from TTS.tts.models.vits import Vits
|
from TTS.tts.models.vits import Vits
|
||||||
|
|
||||||
# pylint: disable=unused-wildcard-import
|
|
||||||
# pylint: disable=wildcard-import
|
|
||||||
from TTS.tts.utils.synthesis import synthesis, transfer_voice, trim_silence
|
from TTS.tts.utils.synthesis import synthesis, transfer_voice, trim_silence
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.audio.numpy_transforms import save_wav
|
from TTS.utils.audio.numpy_transforms import save_wav
|
||||||
|
|
|
@ -44,10 +44,10 @@ classifiers = [
|
||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
# Core
|
# Core
|
||||||
"numpy>=1.24.3,<2.0.0", # TODO: remove upper bound after spacy/thinc release
|
"numpy>=1.25.2,<2.0.0", # TODO: remove upper bound after spacy/thinc release
|
||||||
"cython>=0.29.30",
|
"cython>=0.29.30",
|
||||||
"scipy>=1.11.2",
|
"scipy>=1.11.2",
|
||||||
"torch>=2.1",
|
"torch>=2.4",
|
||||||
"torchaudio",
|
"torchaudio",
|
||||||
"soundfile>=0.12.0",
|
"soundfile>=0.12.0",
|
||||||
"librosa>=0.10.1",
|
"librosa>=0.10.1",
|
||||||
|
|
Loading…
Reference in New Issue