Update Vits speaker encoder init

This commit is contained in:
Eren Gölge 2022-03-02 13:20:23 +01:00
parent 27b67b7945
commit c68885b3fd
1 changed files with 9 additions and 11 deletions

View File

@ -654,7 +654,7 @@ class Vits(BaseTTS):
# TODO: make this a function
if self.args.use_speaker_encoder_as_loss:
if self.speaker_manager.speaker_encoder is None and (
not config.speaker_encoder_model_path or not config.speaker_encoder_config_path
not self.args.speaker_encoder_model_path or not self.args.speaker_encoder_config_path
):
raise RuntimeError(
" [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!"
@ -1445,13 +1445,13 @@ class Vits(BaseTTS):
# as it is probably easier for model distribution.
state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k}
# handle fine-tuning from a checkpoint with additional speakers
if hasattr(self, "emb_g") and state["model"]["vits.emb_g.weight"].shape != self.emb_g.weight.shape:
num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["vits.emb_g.weight"].shape[0]
if hasattr(self, "emb_g") and state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape:
num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["emb_g.weight"].shape[0]
print(f" > Loading checkpoint with {num_new_speakers} additional speakers.")
emb_g = state["model"]["vits.emb_g.weight"]
emb_g = state["model"]["emb_g.weight"]
new_row = torch.randn(num_new_speakers, emb_g.shape[1])
emb_g = torch.cat([emb_g, new_row], axis=0)
state["model"]["vits.emb_g.weight"] = emb_g
state["model"]["emb_g.weight"] = emb_g
# load the model weights
self.load_state_dict(state["model"], strict=strict)
@ -1479,14 +1479,12 @@ class Vits(BaseTTS):
tokenizer, new_config = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config, samples)
language_manager = LanguageManager.init_from_config(config)
if config.model_args.speaker_encoder_model_path is not None:
speaker_manager.init_speaker_encoder(config.model_args.speaker_encoder_model_path,
config.model_args.speaker_encoder_config_path)
return Vits(new_config, ap, tokenizer, speaker_manager, language_manager)
##################################
# VITS CHARACTERS
##################################
##################################
# VITS CHARACTERS
##################################