From 8e83a212fa91de6816909a3fb174e7978b1f2655 Mon Sep 17 00:00:00 2001 From: Edresson Date: Sat, 14 Aug 2021 17:52:00 -0300 Subject: [PATCH] Add multilingual inference support --- TTS/tts/configs/vits_config.py | 12 +++---- TTS/tts/models/base_tts.py | 45 +++++++++++++++++++++++++++ TTS/tts/models/vits.py | 57 +++++++++++++++++++--------------- TTS/tts/utils/speakers.py | 41 +++++++++++++++++++++--- TTS/tts/utils/synthesis.py | 22 ++++++++----- 5 files changed, 133 insertions(+), 44 deletions(-) diff --git a/TTS/tts/configs/vits_config.py b/TTS/tts/configs/vits_config.py index d490e6e6..3e031f02 100644 --- a/TTS/tts/configs/vits_config.py +++ b/TTS/tts/configs/vits_config.py @@ -130,13 +130,13 @@ class VitsConfig(BaseTTSConfig): add_blank: bool = True # testing - test_sentences: List[str] = field( + test_sentences: List[List] = field( default_factory=lambda: [ - "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", - "Be a voice, not an echo.", - "I'm sorry Dave. I'm afraid I can't do that.", - "This cake is great. It's so delicious and moist.", - "Prior to November 22, 1963.", + ["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent."], + ["Be a voice, not an echo."], + ["I'm sorry Dave. I'm afraid I can't do that."], + ["This cake is great. It's so delicious and moist."], + ["Prior to November 22, 1963."], ] ) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index c0d2bd78..bfa6df14 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -107,6 +107,51 @@ class BaseTTS(BaseModel): self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) self.speaker_embedding.weight.data.normal_(0, 0.3) + def get_aux_input(self, **kwargs) -> Dict: + """Prepare and return `aux_input` used by `forward()`""" + return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None} + + def get_aux_input_from_test_setences(self, sentence_info): + if hasattr(self.config, "model_args"): + config = self.config.model_args + else: + config = self.config + + # extract speaker and language info + text, speaker_name, style_wav, language_name = None, None, None, None + + if isinstance(sentence_info, list): + if len(sentence_info) == 1: + text = sentence_info[0] + elif len(sentence_info) == 2: + text, speaker_name = sentence_info + elif len(sentence_info) == 3: + text, speaker_name, style_wav = sentence_info + elif len(sentence_info) == 4: + text, speaker_name, style_wav, language_name = sentence_info + else: + text = sentence_info + + # get speaker id/d_vector + speaker_id, d_vector, language_id = None, None, None + if hasattr(self, "speaker_manager") and config.use_speaker_embedding: + if config.use_d_vector_file: + if speaker_name is None: + d_vector = self.speaker_manager.get_random_d_vector() + else: + d_vector = self.speaker_manager.get_d_vector_by_speaker(speaker_name) + else: + if speaker_name is None: + speaker_id = self.speaker_manager.get_random_speaker_id() + else: + speaker_id = self.speaker_manager.speaker_ids[speaker_name] + + # get language id + if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None: + language_id = self.language_manager.language_id_mapping[language_name] + + return {"text": text, "speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector, "language_id": language_id} + def format_batch(self, batch: Dict) -> Dict: """Generic batch formatting for `TTSDataset`. diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 3a682ce5..11f1fab0 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -399,8 +399,14 @@ class Vits(BaseTTS): sid = sid.unsqueeze_(0) if "d_vectors" in aux_input and aux_input["d_vectors"] is not None: g = F.normalize(aux_input["d_vectors"]).unsqueeze(-1) + if g.ndim == 2: + g = g.unsqueeze_(0) + if "language_ids" in aux_input and aux_input["language_ids"] is not None: lid = aux_input["language_ids"] + if lid.ndim == 0: + lid = lid.unsqueeze_(0) + return sid, g, lid def get_aux_input(self, aux_input: Dict): @@ -437,9 +443,8 @@ class Vits(BaseTTS): """ outputs = {} sid, g, lid = self._set_cond_input(aux_input) - # speaker embedding - if self.num_speakers > 1 and sid is not None and not self.use_d_vector: + if self.args.use_speaker_embedding and sid is not None and not self.use_d_vector: g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] # language embedding @@ -521,11 +526,11 @@ class Vits(BaseTTS): x_lengths = torch.tensor(x.shape[1:2]).to(x.device) # speaker embedding - if self.num_speakers > 0 and sid: + if self.args.use_speaker_embedding and sid is not None and not self.use_d_vector: g = self.emb_g(sid).unsqueeze(-1) # language embedding - if self.args.use_language_embedding: + if self.args.use_language_embedding and lid is not None: lang_emb = self.emb_l(lid).unsqueeze(-1) x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) @@ -713,29 +718,29 @@ class Vits(BaseTTS): test_audios = {} test_figures = {} test_sentences = self.config.test_sentences - if hasattr(self, "speaker_manager"): - aux_inputs = self.speaker_manager.get_random_speaker_aux_input() - else: - aux_inputs = self.get_aux_input() - for idx, sen in enumerate(test_sentences): + for idx, s_info in enumerate(test_sentences): + try: + aux_inputs = self.get_aux_input_from_test_setences(s_info) + wav, alignment, _, _ = synthesis( + self, + aux_inputs["text"], + self.config, + "cuda" in str(next(self.parameters()).device), + ap, + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + style_wav=aux_inputs["style_wav"], + language_id=aux_inputs["language_id"], + enable_eos_bos_chars=self.config.enable_eos_bos_chars, + use_griffin_lim=True, + do_trim_silence=False, + ).values() - wav, alignment, _, _ = synthesis( - self, - sen, - self.config, - "cuda" in str(next(self.parameters()).device), - ap, - speaker_id=aux_inputs["speaker_id"], - d_vector=aux_inputs["d_vector"], - style_wav=aux_inputs["style_wav"], - enable_eos_bos_chars=self.config.enable_eos_bos_chars, - use_griffin_lim=True, - do_trim_silence=False, - ).values() - - test_audios["{}-audio".format(idx)] = wav - test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False) + test_audios["{}-audio".format(idx)] = wav + test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False) + except: # pylint: disable=bare-except + print(" !! Error creating Test Sentence -", idx) return test_figures, test_audios def get_optimizer(self) -> List: @@ -832,3 +837,5 @@ class Vits(BaseTTS): if eval: self.eval() assert not self.training + + diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index b7dd5251..1497ca74 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -193,6 +193,20 @@ class SpeakerManager: """ return [x["embedding"] for x in self.d_vectors.values() if x["name"] == speaker_idx] + def get_d_vector_by_speaker(self, speaker_idx: str) -> np.ndarray: + """Get a d_vector of a speaker. + + Args: + speaker_idx (str): Target speaker ID. + + Returns: + np.ndarray: d_vector. + """ + for x in self.d_vectors.values(): + if x["name"] == speaker_idx: + return x["embedding"] + return None + def get_mean_d_vector(self, speaker_idx: str, num_samples: int = None, randomize: bool = False) -> np.ndarray: """Get mean d_vector of a speaker ID. @@ -215,14 +229,31 @@ class SpeakerManager: d_vectors = np.stack(d_vectors[:num_samples]).mean(0) return d_vectors - def get_random_speaker_aux_input(self) -> Dict: - if self.d_vectors: - return {"speaker_id": None, "style_wav": None, "d_vector": self.d_vectors[random.choices(list(self.d_vectors.keys()))[0]]["embedding"]} + def get_random_speaker_id(self) -> Any: + """Get a random d_vector. + Args: + + Returns: + np.ndarray: d_vector. + """ if self.speaker_ids: - return {"speaker_id": self.speaker_ids[random.choices(list(self.speaker_ids.keys()))[0]], "style_wav": None, "d_vector": None} + return self.speaker_ids[random.choices(list(self.speaker_ids.keys()))[0]] - return {"speaker_id": None, "style_wav": None, "d_vector": None} + return None + + def get_random_d_vector(self) -> Any: + """Get a random D ID. + + Args: + + Returns: + np.ndarray: d_vector. + """ + if self.d_vectors: + return self.d_vectors[random.choices(list(self.d_vectors.keys()))[0]]["embedding"] + + return None def get_speakers(self) -> List: return self.speaker_ids diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 578c26c0..63fe92c3 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -71,6 +71,7 @@ def run_model_torch( speaker_id: int = None, style_mel: torch.Tensor = None, d_vector: torch.Tensor = None, + language_id: torch.Tensor = None, ) -> Dict: """Run a torch model for inference. It does not support batch inference. @@ -96,6 +97,7 @@ def run_model_torch( "speaker_ids": speaker_id, "d_vectors": d_vector, "style_mel": style_mel, + "language_ids": language_id, }, ) return outputs @@ -160,13 +162,13 @@ def inv_spectrogram(postnet_output, ap, CONFIG): return wav -def speaker_id_to_torch(speaker_id, cuda=False): - if speaker_id is not None: - speaker_id = np.asarray(speaker_id) - speaker_id = torch.from_numpy(speaker_id) +def id_to_torch(aux_id, cuda=False): + if aux_id is not None: + aux_id = np.asarray(aux_id) + aux_id = torch.from_numpy(aux_id) if cuda: - return speaker_id.cuda() - return speaker_id + return aux_id.cuda() + return aux_id def embedding_to_torch(d_vector, cuda=False): @@ -208,6 +210,7 @@ def synthesis( use_griffin_lim=False, do_trim_silence=False, d_vector=None, + language_id=None, backend="torch", ): """Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to @@ -262,11 +265,14 @@ def synthesis( # pass tensors to backend if backend == "torch": if speaker_id is not None: - speaker_id = speaker_id_to_torch(speaker_id, cuda=use_cuda) + speaker_id = id_to_torch(speaker_id, cuda=use_cuda) if d_vector is not None: d_vector = embedding_to_torch(d_vector, cuda=use_cuda) + if language_id is not None: + language_id = id_to_torch(language_id, cuda=use_cuda) + if not isinstance(style_mel, dict): style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda) text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda) @@ -278,7 +284,7 @@ def synthesis( text_inputs = tf.expand_dims(text_inputs, 0) # synthesize voice if backend == "torch": - outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, d_vector=d_vector) + outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, d_vector=d_vector, language_id=language_id) model_outputs = outputs["model_outputs"] model_outputs = model_outputs[0].data.cpu().numpy() alignments = outputs["alignments"]