diff --git a/TTS/tts/layers/bark/load_model.py b/TTS/tts/layers/bark/load_model.py index ce6b757f..7785aab8 100644 --- a/TTS/tts/layers/bark/load_model.py +++ b/TTS/tts/layers/bark/load_model.py @@ -118,7 +118,7 @@ def load_model(ckpt_path, device, config, model_type="text"): logger.info(f"{model_type} model not found, downloading...") _download(config.REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path, config.CACHE_DIR) - checkpoint = torch.load(ckpt_path, map_location=device) + checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True) # this is a hack model_args = checkpoint["model_args"] if "input_vocab_size" not in model_args: diff --git a/TTS/tts/layers/tortoise/arch_utils.py b/TTS/tts/layers/tortoise/arch_utils.py index c79ef31b..f4dbcc80 100644 --- a/TTS/tts/layers/tortoise/arch_utils.py +++ b/TTS/tts/layers/tortoise/arch_utils.py @@ -332,7 +332,7 @@ class TorchMelSpectrogram(nn.Module): self.mel_norm_file = mel_norm_file if self.mel_norm_file is not None: with fsspec.open(self.mel_norm_file) as f: - self.mel_norms = torch.load(f) + self.mel_norms = torch.load(f, weights_only=True) else: self.mel_norms = None diff --git a/TTS/tts/layers/tortoise/audio_utils.py b/TTS/tts/layers/tortoise/audio_utils.py index 0b870122..94c2bae6 100644 --- a/TTS/tts/layers/tortoise/audio_utils.py +++ b/TTS/tts/layers/tortoise/audio_utils.py @@ -124,7 +124,7 @@ def load_voice(voice: str, extra_voice_dirs: List[str] = []): voices = get_voices(extra_voice_dirs) paths = voices[voice] if len(paths) == 1 and paths[0].endswith(".pth"): - return None, torch.load(paths[0]) + return None, torch.load(paths[0], weights_only=True) else: conds = [] for cond_path in paths: diff --git a/TTS/tts/layers/xtts/dvae.py b/TTS/tts/layers/xtts/dvae.py index 4a37307e..58f91785 100644 --- a/TTS/tts/layers/xtts/dvae.py +++ b/TTS/tts/layers/xtts/dvae.py @@ -46,7 +46,7 @@ def dvae_wav_to_mel( mel = mel_stft(wav) mel = torch.log(torch.clamp(mel, min=1e-5)) if mel_norms is None: - mel_norms = torch.load(mel_norms_file, map_location=device) + mel_norms = torch.load(mel_norms_file, map_location=device, weights_only=True) mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1) return mel diff --git a/TTS/tts/layers/xtts/hifigan_decoder.py b/TTS/tts/layers/xtts/hifigan_decoder.py index b6032e55..09bd06df 100644 --- a/TTS/tts/layers/xtts/hifigan_decoder.py +++ b/TTS/tts/layers/xtts/hifigan_decoder.py @@ -328,7 +328,7 @@ class HifiganGenerator(torch.nn.Module): def load_checkpoint( self, config, checkpoint_path, eval=False, cache=False ): # pylint: disable=unused-argument, redefined-builtin - state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + state = torch.load(checkpoint_path, map_location=torch.device("cpu"), weights_only=True) self.load_state_dict(state["model"]) if eval: self.eval() diff --git a/TTS/tts/layers/xtts/trainer/gpt_trainer.py b/TTS/tts/layers/xtts/trainer/gpt_trainer.py index 04d12377..f1aa6f8c 100644 --- a/TTS/tts/layers/xtts/trainer/gpt_trainer.py +++ b/TTS/tts/layers/xtts/trainer/gpt_trainer.py @@ -91,7 +91,7 @@ class GPTTrainer(BaseTTS): # load GPT if available if self.args.gpt_checkpoint: - gpt_checkpoint = torch.load(self.args.gpt_checkpoint, map_location=torch.device("cpu")) + gpt_checkpoint = torch.load(self.args.gpt_checkpoint, map_location=torch.device("cpu"), weights_only=True) # deal with coqui Trainer exported model if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys(): logger.info("Coqui Trainer checkpoint detected! Converting it!") @@ -184,7 +184,7 @@ class GPTTrainer(BaseTTS): self.dvae.eval() if self.args.dvae_checkpoint: - dvae_checkpoint = torch.load(self.args.dvae_checkpoint, map_location=torch.device("cpu")) + dvae_checkpoint = torch.load(self.args.dvae_checkpoint, map_location=torch.device("cpu"), weights_only=True) self.dvae.load_state_dict(dvae_checkpoint, strict=False) logger.info("DVAE weights restored from: %s", self.args.dvae_checkpoint) else: diff --git a/TTS/tts/layers/xtts/xtts_manager.py b/TTS/tts/layers/xtts/xtts_manager.py index 5560e876..5a3c47ae 100644 --- a/TTS/tts/layers/xtts/xtts_manager.py +++ b/TTS/tts/layers/xtts/xtts_manager.py @@ -3,7 +3,7 @@ import torch class SpeakerManager: def __init__(self, speaker_file_path=None): - self.speakers = torch.load(speaker_file_path) + self.speakers = torch.load(speaker_file_path, weights_only=True) @property def name_to_id(self): diff --git a/TTS/tts/models/neuralhmm_tts.py b/TTS/tts/models/neuralhmm_tts.py index 277369e6..49c48c2b 100644 --- a/TTS/tts/models/neuralhmm_tts.py +++ b/TTS/tts/models/neuralhmm_tts.py @@ -107,7 +107,7 @@ class NeuralhmmTTS(BaseTTS): def preprocess_batch(self, text, text_len, mels, mel_len): if self.mean.item() == 0 or self.std.item() == 1: - statistics_dict = torch.load(self.mel_statistics_parameter_path) + statistics_dict = torch.load(self.mel_statistics_parameter_path, weights_only=True) self.update_mean_std(statistics_dict) mels = self.normalize(mels) @@ -292,7 +292,7 @@ class NeuralhmmTTS(BaseTTS): "Data parameters found for: %s. Loading mel normalization parameters...", trainer.config.mel_statistics_parameter_path, ) - statistics = torch.load(trainer.config.mel_statistics_parameter_path) + statistics = torch.load(trainer.config.mel_statistics_parameter_path, weights_only=True) data_mean, data_std, init_transition_prob = ( statistics["mean"], statistics["std"], diff --git a/TTS/tts/models/overflow.py b/TTS/tts/models/overflow.py index b05b7500..4c0f341b 100644 --- a/TTS/tts/models/overflow.py +++ b/TTS/tts/models/overflow.py @@ -120,7 +120,7 @@ class Overflow(BaseTTS): def preprocess_batch(self, text, text_len, mels, mel_len): if self.mean.item() == 0 or self.std.item() == 1: - statistics_dict = torch.load(self.mel_statistics_parameter_path) + statistics_dict = torch.load(self.mel_statistics_parameter_path, weights_only=True) self.update_mean_std(statistics_dict) mels = self.normalize(mels) @@ -308,7 +308,7 @@ class Overflow(BaseTTS): "Data parameters found for: %s. Loading mel normalization parameters...", trainer.config.mel_statistics_parameter_path, ) - statistics = torch.load(trainer.config.mel_statistics_parameter_path) + statistics = torch.load(trainer.config.mel_statistics_parameter_path, weights_only=True) data_mean, data_std, init_transition_prob = ( statistics["mean"], statistics["std"], diff --git a/TTS/tts/models/tortoise.py b/TTS/tts/models/tortoise.py index 17303c69..98e79d0c 100644 --- a/TTS/tts/models/tortoise.py +++ b/TTS/tts/models/tortoise.py @@ -170,7 +170,9 @@ def classify_audio_clip(clip, model_dir): kernel_size=5, distribute_zero_label=False, ) - classifier.load_state_dict(torch.load(os.path.join(model_dir, "classifier.pth"), map_location=torch.device("cpu"))) + classifier.load_state_dict( + torch.load(os.path.join(model_dir, "classifier.pth"), map_location=torch.device("cpu"), weights_only=True) + ) clip = clip.cpu().unsqueeze(0) results = F.softmax(classifier(clip), dim=-1) return results[0][0] @@ -488,6 +490,7 @@ class Tortoise(BaseTTS): torch.load( os.path.join(self.models_dir, "rlg_auto.pth"), map_location=torch.device("cpu"), + weights_only=True, ) ) self.rlg_diffusion = RandomLatentConverter(2048).eval() @@ -495,6 +498,7 @@ class Tortoise(BaseTTS): torch.load( os.path.join(self.models_dir, "rlg_diffuser.pth"), map_location=torch.device("cpu"), + weights_only=True, ) ) with torch.no_grad(): @@ -881,17 +885,17 @@ class Tortoise(BaseTTS): if os.path.exists(ar_path): # remove keys from the checkpoint that are not in the model - checkpoint = torch.load(ar_path, map_location=torch.device("cpu")) + checkpoint = torch.load(ar_path, map_location=torch.device("cpu"), weights_only=True) # strict set False # due to removed `bias` and `masked_bias` changes in Transformers self.autoregressive.load_state_dict(checkpoint, strict=False) if os.path.exists(diff_path): - self.diffusion.load_state_dict(torch.load(diff_path), strict=strict) + self.diffusion.load_state_dict(torch.load(diff_path, weights_only=True), strict=strict) if os.path.exists(clvp_path): - self.clvp.load_state_dict(torch.load(clvp_path), strict=strict) + self.clvp.load_state_dict(torch.load(clvp_path, weights_only=True), strict=strict) if os.path.exists(vocoder_checkpoint_path): self.vocoder.load_state_dict( @@ -899,6 +903,7 @@ class Tortoise(BaseTTS): torch.load( vocoder_checkpoint_path, map_location=torch.device("cpu"), + weights_only=True, ) ) ) diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index ef093442..0b7652e4 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -65,7 +65,7 @@ def wav_to_mel_cloning( mel = mel_stft(wav) mel = torch.log(torch.clamp(mel, min=1e-5)) if mel_norms is None: - mel_norms = torch.load(mel_norms_file, map_location=device) + mel_norms = torch.load(mel_norms_file, map_location=device, weights_only=True) mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1) return mel diff --git a/TTS/tts/utils/fairseq.py b/TTS/tts/utils/fairseq.py index 3d8eec2b..6eb1905d 100644 --- a/TTS/tts/utils/fairseq.py +++ b/TTS/tts/utils/fairseq.py @@ -2,7 +2,7 @@ import torch def rehash_fairseq_vits_checkpoint(checkpoint_file): - chk = torch.load(checkpoint_file, map_location=torch.device("cpu"))["model"] + chk = torch.load(checkpoint_file, map_location=torch.device("cpu"), weights_only=True)["model"] new_chk = {} for k, v in chk.items(): if "enc_p." in k: diff --git a/TTS/tts/utils/managers.py b/TTS/tts/utils/managers.py index 23aa52a8..6f72581c 100644 --- a/TTS/tts/utils/managers.py +++ b/TTS/tts/utils/managers.py @@ -17,7 +17,7 @@ def load_file(path: str): return json.load(f) elif path.endswith(".pth"): with fsspec.open(path, "rb") as f: - return torch.load(f, map_location="cpu") + return torch.load(f, map_location="cpu", weights_only=True) else: raise ValueError("Unsupported file type") diff --git a/TTS/vc/modules/freevc/wavlm/__init__.py b/TTS/vc/modules/freevc/wavlm/__init__.py index 03b2f582..528fade7 100644 --- a/TTS/vc/modules/freevc/wavlm/__init__.py +++ b/TTS/vc/modules/freevc/wavlm/__init__.py @@ -26,7 +26,7 @@ def get_wavlm(device="cpu"): logger.info("Downloading WavLM model to %s ...", output_path) urllib.request.urlretrieve(model_uri, output_path) - checkpoint = torch.load(output_path, map_location=torch.device(device)) + checkpoint = torch.load(output_path, map_location=torch.device(device), weights_only=True) cfg = WavLMConfig(checkpoint["cfg"]) wavlm = WavLM(cfg).to(device) wavlm.load_state_dict(checkpoint["model"]) diff --git a/notebooks/TestAttention.ipynb b/notebooks/TestAttention.ipynb index d85ca103..f52fa028 100644 --- a/notebooks/TestAttention.ipynb +++ b/notebooks/TestAttention.ipynb @@ -119,9 +119,9 @@ "\n", "# load model state\n", "if use_cuda:\n", - " cp = torch.load(MODEL_PATH)\n", + " cp = torch.load(MODEL_PATH, weights_only=True)\n", "else:\n", - " cp = torch.load(MODEL_PATH, map_location=lambda storage, loc: storage)\n", + " cp = torch.load(MODEL_PATH, map_location=lambda storage, loc: storage, weights_only=True)\n", "\n", "# load the model\n", "model.load_state_dict(cp['model'])\n",