Merge pull request #77 from shavit/71-torch-load

Load weights only in torch.load
This commit is contained in:
Enno Hermann 2024-09-12 23:28:57 +01:00 committed by GitHub
commit e5dd06b3bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 58 additions and 35 deletions

View File

@ -55,6 +55,7 @@ jobs:
- name: Upload coverage data - name: Upload coverage data
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
include-hidden-files: true
name: coverage-data-${{ matrix.subset }}-${{ matrix.python-version }} name: coverage-data-${{ matrix.subset }}-${{ matrix.python-version }}
path: .coverage.* path: .coverage.*
if-no-files-found: ignore if-no-files-found: ignore

View File

@ -48,7 +48,6 @@
"https://coqui.gateway.scarf.sh/hf/bark/fine_2.pt", "https://coqui.gateway.scarf.sh/hf/bark/fine_2.pt",
"https://coqui.gateway.scarf.sh/hf/bark/text_2.pt", "https://coqui.gateway.scarf.sh/hf/bark/text_2.pt",
"https://coqui.gateway.scarf.sh/hf/bark/config.json", "https://coqui.gateway.scarf.sh/hf/bark/config.json",
"https://coqui.gateway.scarf.sh/hf/bark/hubert.pt",
"https://coqui.gateway.scarf.sh/hf/bark/tokenizer.pth" "https://coqui.gateway.scarf.sh/hf/bark/tokenizer.pth"
], ],
"default_vocoder": null, "default_vocoder": null,

View File

@ -1,3 +1,29 @@
import _codecs
import importlib.metadata import importlib.metadata
from collections import defaultdict
import numpy as np
import torch
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig
from TTS.utils.radam import RAdam
__version__ = importlib.metadata.version("coqui-tts") __version__ = importlib.metadata.version("coqui-tts")
torch.serialization.add_safe_globals([dict, defaultdict, RAdam])
# Bark
torch.serialization.add_safe_globals(
[
np.core.multiarray.scalar,
np.dtype,
np.dtypes.Float64DType,
_codecs.encode, # TODO: safe by default from Pytorch 2.5
]
)
# XTTS
torch.serialization.add_safe_globals([BaseDatasetConfig, XttsConfig, XttsAudioConfig, XttsArgs])

View File

@ -96,7 +96,6 @@ class BarkConfig(BaseTTSConfig):
"coarse": os.path.join(self.CACHE_DIR, "coarse_2.pt"), "coarse": os.path.join(self.CACHE_DIR, "coarse_2.pt"),
"fine": os.path.join(self.CACHE_DIR, "fine_2.pt"), "fine": os.path.join(self.CACHE_DIR, "fine_2.pt"),
"hubert_tokenizer": os.path.join(self.CACHE_DIR, "tokenizer.pth"), "hubert_tokenizer": os.path.join(self.CACHE_DIR, "tokenizer.pth"),
"hubert": os.path.join(self.CACHE_DIR, "hubert.pt"),
} }
self.SMALL_REMOTE_MODEL_PATHS = { self.SMALL_REMOTE_MODEL_PATHS = {
"text": {"path": os.path.join(self.REMOTE_BASE_URL, "text.pt")}, "text": {"path": os.path.join(self.REMOTE_BASE_URL, "text.pt")},

View File

@ -40,7 +40,7 @@ class CustomHubert(nn.Module):
or you can train your own or you can train your own
""" """
def __init__(self, checkpoint_path, target_sample_hz=16000, seq_len_multiple_of=None, output_layer=9, device=None): def __init__(self, target_sample_hz=16000, seq_len_multiple_of=None, output_layer=9, device=None):
super().__init__() super().__init__()
self.target_sample_hz = target_sample_hz self.target_sample_hz = target_sample_hz
self.seq_len_multiple_of = seq_len_multiple_of self.seq_len_multiple_of = seq_len_multiple_of

View File

@ -134,10 +134,9 @@ def generate_voice(
# generate semantic tokens # generate semantic tokens
# Load the HuBERT model # Load the HuBERT model
hubert_manager = HubertManager() hubert_manager = HubertManager()
# hubert_manager.make_sure_hubert_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert"])
hubert_manager.make_sure_tokenizer_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert_tokenizer"]) hubert_manager.make_sure_tokenizer_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert_tokenizer"])
hubert_model = CustomHubert(checkpoint_path=model.config.LOCAL_MODEL_PATHS["hubert"]).to(model.device) hubert_model = CustomHubert().to(model.device)
# Load the CustomTokenizer model # Load the CustomTokenizer model
tokenizer = HubertTokenizer.load_from_checkpoint( tokenizer = HubertTokenizer.load_from_checkpoint(

View File

@ -118,7 +118,7 @@ def load_model(ckpt_path, device, config, model_type="text"):
logger.info(f"{model_type} model not found, downloading...") logger.info(f"{model_type} model not found, downloading...")
_download(config.REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path, config.CACHE_DIR) _download(config.REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path, config.CACHE_DIR)
checkpoint = torch.load(ckpt_path, map_location=device) checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
# this is a hack # this is a hack
model_args = checkpoint["model_args"] model_args = checkpoint["model_args"]
if "input_vocab_size" not in model_args: if "input_vocab_size" not in model_args:

View File

@ -332,7 +332,7 @@ class TorchMelSpectrogram(nn.Module):
self.mel_norm_file = mel_norm_file self.mel_norm_file = mel_norm_file
if self.mel_norm_file is not None: if self.mel_norm_file is not None:
with fsspec.open(self.mel_norm_file) as f: with fsspec.open(self.mel_norm_file) as f:
self.mel_norms = torch.load(f) self.mel_norms = torch.load(f, weights_only=True)
else: else:
self.mel_norms = None self.mel_norms = None

View File

@ -124,7 +124,7 @@ def load_voice(voice: str, extra_voice_dirs: List[str] = []):
voices = get_voices(extra_voice_dirs) voices = get_voices(extra_voice_dirs)
paths = voices[voice] paths = voices[voice]
if len(paths) == 1 and paths[0].endswith(".pth"): if len(paths) == 1 and paths[0].endswith(".pth"):
return None, torch.load(paths[0]) return None, torch.load(paths[0], weights_only=True)
else: else:
conds = [] conds = []
for cond_path in paths: for cond_path in paths:

View File

@ -46,7 +46,7 @@ def dvae_wav_to_mel(
mel = mel_stft(wav) mel = mel_stft(wav)
mel = torch.log(torch.clamp(mel, min=1e-5)) mel = torch.log(torch.clamp(mel, min=1e-5))
if mel_norms is None: if mel_norms is None:
mel_norms = torch.load(mel_norms_file, map_location=device) mel_norms = torch.load(mel_norms_file, map_location=device, weights_only=True)
mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1) mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1)
return mel return mel

View File

@ -328,7 +328,7 @@ class HifiganGenerator(torch.nn.Module):
def load_checkpoint( def load_checkpoint(
self, config, checkpoint_path, eval=False, cache=False self, config, checkpoint_path, eval=False, cache=False
): # pylint: disable=unused-argument, redefined-builtin ): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu")) state = torch.load(checkpoint_path, map_location=torch.device("cpu"), weights_only=True)
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
if eval: if eval:
self.eval() self.eval()

View File

@ -91,7 +91,7 @@ class GPTTrainer(BaseTTS):
# load GPT if available # load GPT if available
if self.args.gpt_checkpoint: if self.args.gpt_checkpoint:
gpt_checkpoint = torch.load(self.args.gpt_checkpoint, map_location=torch.device("cpu")) gpt_checkpoint = torch.load(self.args.gpt_checkpoint, map_location=torch.device("cpu"), weights_only=True)
# deal with coqui Trainer exported model # deal with coqui Trainer exported model
if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys(): if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys():
logger.info("Coqui Trainer checkpoint detected! Converting it!") logger.info("Coqui Trainer checkpoint detected! Converting it!")
@ -184,7 +184,7 @@ class GPTTrainer(BaseTTS):
self.dvae.eval() self.dvae.eval()
if self.args.dvae_checkpoint: if self.args.dvae_checkpoint:
dvae_checkpoint = torch.load(self.args.dvae_checkpoint, map_location=torch.device("cpu")) dvae_checkpoint = torch.load(self.args.dvae_checkpoint, map_location=torch.device("cpu"), weights_only=True)
self.dvae.load_state_dict(dvae_checkpoint, strict=False) self.dvae.load_state_dict(dvae_checkpoint, strict=False)
logger.info("DVAE weights restored from: %s", self.args.dvae_checkpoint) logger.info("DVAE weights restored from: %s", self.args.dvae_checkpoint)
else: else:

View File

@ -3,7 +3,7 @@ import torch
class SpeakerManager: class SpeakerManager:
def __init__(self, speaker_file_path=None): def __init__(self, speaker_file_path=None):
self.speakers = torch.load(speaker_file_path) self.speakers = torch.load(speaker_file_path, weights_only=True)
@property @property
def name_to_id(self): def name_to_id(self):

View File

@ -243,7 +243,6 @@ class Bark(BaseTTS):
text_model_path=None, text_model_path=None,
coarse_model_path=None, coarse_model_path=None,
fine_model_path=None, fine_model_path=None,
hubert_model_path=None,
hubert_tokenizer_path=None, hubert_tokenizer_path=None,
eval=False, eval=False,
strict=True, strict=True,
@ -266,13 +265,11 @@ class Bark(BaseTTS):
text_model_path = text_model_path or os.path.join(checkpoint_dir, "text_2.pt") text_model_path = text_model_path or os.path.join(checkpoint_dir, "text_2.pt")
coarse_model_path = coarse_model_path or os.path.join(checkpoint_dir, "coarse_2.pt") coarse_model_path = coarse_model_path or os.path.join(checkpoint_dir, "coarse_2.pt")
fine_model_path = fine_model_path or os.path.join(checkpoint_dir, "fine_2.pt") fine_model_path = fine_model_path or os.path.join(checkpoint_dir, "fine_2.pt")
hubert_model_path = hubert_model_path or os.path.join(checkpoint_dir, "hubert.pt")
hubert_tokenizer_path = hubert_tokenizer_path or os.path.join(checkpoint_dir, "tokenizer.pth") hubert_tokenizer_path = hubert_tokenizer_path or os.path.join(checkpoint_dir, "tokenizer.pth")
self.config.LOCAL_MODEL_PATHS["text"] = text_model_path self.config.LOCAL_MODEL_PATHS["text"] = text_model_path
self.config.LOCAL_MODEL_PATHS["coarse"] = coarse_model_path self.config.LOCAL_MODEL_PATHS["coarse"] = coarse_model_path
self.config.LOCAL_MODEL_PATHS["fine"] = fine_model_path self.config.LOCAL_MODEL_PATHS["fine"] = fine_model_path
self.config.LOCAL_MODEL_PATHS["hubert"] = hubert_model_path
self.config.LOCAL_MODEL_PATHS["hubert_tokenizer"] = hubert_tokenizer_path self.config.LOCAL_MODEL_PATHS["hubert_tokenizer"] = hubert_tokenizer_path
self.load_bark_models() self.load_bark_models()

View File

@ -107,7 +107,7 @@ class NeuralhmmTTS(BaseTTS):
def preprocess_batch(self, text, text_len, mels, mel_len): def preprocess_batch(self, text, text_len, mels, mel_len):
if self.mean.item() == 0 or self.std.item() == 1: if self.mean.item() == 0 or self.std.item() == 1:
statistics_dict = torch.load(self.mel_statistics_parameter_path) statistics_dict = torch.load(self.mel_statistics_parameter_path, weights_only=True)
self.update_mean_std(statistics_dict) self.update_mean_std(statistics_dict)
mels = self.normalize(mels) mels = self.normalize(mels)
@ -292,7 +292,7 @@ class NeuralhmmTTS(BaseTTS):
"Data parameters found for: %s. Loading mel normalization parameters...", "Data parameters found for: %s. Loading mel normalization parameters...",
trainer.config.mel_statistics_parameter_path, trainer.config.mel_statistics_parameter_path,
) )
statistics = torch.load(trainer.config.mel_statistics_parameter_path) statistics = torch.load(trainer.config.mel_statistics_parameter_path, weights_only=True)
data_mean, data_std, init_transition_prob = ( data_mean, data_std, init_transition_prob = (
statistics["mean"], statistics["mean"],
statistics["std"], statistics["std"],

View File

@ -120,7 +120,7 @@ class Overflow(BaseTTS):
def preprocess_batch(self, text, text_len, mels, mel_len): def preprocess_batch(self, text, text_len, mels, mel_len):
if self.mean.item() == 0 or self.std.item() == 1: if self.mean.item() == 0 or self.std.item() == 1:
statistics_dict = torch.load(self.mel_statistics_parameter_path) statistics_dict = torch.load(self.mel_statistics_parameter_path, weights_only=True)
self.update_mean_std(statistics_dict) self.update_mean_std(statistics_dict)
mels = self.normalize(mels) mels = self.normalize(mels)
@ -308,7 +308,7 @@ class Overflow(BaseTTS):
"Data parameters found for: %s. Loading mel normalization parameters...", "Data parameters found for: %s. Loading mel normalization parameters...",
trainer.config.mel_statistics_parameter_path, trainer.config.mel_statistics_parameter_path,
) )
statistics = torch.load(trainer.config.mel_statistics_parameter_path) statistics = torch.load(trainer.config.mel_statistics_parameter_path, weights_only=True)
data_mean, data_std, init_transition_prob = ( data_mean, data_std, init_transition_prob = (
statistics["mean"], statistics["mean"],
statistics["std"], statistics["std"],

View File

@ -170,7 +170,9 @@ def classify_audio_clip(clip, model_dir):
kernel_size=5, kernel_size=5,
distribute_zero_label=False, distribute_zero_label=False,
) )
classifier.load_state_dict(torch.load(os.path.join(model_dir, "classifier.pth"), map_location=torch.device("cpu"))) classifier.load_state_dict(
torch.load(os.path.join(model_dir, "classifier.pth"), map_location=torch.device("cpu"), weights_only=True)
)
clip = clip.cpu().unsqueeze(0) clip = clip.cpu().unsqueeze(0)
results = F.softmax(classifier(clip), dim=-1) results = F.softmax(classifier(clip), dim=-1)
return results[0][0] return results[0][0]
@ -488,6 +490,7 @@ class Tortoise(BaseTTS):
torch.load( torch.load(
os.path.join(self.models_dir, "rlg_auto.pth"), os.path.join(self.models_dir, "rlg_auto.pth"),
map_location=torch.device("cpu"), map_location=torch.device("cpu"),
weights_only=True,
) )
) )
self.rlg_diffusion = RandomLatentConverter(2048).eval() self.rlg_diffusion = RandomLatentConverter(2048).eval()
@ -495,6 +498,7 @@ class Tortoise(BaseTTS):
torch.load( torch.load(
os.path.join(self.models_dir, "rlg_diffuser.pth"), os.path.join(self.models_dir, "rlg_diffuser.pth"),
map_location=torch.device("cpu"), map_location=torch.device("cpu"),
weights_only=True,
) )
) )
with torch.no_grad(): with torch.no_grad():
@ -881,17 +885,17 @@ class Tortoise(BaseTTS):
if os.path.exists(ar_path): if os.path.exists(ar_path):
# remove keys from the checkpoint that are not in the model # remove keys from the checkpoint that are not in the model
checkpoint = torch.load(ar_path, map_location=torch.device("cpu")) checkpoint = torch.load(ar_path, map_location=torch.device("cpu"), weights_only=True)
# strict set False # strict set False
# due to removed `bias` and `masked_bias` changes in Transformers # due to removed `bias` and `masked_bias` changes in Transformers
self.autoregressive.load_state_dict(checkpoint, strict=False) self.autoregressive.load_state_dict(checkpoint, strict=False)
if os.path.exists(diff_path): if os.path.exists(diff_path):
self.diffusion.load_state_dict(torch.load(diff_path), strict=strict) self.diffusion.load_state_dict(torch.load(diff_path, weights_only=True), strict=strict)
if os.path.exists(clvp_path): if os.path.exists(clvp_path):
self.clvp.load_state_dict(torch.load(clvp_path), strict=strict) self.clvp.load_state_dict(torch.load(clvp_path, weights_only=True), strict=strict)
if os.path.exists(vocoder_checkpoint_path): if os.path.exists(vocoder_checkpoint_path):
self.vocoder.load_state_dict( self.vocoder.load_state_dict(
@ -899,6 +903,7 @@ class Tortoise(BaseTTS):
torch.load( torch.load(
vocoder_checkpoint_path, vocoder_checkpoint_path,
map_location=torch.device("cpu"), map_location=torch.device("cpu"),
weights_only=True,
) )
) )
) )

View File

@ -65,7 +65,7 @@ def wav_to_mel_cloning(
mel = mel_stft(wav) mel = mel_stft(wav)
mel = torch.log(torch.clamp(mel, min=1e-5)) mel = torch.log(torch.clamp(mel, min=1e-5))
if mel_norms is None: if mel_norms is None:
mel_norms = torch.load(mel_norms_file, map_location=device) mel_norms = torch.load(mel_norms_file, map_location=device, weights_only=True)
mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1) mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1)
return mel return mel

View File

@ -2,7 +2,7 @@ import torch
def rehash_fairseq_vits_checkpoint(checkpoint_file): def rehash_fairseq_vits_checkpoint(checkpoint_file):
chk = torch.load(checkpoint_file, map_location=torch.device("cpu"))["model"] chk = torch.load(checkpoint_file, map_location=torch.device("cpu"), weights_only=True)["model"]
new_chk = {} new_chk = {}
for k, v in chk.items(): for k, v in chk.items():
if "enc_p." in k: if "enc_p." in k:

View File

@ -17,7 +17,7 @@ def load_file(path: str):
return json.load(f) return json.load(f)
elif path.endswith(".pth"): elif path.endswith(".pth"):
with fsspec.open(path, "rb") as f: with fsspec.open(path, "rb") as f:
return torch.load(f, map_location="cpu") return torch.load(f, map_location="cpu", weights_only=True)
else: else:
raise ValueError("Unsupported file type") raise ValueError("Unsupported file type")

View File

@ -12,9 +12,6 @@ from TTS.config import load_config
from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.models import setup_model as setup_tts_model from TTS.tts.models import setup_model as setup_tts_model
from TTS.tts.models.vits import Vits from TTS.tts.models.vits import Vits
# pylint: disable=unused-wildcard-import
# pylint: disable=wildcard-import
from TTS.tts.utils.synthesis import synthesis, transfer_voice, trim_silence from TTS.tts.utils.synthesis import synthesis, transfer_voice, trim_silence
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import save_wav from TTS.utils.audio.numpy_transforms import save_wav

View File

@ -26,7 +26,7 @@ def get_wavlm(device="cpu"):
logger.info("Downloading WavLM model to %s ...", output_path) logger.info("Downloading WavLM model to %s ...", output_path)
urllib.request.urlretrieve(model_uri, output_path) urllib.request.urlretrieve(model_uri, output_path)
checkpoint = torch.load(output_path, map_location=torch.device(device)) checkpoint = torch.load(output_path, map_location=torch.device(device), weights_only=True)
cfg = WavLMConfig(checkpoint["cfg"]) cfg = WavLMConfig(checkpoint["cfg"])
wavlm = WavLM(cfg).to(device) wavlm = WavLM(cfg).to(device)
wavlm.load_state_dict(checkpoint["model"]) wavlm.load_state_dict(checkpoint["model"])

View File

@ -119,9 +119,9 @@
"\n", "\n",
"# load model state\n", "# load model state\n",
"if use_cuda:\n", "if use_cuda:\n",
" cp = torch.load(MODEL_PATH)\n", " cp = torch.load(MODEL_PATH, weights_only=True)\n",
"else:\n", "else:\n",
" cp = torch.load(MODEL_PATH, map_location=lambda storage, loc: storage)\n", " cp = torch.load(MODEL_PATH, map_location=lambda storage, loc: storage, weights_only=True)\n",
"\n", "\n",
"# load the model\n", "# load the model\n",
"model.load_state_dict(cp['model'])\n", "model.load_state_dict(cp['model'])\n",

View File

@ -44,10 +44,10 @@ classifiers = [
] ]
dependencies = [ dependencies = [
# Core # Core
"numpy>=1.24.3,<2.0.0", # TODO: remove upper bound after spacy/thinc release "numpy>=1.25.2,<2.0.0", # TODO: remove upper bound after spacy/thinc release
"cython>=0.29.30", "cython>=0.29.30",
"scipy>=1.11.2", "scipy>=1.11.2",
"torch>=2.1", "torch>=2.4",
"torchaudio", "torchaudio",
"soundfile>=0.12.0", "soundfile>=0.12.0",
"librosa>=0.10.1", "librosa>=0.10.1",