mirror of https://github.com/coqui-ai/TTS.git
Add new speakers to the vits model
This commit is contained in:
parent
d5c0e17548
commit
f70e4bb8c6
|
@ -53,6 +53,10 @@ class CharactersConfig(Coqpit):
|
|||
"""Defines arguments for the `BaseCharacters` and its subclasses.
|
||||
|
||||
Args:
|
||||
characters_class (str):
|
||||
Defines the class of the characters used. If None, we pick ```Phonemes``` or ```Graphemes``` based on
|
||||
the configuration. Defaults to None.
|
||||
|
||||
pad (str):
|
||||
characters in place of empty padding. Defaults to None.
|
||||
|
||||
|
@ -84,6 +88,7 @@ class CharactersConfig(Coqpit):
|
|||
Sort the characters in alphabetical order. Defaults to True.
|
||||
"""
|
||||
|
||||
characters_class: str = None
|
||||
pad: str = None
|
||||
eos: str = None
|
||||
bos: str = None
|
||||
|
|
|
@ -649,12 +649,7 @@ class Vits(BaseTTS):
|
|||
z_p = self.flow(z, y_mask, g=g)
|
||||
|
||||
# duration predictor
|
||||
if self.args.use_mas:
|
||||
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb)
|
||||
elif self.args.use_aligner_network:
|
||||
outputs, attn = self.forward_aligner(outputs, m_p, z_p, x_mask, y_mask, g=g, lang_emb=lang_emb)
|
||||
outputs["x_lens"] = x_lengths
|
||||
outputs["y_lens"] = y_lengths
|
||||
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb)
|
||||
|
||||
# expand prior
|
||||
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
||||
|
@ -1059,7 +1054,15 @@ class Vits(BaseTTS):
|
|||
# TODO: consider baking the speaker encoder into the model and call it from there.
|
||||
# as it is probably easier for model distribution.
|
||||
state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k}
|
||||
self.load_state_dict(state["model"])
|
||||
# handle fine-tuning from a checkpoint with additional speakers
|
||||
if state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape:
|
||||
print(" > Loading checkpoint with additional speakers.")
|
||||
emb_g = state["model"]["emb_g.weight"]
|
||||
new_row = torch.zeros(1, emb_g.shape[1])
|
||||
emb_g = torch.cat([emb_g, new_row], axis=0)
|
||||
state["model"]["emb_g.weight"] = emb_g
|
||||
|
||||
self.load_state_dict(state["model"], strict=False)
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
|
|
Loading…
Reference in New Issue