diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 036f22f2..8c15103f 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -654,7 +654,7 @@ class Vits(BaseTTS): # TODO: make this a function if self.args.use_speaker_encoder_as_loss: if self.speaker_manager.speaker_encoder is None and ( - not config.speaker_encoder_model_path or not config.speaker_encoder_config_path + not self.args.speaker_encoder_model_path or not self.args.speaker_encoder_config_path ): raise RuntimeError( " [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!" @@ -1445,13 +1445,13 @@ class Vits(BaseTTS): # as it is probably easier for model distribution. state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k} # handle fine-tuning from a checkpoint with additional speakers - if hasattr(self, "emb_g") and state["model"]["vits.emb_g.weight"].shape != self.emb_g.weight.shape: - num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["vits.emb_g.weight"].shape[0] + if hasattr(self, "emb_g") and state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape: + num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["emb_g.weight"].shape[0] print(f" > Loading checkpoint with {num_new_speakers} additional speakers.") - emb_g = state["model"]["vits.emb_g.weight"] + emb_g = state["model"]["emb_g.weight"] new_row = torch.randn(num_new_speakers, emb_g.shape[1]) emb_g = torch.cat([emb_g, new_row], axis=0) - state["model"]["vits.emb_g.weight"] = emb_g + state["model"]["emb_g.weight"] = emb_g # load the model weights self.load_state_dict(state["model"], strict=strict) @@ -1479,14 +1479,12 @@ class Vits(BaseTTS): tokenizer, new_config = TTSTokenizer.init_from_config(config) speaker_manager = SpeakerManager.init_from_config(config, samples) language_manager = LanguageManager.init_from_config(config) + + if config.model_args.speaker_encoder_model_path is not None: + speaker_manager.init_speaker_encoder(config.model_args.speaker_encoder_model_path, + config.model_args.speaker_encoder_config_path) return Vits(new_config, ap, tokenizer, speaker_manager, language_manager) - -################################## -# VITS CHARACTERS -################################## - - ################################## # VITS CHARACTERS ##################################