diff --git a/TTS/bin/compute_vits_alignments.py b/TTS/bin/compute_vits_alignments.py index 5f27b2cc..df4e0b65 100644 --- a/TTS/bin/compute_vits_alignments.py +++ b/TTS/bin/compute_vits_alignments.py @@ -24,54 +24,56 @@ def extract_aligments( data_loader, model, output_path, use_cuda=True ): model.eval() - export_metadata = [] - for _, batch in tqdm(enumerate(data_loader), total=len(data_loader)): + with torch.no_grad(): + for _, batch in tqdm(enumerate(data_loader), total=len(data_loader)): - batch = model.format_batch(batch) - if use_cuda: - for k, v in batch.items(): - batch[k] = to_cuda(v) + batch = model.format_batch(batch) + if use_cuda: + for k, v in batch.items(): + batch[k] = to_cuda(v) - batch = model.format_batch_on_device(batch) + batch = model.format_batch_on_device(batch) - spec_lens = batch["spec_lens"] - tokens = batch["tokens"] - token_lenghts = batch["token_lens"] - spec = batch["spec"] + spec_lens = batch["spec_lens"] + tokens = batch["tokens"] + token_lenghts = batch["token_lens"] + spec = batch["spec"] - d_vectors = batch["d_vectors"] - speaker_ids = batch["speaker_ids"] - language_ids = batch["language_ids"] - emotion_embeddings = batch["emotion_embeddings"] - emotion_ids = batch["emotion_ids"] - waveform = batch["waveform"] - item_idx = batch["audio_files"] - # generator pass - outputs = model.forward( - tokens, - token_lenghts, - spec, - spec_lens, - waveform, - aux_input={ - "d_vectors": d_vectors, - "speaker_ids": speaker_ids, - "language_ids": language_ids, - "emotion_embeddings": emotion_embeddings, - "emotion_ids": emotion_ids, - }, - ) + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + language_ids = batch["language_ids"] + emotion_embeddings = batch["emotion_embeddings"] + emotion_ids = batch["emotion_ids"] + waveform = batch["waveform"] + item_idx = batch["audio_files"] + pitch = batch["pitch"] + # generator pass + outputs = model.forward( + tokens, + token_lenghts, + spec, + spec_lens, + waveform, + pitch, + aux_input={ + "d_vectors": d_vectors, + "speaker_ids": speaker_ids, + "language_ids": language_ids, + "emotion_embeddings": emotion_embeddings, + "emotion_ids": emotion_ids, + }, + ) - alignments = outputs["alignments"].detach().cpu().numpy() + alignments = outputs["alignments"].detach().cpu().numpy() - for idx in range(tokens.shape[0]): - wav_file_path = item_idx[idx] - alignment = alignments[idx] - # set paths - align_file_name = os.path.splitext(os.path.basename(wav_file_path))[0] + ".npy" - os.makedirs(os.path.join(output_path, "alignments"), exist_ok=True) - align_file_path = os.path.join(output_path, "alignments", align_file_name) - np.save(align_file_path, alignment) + for idx in range(tokens.shape[0]): + wav_file_path = item_idx[idx] + alignment = alignments[idx] + # set paths + align_file_name = os.path.splitext(os.path.basename(wav_file_path))[0] + ".npy" + os.makedirs(os.path.join(output_path, "alignments"), exist_ok=True) + align_file_path = os.path.join(output_path, "alignments", align_file_name) + np.save(align_file_path, alignment) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index dcc1071c..c54850c6 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -657,6 +657,7 @@ class VitsArgs(Coqpit): use_avg_feature_on_latent_discriminator: bool = False # Pitch predictor + use_pitch_on_enc_input: bool = False use_pitch: bool = False pitch_predictor_hidden_channels: int = 256 pitch_predictor_kernel_size: int = 3 @@ -811,12 +812,22 @@ class Vits(BaseTTS): ) if self.args.use_pitch: + if self.args.use_pitch_on_enc_input: + self.pitch_predictor_vocab_emb = nn.Embedding(self.args.num_chars, self.args.hidden_channels) + else: + self.pitch_emb = nn.Conv1d( + 1, + self.args.hidden_channels, + kernel_size=self.args.pitch_predictor_kernel_size, + padding=int((self.args.pitch_predictor_kernel_size - 1) / 2), + ) self.pitch_predictor = DurationPredictor( - self.args.hidden_channels + self.args.emotion_embedding_dim + self.args.prosody_embedding_dim, + self.args.hidden_channels, self.args.pitch_predictor_hidden_channels, self.args.pitch_predictor_kernel_size, self.args.pitch_predictor_dropout_p, cond_channels=dp_cond_embedding_dim, + language_emb_dim=self.embedded_language_dim, ) if self.args.use_prosody_encoder: @@ -1190,7 +1201,7 @@ class Vits(BaseTTS): def forward_pitch_predictor( self, o_en: torch.FloatTensor, - x_mask: torch.IntTensor, + x_lengths: torch.IntTensor, pitch: torch.FloatTensor = None, dr: torch.IntTensor = None, g_pp: torch.IntTensor = None, @@ -1217,15 +1228,30 @@ class Vits(BaseTTS): - pitch: :math:`(B, 1, T_{de})` - dr: :math:`(B, T_{en})` """ - o_pitch = self.pitch_predictor( + if self.args.use_pitch_on_enc_input: + o_en = self.pitch_predictor_vocab_emb(o_en) + o_en = torch.transpose(o_en, 1, -1) # [b, h, t] + + x_mask = torch.unsqueeze(sequence_mask(x_lengths, o_en.size(2)), 1).to(o_en.dtype) # [b, 1, t] + + pred_avg_pitch = self.pitch_predictor( o_en, x_mask, g=g_pp.detach() if self.args.detach_pp_input and g_pp is not None else g_pp ) - avg_pitch = average_over_durations(pitch, dr.squeeze()) - pitch_loss = torch.sum(torch.sum((avg_pitch - o_pitch) ** 2, [1, 2]) / torch.sum(x_mask)) - return pitch_loss + pitch_loss = None + gt_avg_pitch = None + if pitch is not None: + gt_avg_pitch = average_over_durations(pitch, dr.squeeze()).detach() + pitch_loss = torch.sum(torch.sum((gt_avg_pitch - pred_avg_pitch) ** 2, [1, 2]) / torch.sum(x_mask)) + if not self.args.use_pitch_on_enc_input: + gt_agv_pitch = self.pitch_emb(gt_avg_pitch) + else: + if not self.args.use_pitch_on_enc_input: + pred_avg_pitch = self.pitch_emb(pred_avg_pitch) + + return pitch_loss, gt_agv_pitch, pred_avg_pitch def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb): # find the alignment path @@ -1392,7 +1418,7 @@ class Vits(BaseTTS): _, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None) if self.args.use_prosody_enc_emo_classifier: _, l_pros_emotion = self.pros_enc_emotion_classifier(pros_emb.transpose(1, 2), eid, x_mask=None) - + x_input = x x, m_p, logs_p, x_mask = self.text_encoder( x, x_lengths, @@ -1428,8 +1454,9 @@ class Vits(BaseTTS): outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb) pitch_loss = None - if self.args.use_pitch: - pitch_loss = self.forward_pitch_predictor(x, x_mask, pitch, attn.sum(3), g_dp) + if self.args.use_pitch and not self.args.use_pitch_on_enc_input: + pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(m_p, x_lengths, pitch, attn.sum(3), g_dp) + m_p = m_p + gt_avg_pitch_emb # expand prior m_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) @@ -1646,6 +1673,12 @@ class Vits(BaseTTS): attn_mask = x_mask * y_mask.transpose(1, 2) # [B, 1, T_enc] * [B, T_dec, 1] attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2)) + pred_avg_pitch_emb = None + if self.args.use_pitch and not self.args.use_pitch_on_enc_input: + _, _, pred_avg_pitch_emb = self.forward_pitch_predictor(m_p, x_lengths, g_pp=g_dp) + m_p = m_p + pred_avg_pitch_emb + + m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2) logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2) @@ -1683,6 +1716,7 @@ class Vits(BaseTTS): "m_p": m_p, "logs_p": logs_p, "y_mask": y_mask, + "pitch": pred_avg_pitch_emb, } return outputs @@ -1693,7 +1727,7 @@ class Vits(BaseTTS): " [!] Style reference need to have sampling rate equal to {self.config.audio.sample_rate} !!" ) y = wav_to_spec( - style_wav, + style_wav.unsqueeze(1), self.config.audio.fft_size, self.config.audio.hop_length, self.config.audio.win_length, diff --git a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py index 7dfa4f01..ad8152fc 100644 --- a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py +++ b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py @@ -25,6 +25,8 @@ config = VitsConfig( epochs=1, print_step=1, print_eval=True, + compute_pitch=True, + f0_cache_path="tests/data/ljspeech/f0_cache/", test_sentences=[ ["Be a voice, not an echo.", "ljspeech-1", "tests/data/ljspeech/wavs/LJ001-0001.wav", None, None, "ljspeech-2"], ], @@ -57,12 +59,15 @@ config.model_args.use_latent_discriminator = True config.model_args.use_noise_scale_predictor = False config.model_args.condition_pros_enc_on_speaker = True -config.model_args.use_pros_enc_input_as_pros_emb = True -config.model_args.use_prosody_embedding_squeezer = True -config.model_args.prosody_embedding_squeezer_input_dim = 192 +config.model_args.use_pros_enc_input_as_pros_emb = False +config.model_args.use_prosody_embedding_squeezer = False +config.model_args.prosody_embedding_squeezer_input_dim = 0 + +# pitch predictor +config.model_args.use_pitch = True +config.model_args.use_pitch_on_enc_input = False +config.model_args.condition_dp_on_speaker = False -# enable end2end loss -config.model_args.use_end2end_loss = False config.mixed_precision = False