diff --git a/TTS/tts/configs/shared_configs.py b/TTS/tts/configs/shared_configs.py index 5c271f07..96cf0427 100644 --- a/TTS/tts/configs/shared_configs.py +++ b/TTS/tts/configs/shared_configs.py @@ -53,6 +53,10 @@ class CharactersConfig(Coqpit): """Defines arguments for the `BaseCharacters` and its subclasses. Args: + characters_class (str): + Defines the class of the characters used. If None, we pick ```Phonemes``` or ```Graphemes``` based on + the configuration. Defaults to None. + pad (str): characters in place of empty padding. Defaults to None. @@ -84,6 +88,7 @@ class CharactersConfig(Coqpit): Sort the characters in alphabetical order. Defaults to True. """ + characters_class: str = None pad: str = None eos: str = None bos: str = None diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index cb4499fb..a69b02ba 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -649,12 +649,7 @@ class Vits(BaseTTS): z_p = self.flow(z, y_mask, g=g) # duration predictor - if self.args.use_mas: - outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb) - elif self.args.use_aligner_network: - outputs, attn = self.forward_aligner(outputs, m_p, z_p, x_mask, y_mask, g=g, lang_emb=lang_emb) - outputs["x_lens"] = x_lengths - outputs["y_lens"] = y_lengths + outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb) # expand prior m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) @@ -1059,7 +1054,15 @@ class Vits(BaseTTS): # TODO: consider baking the speaker encoder into the model and call it from there. # as it is probably easier for model distribution. state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k} - self.load_state_dict(state["model"]) + # handle fine-tuning from a checkpoint with additional speakers + if state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape: + print(" > Loading checkpoint with additional speakers.") + emb_g = state["model"]["emb_g.weight"] + new_row = torch.zeros(1, emb_g.shape[1]) + emb_g = torch.cat([emb_g, new_row], axis=0) + state["model"]["emb_g.weight"] = emb_g + + self.load_state_dict(state["model"], strict=False) if eval: self.eval() assert not self.training