Add multilingual inference support

This commit is contained in:
Edresson 2021-08-14 17:52:00 -03:00 committed by Eren Gölge
parent d0e3647db6
commit 8e83a212fa
5 changed files with 133 additions and 44 deletions

View File

@ -130,13 +130,13 @@ class VitsConfig(BaseTTSConfig):
add_blank: bool = True
# testing
test_sentences: List[str] = field(
test_sentences: List[List] = field(
default_factory=lambda: [
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.",
"I'm sorry Dave. I'm afraid I can't do that.",
"This cake is great. It's so delicious and moist.",
"Prior to November 22, 1963.",
["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent."],
["Be a voice, not an echo."],
["I'm sorry Dave. I'm afraid I can't do that."],
["This cake is great. It's so delicious and moist."],
["Prior to November 22, 1963."],
]
)

View File

@ -107,6 +107,51 @@ class BaseTTS(BaseModel):
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
self.speaker_embedding.weight.data.normal_(0, 0.3)
def get_aux_input(self, **kwargs) -> Dict:
"""Prepare and return `aux_input` used by `forward()`"""
return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None}
def get_aux_input_from_test_setences(self, sentence_info):
if hasattr(self.config, "model_args"):
config = self.config.model_args
else:
config = self.config
# extract speaker and language info
text, speaker_name, style_wav, language_name = None, None, None, None
if isinstance(sentence_info, list):
if len(sentence_info) == 1:
text = sentence_info[0]
elif len(sentence_info) == 2:
text, speaker_name = sentence_info
elif len(sentence_info) == 3:
text, speaker_name, style_wav = sentence_info
elif len(sentence_info) == 4:
text, speaker_name, style_wav, language_name = sentence_info
else:
text = sentence_info
# get speaker id/d_vector
speaker_id, d_vector, language_id = None, None, None
if hasattr(self, "speaker_manager") and config.use_speaker_embedding:
if config.use_d_vector_file:
if speaker_name is None:
d_vector = self.speaker_manager.get_random_d_vector()
else:
d_vector = self.speaker_manager.get_d_vector_by_speaker(speaker_name)
else:
if speaker_name is None:
speaker_id = self.speaker_manager.get_random_speaker_id()
else:
speaker_id = self.speaker_manager.speaker_ids[speaker_name]
# get language id
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
language_id = self.language_manager.language_id_mapping[language_name]
return {"text": text, "speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector, "language_id": language_id}
def format_batch(self, batch: Dict) -> Dict:
"""Generic batch formatting for `TTSDataset`.

View File

@ -399,8 +399,14 @@ class Vits(BaseTTS):
sid = sid.unsqueeze_(0)
if "d_vectors" in aux_input and aux_input["d_vectors"] is not None:
g = F.normalize(aux_input["d_vectors"]).unsqueeze(-1)
if g.ndim == 2:
g = g.unsqueeze_(0)
if "language_ids" in aux_input and aux_input["language_ids"] is not None:
lid = aux_input["language_ids"]
if lid.ndim == 0:
lid = lid.unsqueeze_(0)
return sid, g, lid
def get_aux_input(self, aux_input: Dict):
@ -437,9 +443,8 @@ class Vits(BaseTTS):
"""
outputs = {}
sid, g, lid = self._set_cond_input(aux_input)
# speaker embedding
if self.num_speakers > 1 and sid is not None and not self.use_d_vector:
if self.args.use_speaker_embedding and sid is not None and not self.use_d_vector:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
# language embedding
@ -521,11 +526,11 @@ class Vits(BaseTTS):
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
# speaker embedding
if self.num_speakers > 0 and sid:
if self.args.use_speaker_embedding and sid is not None and not self.use_d_vector:
g = self.emb_g(sid).unsqueeze(-1)
# language embedding
if self.args.use_language_embedding:
if self.args.use_language_embedding and lid is not None:
lang_emb = self.emb_l(lid).unsqueeze(-1)
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
@ -713,29 +718,29 @@ class Vits(BaseTTS):
test_audios = {}
test_figures = {}
test_sentences = self.config.test_sentences
if hasattr(self, "speaker_manager"):
aux_inputs = self.speaker_manager.get_random_speaker_aux_input()
else:
aux_inputs = self.get_aux_input()
for idx, sen in enumerate(test_sentences):
for idx, s_info in enumerate(test_sentences):
try:
aux_inputs = self.get_aux_input_from_test_setences(s_info)
wav, alignment, _, _ = synthesis(
self,
aux_inputs["text"],
self.config,
"cuda" in str(next(self.parameters()).device),
ap,
speaker_id=aux_inputs["speaker_id"],
d_vector=aux_inputs["d_vector"],
style_wav=aux_inputs["style_wav"],
language_id=aux_inputs["language_id"],
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
use_griffin_lim=True,
do_trim_silence=False,
).values()
wav, alignment, _, _ = synthesis(
self,
sen,
self.config,
"cuda" in str(next(self.parameters()).device),
ap,
speaker_id=aux_inputs["speaker_id"],
d_vector=aux_inputs["d_vector"],
style_wav=aux_inputs["style_wav"],
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
use_griffin_lim=True,
do_trim_silence=False,
).values()
test_audios["{}-audio".format(idx)] = wav
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False)
test_audios["{}-audio".format(idx)] = wav
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False)
except: # pylint: disable=bare-except
print(" !! Error creating Test Sentence -", idx)
return test_figures, test_audios
def get_optimizer(self) -> List:
@ -832,3 +837,5 @@ class Vits(BaseTTS):
if eval:
self.eval()
assert not self.training

View File

@ -193,6 +193,20 @@ class SpeakerManager:
"""
return [x["embedding"] for x in self.d_vectors.values() if x["name"] == speaker_idx]
def get_d_vector_by_speaker(self, speaker_idx: str) -> np.ndarray:
"""Get a d_vector of a speaker.
Args:
speaker_idx (str): Target speaker ID.
Returns:
np.ndarray: d_vector.
"""
for x in self.d_vectors.values():
if x["name"] == speaker_idx:
return x["embedding"]
return None
def get_mean_d_vector(self, speaker_idx: str, num_samples: int = None, randomize: bool = False) -> np.ndarray:
"""Get mean d_vector of a speaker ID.
@ -215,14 +229,31 @@ class SpeakerManager:
d_vectors = np.stack(d_vectors[:num_samples]).mean(0)
return d_vectors
def get_random_speaker_aux_input(self) -> Dict:
if self.d_vectors:
return {"speaker_id": None, "style_wav": None, "d_vector": self.d_vectors[random.choices(list(self.d_vectors.keys()))[0]]["embedding"]}
def get_random_speaker_id(self) -> Any:
"""Get a random d_vector.
Args:
Returns:
np.ndarray: d_vector.
"""
if self.speaker_ids:
return {"speaker_id": self.speaker_ids[random.choices(list(self.speaker_ids.keys()))[0]], "style_wav": None, "d_vector": None}
return self.speaker_ids[random.choices(list(self.speaker_ids.keys()))[0]]
return {"speaker_id": None, "style_wav": None, "d_vector": None}
return None
def get_random_d_vector(self) -> Any:
"""Get a random D ID.
Args:
Returns:
np.ndarray: d_vector.
"""
if self.d_vectors:
return self.d_vectors[random.choices(list(self.d_vectors.keys()))[0]]["embedding"]
return None
def get_speakers(self) -> List:
return self.speaker_ids

View File

@ -71,6 +71,7 @@ def run_model_torch(
speaker_id: int = None,
style_mel: torch.Tensor = None,
d_vector: torch.Tensor = None,
language_id: torch.Tensor = None,
) -> Dict:
"""Run a torch model for inference. It does not support batch inference.
@ -96,6 +97,7 @@ def run_model_torch(
"speaker_ids": speaker_id,
"d_vectors": d_vector,
"style_mel": style_mel,
"language_ids": language_id,
},
)
return outputs
@ -160,13 +162,13 @@ def inv_spectrogram(postnet_output, ap, CONFIG):
return wav
def speaker_id_to_torch(speaker_id, cuda=False):
if speaker_id is not None:
speaker_id = np.asarray(speaker_id)
speaker_id = torch.from_numpy(speaker_id)
def id_to_torch(aux_id, cuda=False):
if aux_id is not None:
aux_id = np.asarray(aux_id)
aux_id = torch.from_numpy(aux_id)
if cuda:
return speaker_id.cuda()
return speaker_id
return aux_id.cuda()
return aux_id
def embedding_to_torch(d_vector, cuda=False):
@ -208,6 +210,7 @@ def synthesis(
use_griffin_lim=False,
do_trim_silence=False,
d_vector=None,
language_id=None,
backend="torch",
):
"""Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to
@ -262,11 +265,14 @@ def synthesis(
# pass tensors to backend
if backend == "torch":
if speaker_id is not None:
speaker_id = speaker_id_to_torch(speaker_id, cuda=use_cuda)
speaker_id = id_to_torch(speaker_id, cuda=use_cuda)
if d_vector is not None:
d_vector = embedding_to_torch(d_vector, cuda=use_cuda)
if language_id is not None:
language_id = id_to_torch(language_id, cuda=use_cuda)
if not isinstance(style_mel, dict):
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda)
@ -278,7 +284,7 @@ def synthesis(
text_inputs = tf.expand_dims(text_inputs, 0)
# synthesize voice
if backend == "torch":
outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, d_vector=d_vector)
outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, d_vector=d_vector, language_id=language_id)
model_outputs = outputs["model_outputs"]
model_outputs = model_outputs[0].data.cpu().numpy()
alignments = outputs["alignments"]