diff --git a/TTS/tts/models/fast_pitch.py b/TTS/tts/models/fast_pitch.py index 352aebfa..1dd0bd68 100644 --- a/TTS/tts/models/fast_pitch.py +++ b/TTS/tts/models/fast_pitch.py @@ -14,9 +14,8 @@ from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.data import sequence_mask -from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram from TTS.utils.audio import AudioProcessor -from TTS.utils.soft_dtw import SoftDTW @dataclass @@ -232,7 +231,9 @@ class FastPitch(BaseTTS): self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1) if self.args.use_aligner: - self.aligner = AlignmentNetwork(in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels) + self.aligner = AlignmentNetwork( + in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels + ) @staticmethod def generate_attn(dr, x_mask, y_mask=None): @@ -307,7 +308,7 @@ class FastPitch(BaseTTS): return x + g def _forward_encoder( - self, x: torch.LongTensor, x_mask:torch.FloatTensor, g: torch.FloatTensor = None + self, x: torch.LongTensor, x_mask: torch.FloatTensor, g: torch.FloatTensor = None ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: """Encoding forward pass. @@ -522,13 +523,15 @@ class FastPitch(BaseTTS): "durations_log": o_dr_log.squeeze(1), "durations": o_dr.squeeze(1), "attn_durations": o_attn, # for visualization - "pitch": o_pitch, - "pitch_gt": avg_pitch, + "pitch_avg": o_pitch, + "pitch_avg_gt": avg_pitch, "alignments": attn, "alignment_soft": alignment_soft.transpose(1, 2), "alignment_mas": alignment_mas.transpose(1, 2), "o_alignment_dur": o_alignment_dur, "alignment_logprob": alignment_logprob, + "x_mask": x_mask, + "y_mask": y_mask, } return outputs @@ -577,6 +580,7 @@ class FastPitch(BaseTTS): speaker_ids = batch["speaker_ids"] durations = batch["durations"] aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids} + # forward pass outputs = self.forward( text_input, text_lengths, mel_lengths, y=mel_input, dr=durations, pitch=pitch, aux_input=aux_input @@ -593,12 +597,12 @@ class FastPitch(BaseTTS): decoder_output_lens=mel_lengths, dur_output=outputs["durations_log"], dur_target=durations, - pitch_output=outputs["pitch"], - pitch_target=outputs["pitch_gt"], + pitch_output=outputs["pitch_avg"], + pitch_target=outputs["pitch_avg_gt"], input_lens=text_lengths, alignment_logprob=outputs["alignment_logprob"], alignment_soft=outputs["alignment_soft"] if self.use_binary_alignment_loss else None, - alignment_hard=outputs["alignment_mas"] if self.use_binary_alignment_loss else None + alignment_hard=outputs["alignment_mas"] if self.use_binary_alignment_loss else None, ) # compute duration error durations_pred = outputs["durations"] @@ -611,15 +615,26 @@ class FastPitch(BaseTTS): model_outputs = outputs["model_outputs"] alignments = outputs["alignments"] mel_input = batch["mel_input"] + pitch = batch["pitch"] + pitch_avg_expanded, _ = self.expand_encoder_outputs( + outputs["pitch_avg"], outputs["durations"], outputs["x_mask"], outputs["y_mask"] + ) pred_spec = model_outputs[0].data.cpu().numpy() gt_spec = mel_input[0].data.cpu().numpy() align_img = alignments[0].data.cpu().numpy() + pitch = pitch[0, 0].data.cpu().numpy() + + # TODO: denormalize before plotting + pitch = abs(pitch) + pitch_avg_expanded = abs(pitch_avg_expanded[0, 0]).data.cpu().numpy() figures = { "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), "alignment": plot_alignment(align_img, output_fig=False), + "pitch_ground_truth": plot_pitch(pitch, gt_spec, ap, output_fig=False), + "pitch_avg_predicted": plot_pitch(pitch_avg_expanded, pred_spec, ap, output_fig=False), } # plot the attention mask computed from the predicted durations