From eeb8ac07d964db769fd6eea01c8669739a60e3de Mon Sep 17 00:00:00 2001 From: Edresson Date: Tue, 14 Sep 2021 17:27:00 -0300 Subject: [PATCH] Add voice conversion fine tuning mode --- TTS/bin/find_unique_phonemes.py | 63 +++++++++++++++++++++++++++++++++ TTS/tts/layers/losses.py | 2 +- TTS/tts/models/vits.py | 59 +++++++++++++++++++++++++++--- 3 files changed, 119 insertions(+), 5 deletions(-) create mode 100644 TTS/bin/find_unique_phonemes.py diff --git a/TTS/bin/find_unique_phonemes.py b/TTS/bin/find_unique_phonemes.py new file mode 100644 index 00000000..7ed79b36 --- /dev/null +++ b/TTS/bin/find_unique_phonemes.py @@ -0,0 +1,63 @@ +"""Find all the unique characters in a dataset""" +import argparse +from argparse import RawTextHelpFormatter + +from TTS.config import load_config +from TTS.tts.datasets import load_meta_data + +import numpy +import multiprocessing +from TTS.tts.utils.text import text2phone +from tqdm.contrib.concurrent import process_map + +def compute_phonemes(item): + try: + text = item[0] + language = item[-1] + ph = text2phone(text, language, use_espeak_phonemes=c.use_espeak_phonemes).split("|") + except: + return [] + return list(set(ph)) + +def main(): + global c + # pylint: disable=bad-option-value + parser = argparse.ArgumentParser( + description="""Find all the unique characters or phonemes in a dataset.\n\n""" + """ + Example runs: + + python TTS/bin/find_unique_chars.py --config_path config.json + """, + formatter_class=RawTextHelpFormatter, + ) + parser.add_argument("--config_path", type=str, help="Path to dataset config file.", required=True) + args = parser.parse_args() + + c = load_config(args.config_path) + + # load all datasets + train_items, eval_items = load_meta_data(c.datasets, eval_split=True) + items = train_items + eval_items + print("Num items:", len(items)) + # items = items[:1000] + + phonemes = process_map(compute_phonemes, items, max_workers=multiprocessing.cpu_count(), chunksize=15) + phones = [] + for ph in phonemes: + phones.extend(ph) + phones = set(phones) + lower_phones = filter(lambda c: c.islower(), phones) + phones_force_lower = [c.lower() for c in phones] + phones_force_lower = set(phones_force_lower) + + + + print(f" > Number of unique phonemes: {len(phones)}") + print(f" > Unique phonemes: {''.join(sorted(phones))}") + print(f" > Unique lower phonemes: {''.join(sorted(lower_phones))}") + print(f" > Unique all forced to lower phonemes: {''.join(sorted(phones_force_lower))}") + + +if __name__ == "__main__": + main() diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index fdee9c10..cd2903b0 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -599,7 +599,7 @@ class VitsGeneratorLoss(nn.Module): feats_disc_fake, feats_disc_real, loss_duration, - fine_tuning_mode=False, + fine_tuning_mode=0, use_speaker_encoder_as_loss=False, gt_spk_emb=None, syn_spk_emb=None diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 71cc4634..a9078b26 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -149,6 +149,28 @@ class VitsArgs(Coqpit): detach_dp_input (bool): Detach duration predictor's input from the network for stopping the gradients. Defaults to True. + + use_language_embedding (bool): + Enable/Disable language embedding for multilingual models. Defaults to False. + + embedded_language_dim (int): + Number of language embedding channels. Defaults to 4. + + num_languages (int): + Number of languages for the language embedding layer. Defaults to 0. + + use_speaker_encoder_as_loss (bool): + + + use_speaker_encoder_as_loss: bool = False + speaker_encoder_config_path: str = "" + speaker_encoder_model_path: str = "" + + fine_tuning_mode (int): + Fine tuning only the vocoder part of the model, while the rest will be frozen. Defaults to 0. + Mode 0: disabled; + Mode 1: uses the distribution predicted by the encoder and It's recommended for TTS; + Mode 2: uses the distribution predicted by the encoder and It's recommended for voice conversion. """ num_chars: int = 100 @@ -194,10 +216,10 @@ class VitsArgs(Coqpit): use_language_embedding: bool = False embedded_language_dim: int = 4 num_languages: int = 0 - fine_tuning_mode: bool = False use_speaker_encoder_as_loss: bool = False speaker_encoder_config_path: str = "" speaker_encoder_model_path: str = "" + fine_tuning_mode: int = 0 @@ -565,6 +587,7 @@ class Vits(BaseTTS): y: torch.tensor, y_lengths: torch.tensor, aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}, + waveform=None, ) -> Dict: """Forward pass of the model. @@ -621,22 +644,50 @@ class Vits(BaseTTS): m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p]) - # get the z after inverse decoder - # ToDo: test if using m_p the result is better (In the SC-GlowTTS paper we used mp instead z_p) - z_f_pred = self.flow(z_p, y_mask, g=g, reverse=True) + # mode 1: like SC-GlowTTS paper; mode 2: recommended for voice conversion + if self.args.fine_tuning_mode == 1: + z_ft = m_p + elif self.args.fine_tuning_mode == 2: + z_ft = z_p + else: + raise RuntimeError(" [!] Invalid Fine Tunning Mode !") + + # inverse decoder and get the output + z_f_pred = self.flow(z_ft, y_mask, g=g, reverse=True) z_slice, slice_ids = rand_segment(z_f_pred, y_lengths, self.spec_segment_size) o = self.waveform_decoder(z_slice, g=g) + + wav_seg = segment( + waveform.transpose(1, 2), + slice_ids * self.config.audio.hop_length, + self.args.spec_segment_size * self.config.audio.hop_length, + ) + + if self.args.use_speaker_encoder_as_loss: + # concate generated and GT waveforms + wavs_batch = torch.cat((wav_seg, o), dim=0).squeeze(1) + pred_embs = self.speaker_encoder.forward(wavs_batch, l2_norm=True) + + # split generated and GT speaker embeddings + gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0) + else: + gt_spk_emb, syn_spk_emb = None, None + outputs.update( { "model_outputs": o, "alignments": attn.squeeze(1), + "loss_duration": 0.0, "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p, "m_q": m_q, "logs_q": logs_q, + "waveform_seg": wav_seg, + "gt_spk_emb": gt_spk_emb, + "syn_spk_emb": syn_spk_emb } ) return outputs