Add new speakers to the vits model

This commit is contained in:
Eren Gölge 2022-01-28 10:22:12 +01:00
parent d5c0e17548
commit f70e4bb8c6
2 changed files with 15 additions and 7 deletions

View File

@ -53,6 +53,10 @@ class CharactersConfig(Coqpit):
"""Defines arguments for the `BaseCharacters` and its subclasses.
Args:
characters_class (str):
Defines the class of the characters used. If None, we pick ```Phonemes``` or ```Graphemes``` based on
the configuration. Defaults to None.
pad (str):
characters in place of empty padding. Defaults to None.
@ -84,6 +88,7 @@ class CharactersConfig(Coqpit):
Sort the characters in alphabetical order. Defaults to True.
"""
characters_class: str = None
pad: str = None
eos: str = None
bos: str = None

View File

@ -649,12 +649,7 @@ class Vits(BaseTTS):
z_p = self.flow(z, y_mask, g=g)
# duration predictor
if self.args.use_mas:
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb)
elif self.args.use_aligner_network:
outputs, attn = self.forward_aligner(outputs, m_p, z_p, x_mask, y_mask, g=g, lang_emb=lang_emb)
outputs["x_lens"] = x_lengths
outputs["y_lens"] = y_lengths
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb)
# expand prior
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
@ -1059,7 +1054,15 @@ class Vits(BaseTTS):
# TODO: consider baking the speaker encoder into the model and call it from there.
# as it is probably easier for model distribution.
state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k}
self.load_state_dict(state["model"])
# handle fine-tuning from a checkpoint with additional speakers
if state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape:
print(" > Loading checkpoint with additional speakers.")
emb_g = state["model"]["emb_g.weight"]
new_row = torch.zeros(1, emb_g.shape[1])
emb_g = torch.cat([emb_g, new_row], axis=0)
state["model"]["emb_g.weight"] = emb_g
self.load_state_dict(state["model"], strict=False)
if eval:
self.eval()
assert not self.training