mirror of https://github.com/coqui-ai/TTS.git
fix: only enable load with weights_only in pytorch>=2.4
Allows moving the minimum Pytorch version back to 2.1
This commit is contained in:
parent
b66c782931
commit
8e66be2c32
|
@ -1,29 +1,33 @@
|
|||
import _codecs
|
||||
import importlib.metadata
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from TTS.config.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.configs.xtts_config import XttsConfig
|
||||
from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
|
||||
|
||||
__version__ = importlib.metadata.version("coqui-tts")
|
||||
|
||||
|
||||
torch.serialization.add_safe_globals([dict, defaultdict, RAdam])
|
||||
if is_pytorch_at_least_2_4():
|
||||
import _codecs
|
||||
from collections import defaultdict
|
||||
|
||||
# Bark
|
||||
torch.serialization.add_safe_globals(
|
||||
[
|
||||
np.core.multiarray.scalar,
|
||||
np.dtype,
|
||||
np.dtypes.Float64DType,
|
||||
_codecs.encode, # TODO: safe by default from Pytorch 2.5
|
||||
]
|
||||
)
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# XTTS
|
||||
torch.serialization.add_safe_globals([BaseDatasetConfig, XttsConfig, XttsAudioConfig, XttsArgs])
|
||||
from TTS.config.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.configs.xtts_config import XttsConfig
|
||||
from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig
|
||||
from TTS.utils.radam import RAdam
|
||||
|
||||
torch.serialization.add_safe_globals([dict, defaultdict, RAdam])
|
||||
|
||||
# Bark
|
||||
torch.serialization.add_safe_globals(
|
||||
[
|
||||
np.core.multiarray.scalar,
|
||||
np.dtype,
|
||||
np.dtypes.Float64DType,
|
||||
_codecs.encode, # TODO: safe by default from Pytorch 2.5
|
||||
]
|
||||
)
|
||||
|
||||
# XTTS
|
||||
torch.serialization.add_safe_globals([BaseDatasetConfig, XttsConfig, XttsAudioConfig, XttsArgs])
|
||||
|
|
|
@ -10,6 +10,7 @@ import tqdm
|
|||
|
||||
from TTS.tts.layers.bark.model import GPT, GPTConfig
|
||||
from TTS.tts.layers.bark.model_fine import FineGPT, FineGPTConfig
|
||||
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
|
||||
|
||||
if (
|
||||
torch.cuda.is_available()
|
||||
|
@ -118,7 +119,7 @@ def load_model(ckpt_path, device, config, model_type="text"):
|
|||
logger.info(f"{model_type} model not found, downloading...")
|
||||
_download(config.REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path, config.CACHE_DIR)
|
||||
|
||||
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
|
||||
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=is_pytorch_at_least_2_4())
|
||||
# this is a hack
|
||||
model_args = checkpoint["model_args"]
|
||||
if "input_vocab_size" not in model_args:
|
||||
|
|
|
@ -9,6 +9,7 @@ import torchaudio
|
|||
from transformers import LogitsWarper
|
||||
|
||||
from TTS.tts.layers.tortoise.xtransformers import ContinuousTransformerWrapper, RelativePositionBias
|
||||
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
|
@ -332,7 +333,7 @@ class TorchMelSpectrogram(nn.Module):
|
|||
self.mel_norm_file = mel_norm_file
|
||||
if self.mel_norm_file is not None:
|
||||
with fsspec.open(self.mel_norm_file) as f:
|
||||
self.mel_norms = torch.load(f, weights_only=True)
|
||||
self.mel_norms = torch.load(f, weights_only=is_pytorch_at_least_2_4())
|
||||
else:
|
||||
self.mel_norms = None
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ import torchaudio
|
|||
from scipy.io.wavfile import read
|
||||
|
||||
from TTS.utils.audio.torch_transforms import TorchSTFT
|
||||
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -124,7 +125,7 @@ def load_voice(voice: str, extra_voice_dirs: List[str] = []):
|
|||
voices = get_voices(extra_voice_dirs)
|
||||
paths = voices[voice]
|
||||
if len(paths) == 1 and paths[0].endswith(".pth"):
|
||||
return None, torch.load(paths[0], weights_only=True)
|
||||
return None, torch.load(paths[0], weights_only=is_pytorch_at_least_2_4())
|
||||
else:
|
||||
conds = []
|
||||
for cond_path in paths:
|
||||
|
|
|
@ -9,6 +9,8 @@ import torch.nn.functional as F
|
|||
import torchaudio
|
||||
from einops import rearrange
|
||||
|
||||
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -46,7 +48,7 @@ def dvae_wav_to_mel(
|
|||
mel = mel_stft(wav)
|
||||
mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||
if mel_norms is None:
|
||||
mel_norms = torch.load(mel_norms_file, map_location=device, weights_only=True)
|
||||
mel_norms = torch.load(mel_norms_file, map_location=device, weights_only=is_pytorch_at_least_2_4())
|
||||
mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1)
|
||||
return mel
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ from torch.nn.utils.parametrizations import weight_norm
|
|||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
from trainer.io import load_fsspec
|
||||
|
||||
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
|
||||
from TTS.vocoder.models.hifigan_generator import get_padding
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -328,7 +329,7 @@ class HifiganGenerator(torch.nn.Module):
|
|||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False, cache=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"), weights_only=True)
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4())
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
|
|
|
@ -19,6 +19,7 @@ from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
|
|||
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig
|
||||
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -91,7 +92,9 @@ class GPTTrainer(BaseTTS):
|
|||
|
||||
# load GPT if available
|
||||
if self.args.gpt_checkpoint:
|
||||
gpt_checkpoint = torch.load(self.args.gpt_checkpoint, map_location=torch.device("cpu"), weights_only=True)
|
||||
gpt_checkpoint = torch.load(
|
||||
self.args.gpt_checkpoint, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4()
|
||||
)
|
||||
# deal with coqui Trainer exported model
|
||||
if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys():
|
||||
logger.info("Coqui Trainer checkpoint detected! Converting it!")
|
||||
|
@ -184,7 +187,9 @@ class GPTTrainer(BaseTTS):
|
|||
|
||||
self.dvae.eval()
|
||||
if self.args.dvae_checkpoint:
|
||||
dvae_checkpoint = torch.load(self.args.dvae_checkpoint, map_location=torch.device("cpu"), weights_only=True)
|
||||
dvae_checkpoint = torch.load(
|
||||
self.args.dvae_checkpoint, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4()
|
||||
)
|
||||
self.dvae.load_state_dict(dvae_checkpoint, strict=False)
|
||||
logger.info("DVAE weights restored from: %s", self.args.dvae_checkpoint)
|
||||
else:
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import torch
|
||||
|
||||
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
|
||||
|
||||
|
||||
class SpeakerManager:
|
||||
def __init__(self, speaker_file_path=None):
|
||||
self.speakers = torch.load(speaker_file_path, weights_only=True)
|
||||
self.speakers = torch.load(speaker_file_path, weights_only=is_pytorch_at_least_2_4())
|
||||
|
||||
@property
|
||||
def name_to_id(self):
|
||||
|
|
|
@ -18,7 +18,7 @@ from TTS.tts.models.base_tts import BaseTTS
|
|||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.generic_utils import format_aux_input
|
||||
from TTS.utils.generic_utils import format_aux_input, is_pytorch_at_least_2_4
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -107,7 +107,7 @@ class NeuralhmmTTS(BaseTTS):
|
|||
|
||||
def preprocess_batch(self, text, text_len, mels, mel_len):
|
||||
if self.mean.item() == 0 or self.std.item() == 1:
|
||||
statistics_dict = torch.load(self.mel_statistics_parameter_path, weights_only=True)
|
||||
statistics_dict = torch.load(self.mel_statistics_parameter_path, weights_only=is_pytorch_at_least_2_4())
|
||||
self.update_mean_std(statistics_dict)
|
||||
|
||||
mels = self.normalize(mels)
|
||||
|
@ -292,7 +292,9 @@ class NeuralhmmTTS(BaseTTS):
|
|||
"Data parameters found for: %s. Loading mel normalization parameters...",
|
||||
trainer.config.mel_statistics_parameter_path,
|
||||
)
|
||||
statistics = torch.load(trainer.config.mel_statistics_parameter_path, weights_only=True)
|
||||
statistics = torch.load(
|
||||
trainer.config.mel_statistics_parameter_path, weights_only=is_pytorch_at_least_2_4()
|
||||
)
|
||||
data_mean, data_std, init_transition_prob = (
|
||||
statistics["mean"],
|
||||
statistics["std"],
|
||||
|
|
|
@ -19,7 +19,7 @@ from TTS.tts.models.base_tts import BaseTTS
|
|||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.generic_utils import format_aux_input
|
||||
from TTS.utils.generic_utils import format_aux_input, is_pytorch_at_least_2_4
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -120,7 +120,7 @@ class Overflow(BaseTTS):
|
|||
|
||||
def preprocess_batch(self, text, text_len, mels, mel_len):
|
||||
if self.mean.item() == 0 or self.std.item() == 1:
|
||||
statistics_dict = torch.load(self.mel_statistics_parameter_path, weights_only=True)
|
||||
statistics_dict = torch.load(self.mel_statistics_parameter_path, weights_only=is_pytorch_at_least_2_4())
|
||||
self.update_mean_std(statistics_dict)
|
||||
|
||||
mels = self.normalize(mels)
|
||||
|
@ -308,7 +308,9 @@ class Overflow(BaseTTS):
|
|||
"Data parameters found for: %s. Loading mel normalization parameters...",
|
||||
trainer.config.mel_statistics_parameter_path,
|
||||
)
|
||||
statistics = torch.load(trainer.config.mel_statistics_parameter_path, weights_only=True)
|
||||
statistics = torch.load(
|
||||
trainer.config.mel_statistics_parameter_path, weights_only=is_pytorch_at_least_2_4()
|
||||
)
|
||||
data_mean, data_std, init_transition_prob = (
|
||||
statistics["mean"],
|
||||
statistics["std"],
|
||||
|
|
|
@ -23,6 +23,7 @@ from TTS.tts.layers.tortoise.tokenizer import VoiceBpeTokenizer
|
|||
from TTS.tts.layers.tortoise.vocoder import VocConf, VocType
|
||||
from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -171,7 +172,11 @@ def classify_audio_clip(clip, model_dir):
|
|||
distribute_zero_label=False,
|
||||
)
|
||||
classifier.load_state_dict(
|
||||
torch.load(os.path.join(model_dir, "classifier.pth"), map_location=torch.device("cpu"), weights_only=True)
|
||||
torch.load(
|
||||
os.path.join(model_dir, "classifier.pth"),
|
||||
map_location=torch.device("cpu"),
|
||||
weights_only=is_pytorch_at_least_2_4(),
|
||||
)
|
||||
)
|
||||
clip = clip.cpu().unsqueeze(0)
|
||||
results = F.softmax(classifier(clip), dim=-1)
|
||||
|
@ -490,7 +495,7 @@ class Tortoise(BaseTTS):
|
|||
torch.load(
|
||||
os.path.join(self.models_dir, "rlg_auto.pth"),
|
||||
map_location=torch.device("cpu"),
|
||||
weights_only=True,
|
||||
weights_only=is_pytorch_at_least_2_4(),
|
||||
)
|
||||
)
|
||||
self.rlg_diffusion = RandomLatentConverter(2048).eval()
|
||||
|
@ -498,7 +503,7 @@ class Tortoise(BaseTTS):
|
|||
torch.load(
|
||||
os.path.join(self.models_dir, "rlg_diffuser.pth"),
|
||||
map_location=torch.device("cpu"),
|
||||
weights_only=True,
|
||||
weights_only=is_pytorch_at_least_2_4(),
|
||||
)
|
||||
)
|
||||
with torch.no_grad():
|
||||
|
@ -885,17 +890,17 @@ class Tortoise(BaseTTS):
|
|||
|
||||
if os.path.exists(ar_path):
|
||||
# remove keys from the checkpoint that are not in the model
|
||||
checkpoint = torch.load(ar_path, map_location=torch.device("cpu"), weights_only=True)
|
||||
checkpoint = torch.load(ar_path, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4())
|
||||
|
||||
# strict set False
|
||||
# due to removed `bias` and `masked_bias` changes in Transformers
|
||||
self.autoregressive.load_state_dict(checkpoint, strict=False)
|
||||
|
||||
if os.path.exists(diff_path):
|
||||
self.diffusion.load_state_dict(torch.load(diff_path, weights_only=True), strict=strict)
|
||||
self.diffusion.load_state_dict(torch.load(diff_path, weights_only=is_pytorch_at_least_2_4()), strict=strict)
|
||||
|
||||
if os.path.exists(clvp_path):
|
||||
self.clvp.load_state_dict(torch.load(clvp_path, weights_only=True), strict=strict)
|
||||
self.clvp.load_state_dict(torch.load(clvp_path, weights_only=is_pytorch_at_least_2_4()), strict=strict)
|
||||
|
||||
if os.path.exists(vocoder_checkpoint_path):
|
||||
self.vocoder.load_state_dict(
|
||||
|
@ -903,7 +908,7 @@ class Tortoise(BaseTTS):
|
|||
torch.load(
|
||||
vocoder_checkpoint_path,
|
||||
map_location=torch.device("cpu"),
|
||||
weights_only=True,
|
||||
weights_only=is_pytorch_at_least_2_4(),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
|
@ -16,6 +16,7 @@ from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
|||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
|
||||
from TTS.tts.layers.xtts.xtts_manager import LanguageManager, SpeakerManager
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -65,7 +66,7 @@ def wav_to_mel_cloning(
|
|||
mel = mel_stft(wav)
|
||||
mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||
if mel_norms is None:
|
||||
mel_norms = torch.load(mel_norms_file, map_location=device, weights_only=True)
|
||||
mel_norms = torch.load(mel_norms_file, map_location=device, weights_only=is_pytorch_at_least_2_4())
|
||||
mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1)
|
||||
return mel
|
||||
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
import torch
|
||||
|
||||
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
|
||||
|
||||
|
||||
def rehash_fairseq_vits_checkpoint(checkpoint_file):
|
||||
chk = torch.load(checkpoint_file, map_location=torch.device("cpu"), weights_only=True)["model"]
|
||||
chk = torch.load(checkpoint_file, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4())["model"]
|
||||
new_chk = {}
|
||||
for k, v in chk.items():
|
||||
if "enc_p." in k:
|
||||
|
|
|
@ -9,6 +9,7 @@ import torch
|
|||
from TTS.config import load_config
|
||||
from TTS.encoder.utils.generic_utils import setup_encoder_model
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
|
||||
|
||||
|
||||
def load_file(path: str):
|
||||
|
@ -17,7 +18,7 @@ def load_file(path: str):
|
|||
return json.load(f)
|
||||
elif path.endswith(".pth"):
|
||||
with fsspec.open(path, "rb") as f:
|
||||
return torch.load(f, map_location="cpu", weights_only=True)
|
||||
return torch.load(f, map_location="cpu", weights_only=is_pytorch_at_least_2_4())
|
||||
else:
|
||||
raise ValueError("Unsupported file type")
|
||||
|
||||
|
|
|
@ -6,6 +6,9 @@ import re
|
|||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
from packaging.version import Version
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -131,3 +134,8 @@ def setup_logger(
|
|||
sh = logging.StreamHandler()
|
||||
sh.setFormatter(formatter)
|
||||
lg.addHandler(sh)
|
||||
|
||||
|
||||
def is_pytorch_at_least_2_4() -> bool:
|
||||
"""Check if the installed Pytorch version is 2.4 or higher."""
|
||||
return Version(torch.__version__) >= Version("2.4")
|
||||
|
|
|
@ -5,6 +5,7 @@ import urllib.request
|
|||
import torch
|
||||
from trainer.io import get_user_data_dir
|
||||
|
||||
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
|
||||
from TTS.vc.modules.freevc.wavlm.wavlm import WavLM, WavLMConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -26,7 +27,7 @@ def get_wavlm(device="cpu"):
|
|||
logger.info("Downloading WavLM model to %s ...", output_path)
|
||||
urllib.request.urlretrieve(model_uri, output_path)
|
||||
|
||||
checkpoint = torch.load(output_path, map_location=torch.device(device), weights_only=True)
|
||||
checkpoint = torch.load(output_path, map_location=torch.device(device), weights_only=is_pytorch_at_least_2_4())
|
||||
cfg = WavLMConfig(checkpoint["cfg"])
|
||||
wavlm = WavLM(cfg).to(device)
|
||||
wavlm.load_state_dict(checkpoint["model"])
|
||||
|
|
|
@ -47,7 +47,7 @@ dependencies = [
|
|||
"numpy>=1.25.2,<2.0",
|
||||
"cython>=3.0.0",
|
||||
"scipy>=1.11.2",
|
||||
"torch>=2.4",
|
||||
"torch>=2.1",
|
||||
"torchaudio",
|
||||
"soundfile>=0.12.0",
|
||||
"librosa>=0.10.1",
|
||||
|
|
Loading…
Reference in New Issue