mirror of https://github.com/coqui-ai/TTS.git
Add new speakers to the vits model
This commit is contained in:
parent
d5c0e17548
commit
f70e4bb8c6
|
@ -53,6 +53,10 @@ class CharactersConfig(Coqpit):
|
||||||
"""Defines arguments for the `BaseCharacters` and its subclasses.
|
"""Defines arguments for the `BaseCharacters` and its subclasses.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
characters_class (str):
|
||||||
|
Defines the class of the characters used. If None, we pick ```Phonemes``` or ```Graphemes``` based on
|
||||||
|
the configuration. Defaults to None.
|
||||||
|
|
||||||
pad (str):
|
pad (str):
|
||||||
characters in place of empty padding. Defaults to None.
|
characters in place of empty padding. Defaults to None.
|
||||||
|
|
||||||
|
@ -84,6 +88,7 @@ class CharactersConfig(Coqpit):
|
||||||
Sort the characters in alphabetical order. Defaults to True.
|
Sort the characters in alphabetical order. Defaults to True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
characters_class: str = None
|
||||||
pad: str = None
|
pad: str = None
|
||||||
eos: str = None
|
eos: str = None
|
||||||
bos: str = None
|
bos: str = None
|
||||||
|
|
|
@ -649,12 +649,7 @@ class Vits(BaseTTS):
|
||||||
z_p = self.flow(z, y_mask, g=g)
|
z_p = self.flow(z, y_mask, g=g)
|
||||||
|
|
||||||
# duration predictor
|
# duration predictor
|
||||||
if self.args.use_mas:
|
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb)
|
||||||
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb)
|
|
||||||
elif self.args.use_aligner_network:
|
|
||||||
outputs, attn = self.forward_aligner(outputs, m_p, z_p, x_mask, y_mask, g=g, lang_emb=lang_emb)
|
|
||||||
outputs["x_lens"] = x_lengths
|
|
||||||
outputs["y_lens"] = y_lengths
|
|
||||||
|
|
||||||
# expand prior
|
# expand prior
|
||||||
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
||||||
|
@ -1059,7 +1054,15 @@ class Vits(BaseTTS):
|
||||||
# TODO: consider baking the speaker encoder into the model and call it from there.
|
# TODO: consider baking the speaker encoder into the model and call it from there.
|
||||||
# as it is probably easier for model distribution.
|
# as it is probably easier for model distribution.
|
||||||
state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k}
|
state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k}
|
||||||
self.load_state_dict(state["model"])
|
# handle fine-tuning from a checkpoint with additional speakers
|
||||||
|
if state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape:
|
||||||
|
print(" > Loading checkpoint with additional speakers.")
|
||||||
|
emb_g = state["model"]["emb_g.weight"]
|
||||||
|
new_row = torch.zeros(1, emb_g.shape[1])
|
||||||
|
emb_g = torch.cat([emb_g, new_row], axis=0)
|
||||||
|
state["model"]["emb_g.weight"] = emb_g
|
||||||
|
|
||||||
|
self.load_state_dict(state["model"], strict=False)
|
||||||
if eval:
|
if eval:
|
||||||
self.eval()
|
self.eval()
|
||||||
assert not self.training
|
assert not self.training
|
||||||
|
|
Loading…
Reference in New Issue