mirror of https://github.com/coqui-ai/TTS.git
Update Vits speaker encoder init
This commit is contained in:
parent
27b67b7945
commit
c68885b3fd
|
@ -654,7 +654,7 @@ class Vits(BaseTTS):
|
|||
# TODO: make this a function
|
||||
if self.args.use_speaker_encoder_as_loss:
|
||||
if self.speaker_manager.speaker_encoder is None and (
|
||||
not config.speaker_encoder_model_path or not config.speaker_encoder_config_path
|
||||
not self.args.speaker_encoder_model_path or not self.args.speaker_encoder_config_path
|
||||
):
|
||||
raise RuntimeError(
|
||||
" [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!"
|
||||
|
@ -1445,13 +1445,13 @@ class Vits(BaseTTS):
|
|||
# as it is probably easier for model distribution.
|
||||
state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k}
|
||||
# handle fine-tuning from a checkpoint with additional speakers
|
||||
if hasattr(self, "emb_g") and state["model"]["vits.emb_g.weight"].shape != self.emb_g.weight.shape:
|
||||
num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["vits.emb_g.weight"].shape[0]
|
||||
if hasattr(self, "emb_g") and state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape:
|
||||
num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["emb_g.weight"].shape[0]
|
||||
print(f" > Loading checkpoint with {num_new_speakers} additional speakers.")
|
||||
emb_g = state["model"]["vits.emb_g.weight"]
|
||||
emb_g = state["model"]["emb_g.weight"]
|
||||
new_row = torch.randn(num_new_speakers, emb_g.shape[1])
|
||||
emb_g = torch.cat([emb_g, new_row], axis=0)
|
||||
state["model"]["vits.emb_g.weight"] = emb_g
|
||||
state["model"]["emb_g.weight"] = emb_g
|
||||
# load the model weights
|
||||
self.load_state_dict(state["model"], strict=strict)
|
||||
|
||||
|
@ -1479,14 +1479,12 @@ class Vits(BaseTTS):
|
|||
tokenizer, new_config = TTSTokenizer.init_from_config(config)
|
||||
speaker_manager = SpeakerManager.init_from_config(config, samples)
|
||||
language_manager = LanguageManager.init_from_config(config)
|
||||
|
||||
if config.model_args.speaker_encoder_model_path is not None:
|
||||
speaker_manager.init_speaker_encoder(config.model_args.speaker_encoder_model_path,
|
||||
config.model_args.speaker_encoder_config_path)
|
||||
return Vits(new_config, ap, tokenizer, speaker_manager, language_manager)
|
||||
|
||||
|
||||
##################################
|
||||
# VITS CHARACTERS
|
||||
##################################
|
||||
|
||||
|
||||
##################################
|
||||
# VITS CHARACTERS
|
||||
##################################
|
||||
|
|
Loading…
Reference in New Issue