diff --git a/TTS/bin/compute_vits_alignments.py b/TTS/bin/compute_vits_alignments.py index df4e0b65..1d70d87b 100644 --- a/TTS/bin/compute_vits_alignments.py +++ b/TTS/bin/compute_vits_alignments.py @@ -69,6 +69,9 @@ def extract_aligments( for idx in range(tokens.shape[0]): wav_file_path = item_idx[idx] alignment = alignments[idx] + spec_length = spec_lens[idx] + token_length = token_lenghts[idx] + alignment = alignment[:token_length, :spec_length] # set paths align_file_name = os.path.splitext(os.path.basename(wav_file_path))[0] + ".npy" os.makedirs(os.path.join(output_path, "alignments"), exist_ok=True) diff --git a/TTS/tts/layers/vits/networks.py b/TTS/tts/layers/vits/networks.py index cf3cd14b..dbcb7313 100644 --- a/TTS/tts/layers/vits/networks.py +++ b/TTS/tts/layers/vits/networks.py @@ -40,6 +40,7 @@ class TextEncoder(nn.Module): language_emb_dim: int = None, emotion_emb_dim: int = None, prosody_emb_dim: int = None, + pitch_dim: int = None, ): """Text Encoder for VITS model. @@ -70,6 +71,9 @@ class TextEncoder(nn.Module): if prosody_emb_dim: hidden_channels += prosody_emb_dim + if pitch_dim: + hidden_channels += pitch_dim + self.encoder = RelativePositionTransformer( in_channels=hidden_channels, out_channels=hidden_channels, @@ -85,7 +89,7 @@ class TextEncoder(nn.Module): self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - def forward(self, x, x_lengths, lang_emb=None, emo_emb=None, pros_emb=None): + def forward(self, x, x_lengths, lang_emb=None, emo_emb=None, pros_emb=None, pitch_emb=None): """ Shapes: - x: :math:`[B, T]` @@ -105,6 +109,9 @@ class TextEncoder(nn.Module): if pros_emb is not None: x = torch.cat((x, pros_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1) + if pitch_emb is not None: + x = torch.cat((x, pitch_emb.transpose(2, 1)), dim=-1) + x = torch.transpose(x, 1, -1) # [b, h, t] x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) # [b, 1, t] diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index c54850c6..eb2e9976 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -265,7 +265,9 @@ class VitsDataset(TTSDataset): self.pad_id = self.tokenizer.characters.pad_id self.model_args = model_args self.compute_pitch = compute_pitch - + self.use_precomputed_alignments = model_args.use_precomputed_alignments + self.alignments_cache_path = model_args.alignments_cache_path + if self.compute_pitch: self.f0_dataset = VITSF0Dataset(config, samples=self.samples, ap=self.ap, cache_path=self.f0_cache_path, precompute_num_workers=self.precompute_num_workers @@ -289,6 +291,11 @@ class VitsDataset(TTSDataset): if self.compute_pitch: f0 = self.get_f0(idx)["f0"] + alignments = None + if self.use_precomputed_alignments: + align_file = os.path.join(self.alignments_cache_path, os.path.splitext(wav_filename)[0] + ".npy") + alignments = self.get_attn_mask(align_file) + # after phonemization the text length may change # this is a shameful 🤭 hack to prevent longer phonemes # TODO: find a better fix @@ -305,6 +312,8 @@ class VitsDataset(TTSDataset): "speaker_name": item["speaker_name"], "language_name": item["language"], "pitch": f0, + "alignments": alignments, + } @property @@ -365,6 +374,18 @@ class VitsDataset(TTSDataset): pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT else: pitch = None + + padded_alignments = None + if self.use_precomputed_alignments: + alignments = batch["alignments"] + max_len_1 = max((x.shape[0] for x in alignments)) + max_len_2 = max((x.shape[1] for x in alignments)) + padded_alignments = [] + for x in alignments: + padded_alignment = np.pad(x, ((0, max_len_1 - x.shape[0]), (0, max_len_2 - x.shape[1])), mode="constant", constant_values=0) + padded_alignments.append(padded_alignment) + + padded_alignments = torch.FloatTensor(np.stack(padded_alignments)).unsqueeze(1) return { "tokens": token_padded, @@ -378,6 +399,7 @@ class VitsDataset(TTSDataset): "language_names": batch["language_name"], "audio_files": batch["wav_file"], "raw_text": batch["raw_text"], + "alignments": padded_alignments, } @@ -385,7 +407,6 @@ class VitsDataset(TTSDataset): # MODEL DEFINITION ############################## - @dataclass class VitsArgs(Coqpit): """VITS model arguments. @@ -664,6 +685,9 @@ class VitsArgs(Coqpit): pitch_predictor_dropout_p: float = 0.1 pitch_embedding_kernel_size: int = 3 detach_pp_input: bool = False + use_precomputed_alignments: bool = False + alignments_cache_path: str = "" + pitch_embedding_dim: int = 0 detach_dp_input: bool = True use_language_embedding: bool = False @@ -751,6 +775,7 @@ class Vits(BaseTTS): language_emb_dim=self.embedded_language_dim, emotion_emb_dim=self.args.emotion_embedding_dim if not self.args.use_noise_scale_predictor else 0, prosody_emb_dim=self.args.prosody_embedding_dim if not self.args.use_noise_scale_predictor else 0, + pitch_dim=self.args.pitch_embedding_dim if self.args.use_pitch and self.args.use_pitch_on_enc_input else 0, ) self.posterior_encoder = PosteriorEncoder( @@ -791,6 +816,9 @@ class Vits(BaseTTS): if self.args.use_prosody_encoder and not self.args.use_noise_scale_predictor: dp_extra_inp_dim += self.args.prosody_embedding_dim + if self.args.use_pitch and self.args.use_pitch_on_enc_input: + dp_extra_inp_dim += self.args.pitch_embedding_dim + if self.args.use_sdp: self.duration_predictor = StochasticDurationPredictor( self.args.hidden_channels + dp_extra_inp_dim, @@ -814,13 +842,13 @@ class Vits(BaseTTS): if self.args.use_pitch: if self.args.use_pitch_on_enc_input: self.pitch_predictor_vocab_emb = nn.Embedding(self.args.num_chars, self.args.hidden_channels) - else: - self.pitch_emb = nn.Conv1d( - 1, - self.args.hidden_channels, - kernel_size=self.args.pitch_predictor_kernel_size, - padding=int((self.args.pitch_predictor_kernel_size - 1) / 2), - ) + + self.pitch_emb = nn.Conv1d( + 1, + self.args.hidden_channels if not self.args.use_pitch_on_enc_input else self.args.pitch_embedding_dim, + kernel_size=self.args.pitch_predictor_kernel_size, + padding=int((self.args.pitch_predictor_kernel_size - 1) / 2), + ) self.pitch_predictor = DurationPredictor( self.args.hidden_channels, self.args.pitch_predictor_hidden_channels, @@ -1241,17 +1269,16 @@ class Vits(BaseTTS): ) pitch_loss = None - gt_avg_pitch = None + pred_avg_pitch_emb = None + gt_avg_pitch_emb = None if pitch is not None: gt_avg_pitch = average_over_durations(pitch, dr.squeeze()).detach() pitch_loss = torch.sum(torch.sum((gt_avg_pitch - pred_avg_pitch) ** 2, [1, 2]) / torch.sum(x_mask)) - if not self.args.use_pitch_on_enc_input: - gt_agv_pitch = self.pitch_emb(gt_avg_pitch) + gt_avg_pitch_emb = self.pitch_emb(gt_avg_pitch) else: - if not self.args.use_pitch_on_enc_input: - pred_avg_pitch = self.pitch_emb(pred_avg_pitch) + pred_avg_pitch_emb = self.pitch_emb(pred_avg_pitch) - return pitch_loss, gt_agv_pitch, pred_avg_pitch + return pitch_loss, gt_avg_pitch_emb, pred_avg_pitch_emb def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb): # find the alignment path @@ -1313,6 +1340,7 @@ class Vits(BaseTTS): y_lengths: torch.tensor, waveform: torch.tensor, pitch: torch.tensor, + alignments: torch.tensor, aux_input={ "d_vectors": None, "speaker_ids": None, @@ -1389,6 +1417,21 @@ class Vits(BaseTTS): if self.args.use_speaker_embedding or self.args.use_d_vector_file: g = F.normalize(self.speaker_embedding_squeezer(g.squeeze(-1))).unsqueeze(-1) + # duration predictor + g_dp = g if self.args.condition_dp_on_speaker else None + if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings): + if g_dp is None: + g_dp = eg + else: + g_dp = torch.cat([g_dp, eg], dim=1) # [b, h1+h2, 1] + + pitch_loss = None + gt_avg_pitch_emb = None + if self.args.use_pitch and self.args.use_pitch_on_enc_input: + if alignments is None: + raise RuntimeError(" [!] For condition the pitch on the Text Encoder you need to provide external alignments !") + pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(x, x_lengths, pitch, alignments.sum(3), g_dp) + # posterior encoder z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g) @@ -1418,13 +1461,14 @@ class Vits(BaseTTS): _, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None) if self.args.use_prosody_enc_emo_classifier: _, l_pros_emotion = self.pros_enc_emotion_classifier(pros_emb.transpose(1, 2), eid, x_mask=None) - x_input = x + x, m_p, logs_p, x_mask = self.text_encoder( x, x_lengths, lang_emb=lang_emb, emo_emb=eg if not self.args.use_noise_scale_predictor else None, pros_emb=pros_emb if not self.args.use_noise_scale_predictor else None, + pitch_emb=gt_avg_pitch_emb if self.args.use_pitch and self.args.use_pitch_on_enc_input else None, ) # reversal speaker loss to force the encoder to be speaker identity free @@ -1437,14 +1481,6 @@ class Vits(BaseTTS): if self.args.use_text_enc_emo_classifier: _, l_text_emotion = self.emo_text_enc_classifier(m_p.transpose(1, 2), eid, x_mask=x_mask) - # duration predictor - g_dp = g if self.args.condition_dp_on_speaker else None - if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings): - if g_dp is None: - g_dp = eg - else: - g_dp = torch.cat([g_dp, eg], dim=1) # [b, h1+h2, 1] - if self.args.use_prosody_encoder: if g_dp is None: g_dp = pros_emb @@ -1453,7 +1489,6 @@ class Vits(BaseTTS): outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb) - pitch_loss = None if self.args.use_pitch and not self.args.use_pitch_on_enc_input: pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(m_p, x_lengths, pitch, attn.sum(3), g_dp) m_p = m_p + gt_avg_pitch_emb @@ -1631,14 +1666,6 @@ class Vits(BaseTTS): pros_emb = pros_emb.transpose(1, 2) - x, m_p, logs_p, x_mask = self.text_encoder( - x, - x_lengths, - lang_emb=lang_emb, - emo_emb=eg if not self.args.use_noise_scale_predictor else None, - pros_emb=pros_emb if not self.args.use_noise_scale_predictor else None, - ) - # duration predictor g_dp = g if self.args.condition_dp_on_speaker else None if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings): @@ -1647,6 +1674,19 @@ class Vits(BaseTTS): else: g_dp = torch.cat([g_dp, eg], dim=1) # [b, h1+h2, 1] + pred_avg_pitch_emb = None + if self.args.use_pitch and self.args.use_pitch_on_enc_input: + _, _, pred_avg_pitch_emb = self.forward_pitch_predictor(x, x_lengths, g_pp=g_dp) + + x, m_p, logs_p, x_mask = self.text_encoder( + x, + x_lengths, + lang_emb=lang_emb, + emo_emb=eg if not self.args.use_noise_scale_predictor else None, + pros_emb=pros_emb if not self.args.use_noise_scale_predictor else None, + pitch_emb=pred_avg_pitch_emb if self.args.use_pitch and self.args.use_pitch_on_enc_input else None, + ) + if self.args.use_prosody_encoder: if g_dp is None: g_dp = pros_emb @@ -1673,7 +1713,6 @@ class Vits(BaseTTS): attn_mask = x_mask * y_mask.transpose(1, 2) # [B, 1, T_enc] * [B, T_dec, 1] attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2)) - pred_avg_pitch_emb = None if self.args.use_pitch and not self.args.use_pitch_on_enc_input: _, _, pred_avg_pitch_emb = self.forward_pitch_predictor(m_p, x_lengths, g_pp=g_dp) m_p = m_p + pred_avg_pitch_emb @@ -1819,6 +1858,7 @@ class Vits(BaseTTS): emotion_ids = batch["emotion_ids"] waveform = batch["waveform"] pitch = batch["pitch"] + alignments = batch["alignments"] # generator pass outputs = self.forward( @@ -1828,6 +1868,7 @@ class Vits(BaseTTS): spec_lens, waveform, pitch, + alignments, aux_input={ "d_vectors": d_vectors, "speaker_ids": speaker_ids, diff --git a/tests/data/ljspeech/mas_alignments/alignments/LJ001-0001.npy b/tests/data/ljspeech/mas_alignments/alignments/LJ001-0001.npy new file mode 100644 index 00000000..c2b961bb Binary files /dev/null and b/tests/data/ljspeech/mas_alignments/alignments/LJ001-0001.npy differ diff --git a/tests/data/ljspeech/mas_alignments/alignments/LJ001-0002.npy b/tests/data/ljspeech/mas_alignments/alignments/LJ001-0002.npy new file mode 100644 index 00000000..0a227d87 Binary files /dev/null and b/tests/data/ljspeech/mas_alignments/alignments/LJ001-0002.npy differ diff --git a/tests/data/ljspeech/mas_alignments/alignments/LJ001-0003.npy b/tests/data/ljspeech/mas_alignments/alignments/LJ001-0003.npy new file mode 100644 index 00000000..bc307396 Binary files /dev/null and b/tests/data/ljspeech/mas_alignments/alignments/LJ001-0003.npy differ diff --git a/tests/data/ljspeech/mas_alignments/alignments/LJ001-0004.npy b/tests/data/ljspeech/mas_alignments/alignments/LJ001-0004.npy new file mode 100644 index 00000000..6becf176 Binary files /dev/null and b/tests/data/ljspeech/mas_alignments/alignments/LJ001-0004.npy differ diff --git a/tests/data/ljspeech/mas_alignments/alignments/LJ001-0005.npy b/tests/data/ljspeech/mas_alignments/alignments/LJ001-0005.npy new file mode 100644 index 00000000..99d27e86 Binary files /dev/null and b/tests/data/ljspeech/mas_alignments/alignments/LJ001-0005.npy differ diff --git a/tests/data/ljspeech/mas_alignments/alignments/LJ001-0006.npy b/tests/data/ljspeech/mas_alignments/alignments/LJ001-0006.npy new file mode 100644 index 00000000..16da9744 Binary files /dev/null and b/tests/data/ljspeech/mas_alignments/alignments/LJ001-0006.npy differ diff --git a/tests/data/ljspeech/mas_alignments/alignments/LJ001-0007.npy b/tests/data/ljspeech/mas_alignments/alignments/LJ001-0007.npy new file mode 100644 index 00000000..89899435 Binary files /dev/null and b/tests/data/ljspeech/mas_alignments/alignments/LJ001-0007.npy differ diff --git a/tests/data/ljspeech/mas_alignments/alignments/LJ001-0008.npy b/tests/data/ljspeech/mas_alignments/alignments/LJ001-0008.npy new file mode 100644 index 00000000..b545ad08 Binary files /dev/null and b/tests/data/ljspeech/mas_alignments/alignments/LJ001-0008.npy differ diff --git a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder_with_pitch_predictor.py b/tests/tts_tests/test_vits_speaker_emb_with_pitch_predictor.py similarity index 89% rename from tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder_with_pitch_predictor.py rename to tests/tts_tests/test_vits_speaker_emb_with_pitch_predictor.py index 30d9f0f6..f3fe2bd5 100644 --- a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder_with_pitch_predictor.py +++ b/tests/tts_tests/test_vits_speaker_emb_with_pitch_predictor.py @@ -28,7 +28,7 @@ config = VitsConfig( compute_pitch=True, f0_cache_path="tests/data/ljspeech/f0_cache/", test_sentences=[ - ["Be a voice, not an echo.", "ljspeech-1", "tests/data/ljspeech/wavs/LJ001-0001.wav", None, None], + ["Be a voice, not an echo.", "ljspeech-1", None, None, None], ], ) # set audio config @@ -42,17 +42,18 @@ config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json" config.model_args.speaker_embedding_channels = 128 config.model_args.d_vector_dim = 128 -# prosody embedding -config.model_args.use_prosody_encoder = True -config.model_args.prosody_embedding_dim = 64 + +config.model_args.use_precomputed_alignments = True +config.model_args.alignments_cache_path = "tests/data/ljspeech/mas_alignments/alignments/" # pitch predictor config.model_args.use_pitch = True +config.model_args.use_pitch_on_enc_input = True +config.model_args.pitch_embedding_dim = 2 config.model_args.condition_dp_on_speaker = True config.save_json(config_path) - # train the model for one epoch command_train = ( f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " @@ -74,11 +75,9 @@ continue_config_path = os.path.join(continue_path, "config.json") continue_restore_path, _ = get_last_checkpoint(continue_path) out_wav_path = os.path.join(get_tests_output_path(), "output.wav") speaker_id = "ljspeech-1" -style_wav_path = "tests/data/ljspeech/wavs/LJ001-0001.wav" continue_speakers_path = os.path.join(continue_path, "speakers.json") - -inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path} --gst_style {style_wav_path}" +inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path} " run_cli(inference_command) # restore the model and continue training for one more epoch