diff --git a/TTS/tts/configs/fast_pitch_config.py b/TTS/tts/configs/fast_pitch_config.py index 7e294896..2c54803a 100644 --- a/TTS/tts/configs/fast_pitch_config.py +++ b/TTS/tts/configs/fast_pitch_config.py @@ -76,6 +76,7 @@ class FastPitchConfig(BaseTTSConfig): spec_loss_alpha: float = 1.0 pitch_loss_alpha: float = 1.0 dur_loss_alpha: float = 1.0 + aligner_loss_alpha: float = 1.0 # overrides min_seq_len: int = 13 diff --git a/TTS/tts/layers/glow_tts/monotonic_align/__init__.py b/TTS/tts/layers/glow_tts/monotonic_align/__init__.py index 7757ecf8..ee058095 100644 --- a/TTS/tts/layers/glow_tts/monotonic_align/__init__.py +++ b/TTS/tts/layers/glow_tts/monotonic_align/__init__.py @@ -47,8 +47,9 @@ def maximum_path(value, mask): def maximum_path_cython(value, mask): """Cython optimised version. - value: [b, t_x, t_y] - mask: [b, t_x, t_y] + Shapes: + - value: :math:`[B, T_en, T_de]` + - mask: :math:`[B, T_en, T_de]` """ value = value * mask device = value.device diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index fefaec9a..6ca010dd 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -660,6 +660,36 @@ class VitsDiscriminatorLoss(nn.Module): return return_dict +class ForwardSumLoss(nn.Module): + def __init__(self, blank_logprob=-1): + super().__init__() + self.log_softmax = torch.nn.LogSoftmax(dim=3) + self.ctc_loss = torch.nn.CTCLoss(zero_infinity=True) + self.blank_logprob = blank_logprob + + def forward(self, attn_logprob, in_lens, out_lens): + key_lens = in_lens + query_lens = out_lens + attn_logprob_padded = torch.nn.functional.pad(input=attn_logprob, pad=(1, 0), value=self.blank_logprob) + + total_loss = 0.0 + for bid in range(attn_logprob.shape[0]): + target_seq = torch.arange(1, key_lens[bid] + 1).unsqueeze(0) + curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)[: query_lens[bid], :, : key_lens[bid] + 1] + + curr_logprob = self.log_softmax(curr_logprob[None])[0] + loss = self.ctc_loss( + curr_logprob, + target_seq, + input_lengths=query_lens[bid : bid + 1], + target_lengths=key_lens[bid : bid + 1], + ) + total_loss = total_loss + loss + + total_loss = total_loss / attn_logprob.shape[0] + return total_loss + + class FastPitchLoss(nn.Module): def __init__(self, c): super().__init__() @@ -667,11 +697,14 @@ class FastPitchLoss(nn.Module): self.ssim = SSIMLoss() self.dur_loss = MSELossMasked(False) self.pitch_loss = MSELossMasked(False) + if c.model_args.use_aligner: + self.aligner_loss = ForwardSumLoss() self.spec_loss_alpha = c.spec_loss_alpha self.ssim_loss_alpha = c.ssim_loss_alpha self.dur_loss_alpha = c.dur_loss_alpha self.pitch_loss_alpha = c.pitch_loss_alpha + self.aligner_loss_alpha = c.aligner_loss_alpha def forward( self, @@ -683,29 +716,35 @@ class FastPitchLoss(nn.Module): pitch_output, pitch_target, input_lens, + alignment_logprob=None, ): loss = 0 return_dict = {} if self.ssim_loss_alpha > 0: ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) - loss += self.ssim_loss_alpha * ssim_loss + loss = loss + self.ssim_loss_alpha * ssim_loss return_dict["loss_ssim"] = self.ssim_loss_alpha * ssim_loss if self.spec_loss_alpha > 0: spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens) - loss += self.spec_loss_alpha * spec_loss + loss = loss + self.spec_loss_alpha * spec_loss return_dict["loss_spec"] = self.spec_loss_alpha * spec_loss if self.dur_loss_alpha > 0: log_dur_tgt = torch.log(dur_target.float() + 1) dur_loss = self.dur_loss(dur_output[:, :, None], log_dur_tgt[:, :, None], input_lens) - loss += self.dur_loss_alpha * dur_loss + loss = loss + self.dur_loss_alpha * dur_loss return_dict["loss_dur"] = self.dur_loss_alpha * dur_loss if self.pitch_loss_alpha > 0: pitch_loss = self.pitch_loss(pitch_output.transpose(1, 2), pitch_target.transpose(1, 2), input_lens) - loss += self.pitch_loss_alpha * pitch_loss + loss = loss + self.pitch_loss_alpha * pitch_loss return_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss + if self.aligner_loss_alpha > 0: + aligner_loss = self.aligner_loss(alignment_logprob, input_lens, decoder_output_lens) + loss += self.aligner_loss_alpha * aligner_loss + return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss + return_dict["loss"] = loss return return_dict diff --git a/TTS/tts/models/fast_pitch.py b/TTS/tts/models/fast_pitch.py index 60a1654a..6f9cee36 100644 --- a/TTS/tts/models/fast_pitch.py +++ b/TTS/tts/models/fast_pitch.py @@ -4,8 +4,9 @@ import torch import torch.nn as nn import torch.nn.functional as F from coqpit import Coqpit +from matplotlib.pyplot import plot -from TTS.tts.layers.glow_tts.monotonic_align import generate_path +from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.data import sequence_mask from TTS.tts.utils.visual import plot_alignment, plot_spectrogram @@ -14,6 +15,75 @@ from TTS.utils.audio import AudioProcessor # pylint: disable=dangerous-default-value +class AlignmentEncoder(torch.nn.Module): + """Module for alignment text and mel spectrogram.""" + + def __init__( + self, + in_query_channels=80, + in_key_channels=512, + attn_channels=80, + temperature=0.0005, + ): + super().__init__() + self.temperature = temperature + self.softmax = torch.nn.Softmax(dim=3) + self.log_softmax = torch.nn.LogSoftmax(dim=3) + + self.key_proj = nn.Sequential( + ConvNorm( + in_key_channels, in_key_channels * 2, kernel_size=3, bias=True, w_init_gain="relu", batch_norm=False + ), + torch.nn.ReLU(), + ConvNorm(in_key_channels * 2, attn_channels, kernel_size=1, bias=True, batch_norm=False), + ) + + self.query_proj = nn.Sequential( + ConvNorm( + in_query_channels, in_query_channels * 2, kernel_size=3, bias=True, w_init_gain="relu", batch_norm=False + ), + torch.nn.ReLU(), + ConvNorm(in_query_channels * 2, in_query_channels, kernel_size=1, bias=True, batch_norm=False), + torch.nn.ReLU(), + ConvNorm(in_query_channels, attn_channels, kernel_size=1, bias=True, batch_norm=False), + ) + + def forward( + self, queries: torch.tensor, keys: torch.tensor, mask: torch.tensor = None, attn_prior: torch.tensor = None + ): + """Forward pass of the aligner encoder. + Args: + queries (torch.tensor): query tensor. + keys (torch.tensor): key tensor. + mask (torch.tensor): uint8 binary mask for variable length entries (should be in the T2 domain). + attn_prior (torch.tensor): prior for attention matrix. + Shapes: + - queries: :math:`(B, C, T_de)` + - keys: :math:`(B, C_emb, T_en)` + - mask: :math:`(B, T_de)` + Output: + attn (torch.tensor): B x 1 x T1 x T2 attention mask. Final dim T2 should sum to 1. + attn_logprob (torch.tensor): B x 1 x T1 x T2 log-prob attention mask. + """ + keys_enc = self.key_proj(keys) # B x n_attn_dims x T2 + queries_enc = self.query_proj(queries) + + # Simplistic Gaussian Isotopic Attention + attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2 # B x n_attn_dims x T1 x T2 + attn = -self.temperature * attn.sum(1, keepdim=True) + + if attn_prior is not None: + attn = self.log_softmax(attn) + torch.log(attn_prior[:, None] + 1e-8) + + attn_logprob = attn.clone() + + if mask is not None: + attn.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf")) + + attn = self.softmax(attn) # softmax along T2 + return attn, attn_logprob + + def mask_from_lens(lens, max_len: int = None): if max_len is None: max_len = lens.max() @@ -379,6 +449,7 @@ class FastPitchArgs(Coqpit): detach_duration_predictor: bool = False max_duration: int = 75 use_gt_duration: bool = True + use_aligner: bool = True class FastPitch(BaseTTS): @@ -421,6 +492,7 @@ class FastPitch(BaseTTS): self.max_duration = args.max_duration self.use_gt_duration = args.use_gt_duration + self.use_aligner = args.use_aligner self.length_scale = float(args.length_scale) if isinstance(args.length_scale, int) else args.length_scale @@ -463,6 +535,9 @@ class FastPitch(BaseTTS): self.proj = nn.Linear(args.hidden_channels, args.out_channels, bias=True) + if args.use_aligner: + self.aligner = AlignmentEncoder(args.out_channels, args.hidden_channels) + @staticmethod def expand_encoder_outputs(en, dr, x_mask, y_mask): """Generate attention alignment map from durations and @@ -483,10 +558,16 @@ class FastPitch(BaseTTS): o_en_ex = torch.matmul(attn.transpose(1, 2), en) return o_en_ex, attn.transpose(1, 2) - def forward(self, x, x_lengths, y_lengths, dr, pitch, aux_input={"d_vectors": 0, "speaker_ids": None}): + def forward( + self, x, x_lengths, y_lengths, y=None, dr=None, pitch=None, aux_input={"d_vectors": 0, "speaker_ids": None} + ): speaker_embedding = aux_input["d_vectors"] if "d_vectors" in aux_input else 0 y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(x.dtype) x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype) + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + o_alignment_dur = None + alignment_logprob = None + alignment_mas = None # Calculate speaker embedding # if self.speaker_emb is None: @@ -508,6 +589,16 @@ class FastPitch(BaseTTS): o_dr_log = self.duration_predictor(o_en_dr, mask_en_dr) o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration) + # Aligner + if self.use_aligner: + alignment_soft, alignment_logprob = self.aligner(y.transpose(1, 2), embedding.transpose(1, 2), x_mask, None) + alignment_mas = maximum_path( + alignment_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous() + ) + o_alignment_dur = torch.log(1 + torch.sum(alignment_mas, -1)) + avg_pitch = average_pitch(pitch, o_alignment_dur) + dr = o_alignment_dur + # TODO: move this to the dataset avg_pitch = average_pitch(pitch, dr) @@ -529,6 +620,9 @@ class FastPitch(BaseTTS): "pitch": o_pitch, "pitch_gt": avg_pitch, "alignments": attn, + "alignment_mas": alignment_mas, + "o_alignment_dur": o_alignment_dur, + "alignment_logprob": alignment_logprob, } return outputs @@ -599,7 +693,12 @@ class FastPitch(BaseTTS): durations = batch["durations"] aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids} - outputs = self.forward(text_input, text_lengths, mel_lengths, durations, pitch, aux_input) + outputs = self.forward( + text_input, text_lengths, mel_lengths, y=mel_input, dr=durations, pitch=pitch, aux_input=aux_input + ) + + if self.use_aligner: + durations = outputs["o_alignment_dur"] # compute loss loss_dict = criterion( @@ -611,6 +710,7 @@ class FastPitch(BaseTTS): outputs["pitch"], outputs["pitch_gt"], text_lengths, + outputs["alignment_logprob"], ) # compute duration error @@ -634,6 +734,10 @@ class FastPitch(BaseTTS): "alignment": plot_alignment(align_img, output_fig=False), } + if self.config.model_args.use_aligner and self.training: + alignment_mas = outputs["alignment_mas"] + figures["alignment_mas"] = plot_alignment(alignment_mas, ap, output_fig=False) + # Sample audio train_audio = ap.inv_melspectrogram(pred_spec.T) return figures, {"audio": train_audio} diff --git a/recipes/ljspeech/fast_pitch/train_fast_pitch.py b/recipes/ljspeech/fast_pitch/train_fast_pitch.py index 4b852d12..63f50dd9 100644 --- a/recipes/ljspeech/fast_pitch/train_fast_pitch.py +++ b/recipes/ljspeech/fast_pitch/train_fast_pitch.py @@ -11,7 +11,7 @@ output_path = os.path.dirname(os.path.abspath(__file__)) dataset_config = BaseDatasetConfig( name="ljspeech", meta_file_train="metadata.csv", - meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"), + # meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"), path=os.path.join(output_path, "../LJSpeech-1.1/"), ) audio_config = BaseAudioConfig( @@ -46,14 +46,17 @@ config = FastPitchConfig( print_eval=False, mixed_precision=False, output_path=output_path, - datasets=[dataset_config] + datasets=[dataset_config], ) # compute alignments -manager = ModelManager() -model_path, config_path, _ = manager.download_model("tts_models/en/ljspeech/tacotron2-DCA") -# TODO: make compute_attention python callable -os.system(f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true") +if not config.model_args.use_aligner: + manager = ModelManager() + model_path, config_path, _ = manager.download_model("tts_models/en/ljspeech/tacotron2-DCA") + # TODO: make compute_attention python callable + os.system( + f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true" + ) # train the model args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)