From d94b8bac029596821d51b1a26181ab7bceaac17c Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 16 May 2022 21:53:49 +0000 Subject: [PATCH] Add pitch predictor --- TTS/tts/configs/vits_config.py | 5 + TTS/tts/datasets/dataset.py | 1 + TTS/tts/layers/losses.py | 6 + TTS/tts/models/vits.py | 163 +++++++++++++++++- ...th_prosody_encoder_with_pitch_predictor.py | 87 ++++++++++ 5 files changed, 260 insertions(+), 2 deletions(-) create mode 100644 tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder_with_pitch_predictor.py diff --git a/TTS/tts/configs/vits_config.py b/TTS/tts/configs/vits_config.py index 9c2157a9..fbf71e1b 100644 --- a/TTS/tts/configs/vits_config.py +++ b/TTS/tts/configs/vits_config.py @@ -115,6 +115,7 @@ class VitsConfig(BaseTTSConfig): mel_loss_alpha: float = 45.0 dur_loss_alpha: float = 1.0 consistency_loss_alpha: float = 1.0 + pitch_loss_alpha: float = 5.0 # data loader params return_wav: bool = True @@ -149,6 +150,10 @@ class VitsConfig(BaseTTSConfig): d_vector_file: str = None d_vector_dim: int = None + # dataset configs + compute_f0: bool = False + f0_cache_path: str = None + def __post_init__(self): for key, val in self.model_args.items(): if hasattr(self, key): diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index d8f16e4e..ca94cb7e 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -122,6 +122,7 @@ class TTSDataset(Dataset): self.return_wav = return_wav self.compute_f0 = compute_f0 self.f0_cache_path = f0_cache_path + self.precompute_num_workers = precompute_num_workers self.min_audio_len = min_audio_len self.max_audio_len = max_audio_len self.min_text_len = min_text_len diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 99e66749..b7d87237 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -527,6 +527,7 @@ class AlignTTSLoss(nn.Module): class VitsGeneratorLoss(nn.Module): def __init__(self, c: Coqpit): super().__init__() + self.pitch_loss_alpha = c.pitch_loss_alpha self.kl_loss_alpha = c.kl_loss_alpha self.gen_loss_alpha = c.gen_loss_alpha self.feat_loss_alpha = c.feat_loss_alpha @@ -606,6 +607,7 @@ class VitsGeneratorLoss(nn.Module): gt_cons_emb=None, syn_cons_emb=None, loss_spk_reversal_classifier=None, + pitch_loss=None, ): """ Shapes: @@ -644,6 +646,10 @@ class VitsGeneratorLoss(nn.Module): if loss_spk_reversal_classifier is not None: loss += loss_spk_reversal_classifier return_dict["loss_spk_reversal_classifier"] = loss_spk_reversal_classifier + if pitch_loss is not None: + pitch_loss = pitch_loss * self.pitch_loss_alpha + loss += pitch_loss + return_dict["pitch_loss"] = pitch_loss # pass losses to the dict return_dict["loss_gen"] = loss_gen diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 2a9a4bbb..0eb91719 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1,5 +1,7 @@ import math import os +import numpy as np +import pyworld as pw from dataclasses import dataclass, field, replace from itertools import chain from typing import Dict, List, Tuple, Union @@ -16,7 +18,7 @@ from torch.utils.data import DataLoader from trainer.trainer_utils import get_optimizer, get_scheduler from TTS.tts.configs.shared_configs import CharactersConfig -from TTS.tts.datasets.dataset import TTSDataset, _parse_sample +from TTS.tts.datasets.dataset import TTSDataset, _parse_sample, F0Dataset from TTS.tts.layers.generic.classifier import ReversalClassifier from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.layers.tacotron.gst_layers import GST @@ -24,6 +26,8 @@ from TTS.tts.layers.vits.discriminator import VitsDiscriminator from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.data import prepare_data +from TTS.tts.utils.helpers import average_over_durations from TTS.tts.utils.emotions import EmotionManager from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask from TTS.tts.utils.languages import LanguageManager @@ -185,16 +189,66 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm spec = amp_to_db(spec) return spec +def compute_f0(x: np.ndarray, sample_rate, hop_length, pitch_fmax=800.0) -> np.ndarray: + """Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram. + + Args: + x (np.ndarray): Waveform. + + Returns: + np.ndarray: Pitch. + """ + # align F0 length to the spectrogram length + if len(x) % hop_length == 0: + x = np.pad(x, (0, hop_length // 2), mode="reflect") + + f0, t = pw.dio( + x.astype(np.double), + fs=sample_rate, + f0_ceil=pitch_fmax, + frame_period=1000 * hop_length / sample_rate, + ) + f0 = pw.stonemask(x.astype(np.double), f0, t, sample_rate) + return f0 ############################## # DATASET ############################## +class VITSF0Dataset(F0Dataset): + def __init__(self, config, *args, **kwargs): + super().__init__(*args, **kwargs) + self.config = config + + def compute_or_load(self, wav_file): + """ + compute pitch and return a numpy array of pitch values + """ + pitch_file = self.create_pitch_file_path(wav_file, self.cache_path) + if not os.path.exists(pitch_file): + pitch = self._compute_and_save_pitch(wav_file, pitch_file) + else: + pitch = np.load(pitch_file) + return pitch.astype(np.float32) + + def _compute_and_save_pitch(self, wav_file, pitch_file=None): + print(wav_file, pitch_file) + wav, _ = load_audio(wav_file) + pitch = compute_f0(wav.squeeze().numpy(), self.config.audio.sample_rate, self.config.audio.hop_length) + if pitch_file: + np.save(pitch_file, pitch) + return pitch + + class VitsDataset(TTSDataset): - def __init__(self, *args, **kwargs): + def __init__(self, config, *args, **kwargs): super().__init__(*args, **kwargs) self.pad_id = self.tokenizer.characters.pad_id + + self.f0_dataset = VITSF0Dataset(config, + samples=self.samples, ap=self.ap, cache_path=self.f0_cache_path, precompute_num_workers=self.precompute_num_workers + ) def __getitem__(self, idx): item = self.samples[idx] @@ -205,6 +259,11 @@ class VitsDataset(TTSDataset): token_ids = self.get_token_ids(idx, item["text"]) + # get f0 values + f0 = None + if self.compute_f0: + f0 = self.get_f0(idx)["f0"] + # after phonemization the text length may change # this is a shameful 🤭 hack to prevent longer phonemes # TODO: find a better fix @@ -220,6 +279,7 @@ class VitsDataset(TTSDataset): "wav_file": wav_filename, "speaker_name": item["speaker_name"], "language_name": item["language"], + "pitch": f0, } @property @@ -272,6 +332,14 @@ class VitsDataset(TTSDataset): wav = batch["wav"][i] wav_padded[i, :, : wav.size(1)] = torch.FloatTensor(wav) + + + # format F0 + if self.compute_f0: + pitch = prepare_data(batch["pitch"]) + pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT + else: + pitch = None return { "tokens": token_padded, @@ -280,6 +348,7 @@ class VitsDataset(TTSDataset): "waveform": wav_padded, # (B x T) "waveform_lens": wav_lens, # (B) "waveform_rel_lens": wav_rel_lens, + "pitch": pitch, "speaker_names": batch["speaker_name"], "language_names": batch["language_name"], "audio_files": batch["wav_file"], @@ -437,6 +506,21 @@ class VitsArgs(Coqpit): encoder_model_path (str): Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "". + use_pitch (bool): + Use pitch predictor to learn the pitch. Defaults to False. + + pitch_predictor_hidden_channels (int): + Number of hidden channels in the pitch predictor. Defaults to 256. + + pitch_predictor_dropout_p (float): + Dropout rate for the pitch predictor. Defaults to 0.1. + + pitch_predictor_kernel_size (int): + Kernel size of conv layers in the pitch predictor. Defaults to 3. + + pitch_embedding_kernel_size (int): + Kernel size of the projection layer in the pitch predictor. Defaults to 3. + condition_dp_on_speaker (bool): Condition the duration predictor on the speaker embedding. Defaults to True. @@ -509,6 +593,14 @@ class VitsArgs(Coqpit): prosody_encoder_num_heads: int = 1 prosody_encoder_num_tokens: int = 5 + # Pitch predictor + use_pitch: bool = False + pitch_predictor_hidden_channels: int = 256 + pitch_predictor_kernel_size: int = 3 + pitch_predictor_dropout_p: float = 0.1 + pitch_embedding_kernel_size: int = 3 + detach_pp_input: bool = False + detach_dp_input: bool = True use_language_embedding: bool = False embedded_language_dim: int = 4 @@ -639,6 +731,21 @@ class Vits(BaseTTS): language_emb_dim=self.embedded_language_dim, ) + if self.args.use_pitch: + self.pitch_predictor = DurationPredictor( + self.args.hidden_channels + self.args.emotion_embedding_dim + self.args.prosody_embedding_dim, + self.args.pitch_predictor_hidden_channels, + self.args.pitch_predictor_kernel_size, + self.args.pitch_predictor_dropout_p, + cond_channels=dp_cond_embedding_dim, + ) + self.pitch_emb = nn.Conv1d( + 1, + self.args.hidden_channels, + kernel_size=self.args.pitch_embedding_kernel_size, + padding=int((self.args.pitch_embedding_kernel_size - 1) / 2), + ) + if self.args.use_prosody_encoder: self.prosody_encoder = GST( num_mel=self.args.hidden_channels, @@ -878,6 +985,47 @@ class Vits(BaseTTS): g = speaker_ids if speaker_ids is not None else d_vectors return g + def forward_pitch_predictor( + self, + o_en: torch.FloatTensor, + x_mask: torch.IntTensor, + pitch: torch.FloatTensor = None, + dr: torch.IntTensor = None, + g_pp: torch.IntTensor = None, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + """Pitch predictor forward pass. + + 1. Predict pitch from encoder outputs. + 2. In training - Compute average pitch values for each input character from the ground truth pitch values. + 3. Embed average pitch values. + + Args: + o_en (torch.FloatTensor): Encoder output. + x_mask (torch.IntTensor): Input sequence mask. + pitch (torch.FloatTensor, optional): Ground truth pitch values. Defaults to None. + dr (torch.IntTensor, optional): Ground truth durations. Defaults to None. + g_pp (torch.IntTensor, optional): Speaker/prosody embedding to condition the pithc predictor. Defaults to None. + + Returns: + Tuple[torch.FloatTensor, torch.FloatTensor]: Pitch embedding, pitch prediction. + + Shapes: + - o_en: :math:`(B, C, T_{en})` + - x_mask: :math:`(B, 1, T_{en})` + - pitch: :math:`(B, 1, T_{de})` + - dr: :math:`(B, T_{en})` + """ + o_pitch = self.pitch_predictor( + o_en, + x_mask, + g=g_pp.detach() if self.args.detach_pp_input and g_pp is not None else g_pp + ) + print(o_pitch.shape, pitch.shape, dr.shape) + avg_pitch = average_over_durations(pitch, dr.squeeze()) + o_pitch_emb = self.pitch_emb(avg_pitch) + pitch_loss = torch.sum(torch.sum((o_pitch_emb - o_pitch) ** 2, [1, 2]) / torch.sum(x_mask)) + return pitch_loss + def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb): # find the alignment path attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) @@ -920,6 +1068,7 @@ class Vits(BaseTTS): y: torch.tensor, y_lengths: torch.tensor, waveform: torch.tensor, + pitch: torch.tensor, aux_input={ "d_vectors": None, "speaker_ids": None, @@ -1011,6 +1160,9 @@ class Vits(BaseTTS): outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb) + if self.args.use_pitch: + pitch_loss = self.forward_pitch_predictor(x, x_mask, pitch, attn.sum(3), g_dp) + # expand prior m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p]) @@ -1063,6 +1215,7 @@ class Vits(BaseTTS): "syn_cons_emb": syn_cons_emb, "slice_ids": slice_ids, "loss_spk_reversal_classifier": l_pros_speaker, + "pitch_loss": pitch_loss, } ) return outputs @@ -1281,6 +1434,7 @@ class Vits(BaseTTS): emotion_embeddings = batch["emotion_embeddings"] emotion_ids = batch["emotion_ids"] waveform = batch["waveform"] + pitch = batch["pitch"] # generator pass outputs = self.forward( @@ -1289,6 +1443,7 @@ class Vits(BaseTTS): spec, spec_lens, waveform, + pitch, aux_input={ "d_vectors": d_vectors, "speaker_ids": speaker_ids, @@ -1358,6 +1513,7 @@ class Vits(BaseTTS): gt_cons_emb=self.model_outputs_cache["gt_cons_emb"], syn_cons_emb=self.model_outputs_cache["syn_cons_emb"], loss_spk_reversal_classifier=self.model_outputs_cache["loss_spk_reversal_classifier"], + pitch_loss=self.model_outputs_cache["pitch_loss"], ) return self.model_outputs_cache, loss_dict @@ -1613,6 +1769,7 @@ class Vits(BaseTTS): else: # init dataloader dataset = VitsDataset( + config=config, samples=samples, # batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, min_text_len=config.min_text_len, @@ -1624,6 +1781,8 @@ class Vits(BaseTTS): verbose=verbose, tokenizer=self.tokenizer, start_by_longest=config.start_by_longest, + compute_f0=config.get("compute_f0", False), + f0_cache_path=config.get("f0_cache_path", None), ) # wait all the DDP process to be ready diff --git a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder_with_pitch_predictor.py b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder_with_pitch_predictor.py new file mode 100644 index 00000000..9b13d501 --- /dev/null +++ b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder_with_pitch_predictor.py @@ -0,0 +1,87 @@ +import glob +import os +import shutil + +from trainer import get_last_checkpoint + +from tests import get_device_id, get_tests_output_path, run_cli +from TTS.tts.configs.vits_config import VitsConfig + +config_path = os.path.join(get_tests_output_path(), "test_model_config.json") +output_path = os.path.join(get_tests_output_path(), "train_outputs") + + +config = VitsConfig( + batch_size=2, + eval_batch_size=2, + num_loader_workers=0, + num_eval_loader_workers=0, + text_cleaner="english_cleaners", + use_phonemes=True, + phoneme_language="en-us", + phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", + run_eval=True, + test_delay_epochs=-1, + epochs=1, + print_step=1, + print_eval=True, + compute_f0=True, + f0_cache_path="tests/data/ljspeech/f0_cache/", + test_sentences=[ + ["Be a voice, not an echo.", "ljspeech-1", "tests/data/ljspeech/wavs/LJ001-0001.wav", None, None], + ], +) +# set audio config +config.audio.do_trim_silence = True +config.audio.trim_db = 60 + +# active multispeaker d-vec mode +config.model_args.use_speaker_embedding = True +config.model_args.use_d_vector_file = False +config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json" +config.model_args.speaker_embedding_channels = 128 +config.model_args.d_vector_dim = 128 + +# prosody embedding +config.model_args.use_prosody_encoder = True +config.model_args.prosody_embedding_dim = 64 + +# pitch predictor +config.model_args.use_pitch = True +config.model_args.condition_dp_on_speaker = True + + +config.save_json(config_path) + +# train the model for one epoch +command_train = ( + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " + f"--coqpit.output_path {output_path} " + "--coqpit.datasets.0.name ljspeech_test " + "--coqpit.datasets.0.meta_file_train metadata.csv " + "--coqpit.datasets.0.meta_file_val metadata.csv " + "--coqpit.datasets.0.path tests/data/ljspeech " + "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt " + "--coqpit.test_delay_epochs 0" +) +run_cli(command_train) + +# Find latest folder +continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) + +# Inference using TTS API +continue_config_path = os.path.join(continue_path, "config.json") +continue_restore_path, _ = get_last_checkpoint(continue_path) +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") +speaker_id = "ljspeech-1" +style_wav_path = "tests/data/ljspeech/wavs/LJ001-0001.wav" +continue_speakers_path = os.path.join(continue_path, "speakers.json") + + +inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path} --gst_style {style_wav_path}" +run_cli(inference_command) + +# restore the model and continue training for one more epoch +command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " +run_cli(command_train) +shutil.rmtree(continue_path)