Plot unnormalized pitch by `FastPitch`

This commit is contained in:
Eren Gölge 2021-09-06 14:26:36 +00:00
parent 2b59da802c
commit 8d41060d36
1 changed files with 24 additions and 9 deletions

View File

@ -14,9 +14,8 @@ from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.data import sequence_mask from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.soft_dtw import SoftDTW
@dataclass @dataclass
@ -232,7 +231,9 @@ class FastPitch(BaseTTS):
self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1) self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
if self.args.use_aligner: if self.args.use_aligner:
self.aligner = AlignmentNetwork(in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels) self.aligner = AlignmentNetwork(
in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels
)
@staticmethod @staticmethod
def generate_attn(dr, x_mask, y_mask=None): def generate_attn(dr, x_mask, y_mask=None):
@ -307,7 +308,7 @@ class FastPitch(BaseTTS):
return x + g return x + g
def _forward_encoder( def _forward_encoder(
self, x: torch.LongTensor, x_mask:torch.FloatTensor, g: torch.FloatTensor = None self, x: torch.LongTensor, x_mask: torch.FloatTensor, g: torch.FloatTensor = None
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Encoding forward pass. """Encoding forward pass.
@ -522,13 +523,15 @@ class FastPitch(BaseTTS):
"durations_log": o_dr_log.squeeze(1), "durations_log": o_dr_log.squeeze(1),
"durations": o_dr.squeeze(1), "durations": o_dr.squeeze(1),
"attn_durations": o_attn, # for visualization "attn_durations": o_attn, # for visualization
"pitch": o_pitch, "pitch_avg": o_pitch,
"pitch_gt": avg_pitch, "pitch_avg_gt": avg_pitch,
"alignments": attn, "alignments": attn,
"alignment_soft": alignment_soft.transpose(1, 2), "alignment_soft": alignment_soft.transpose(1, 2),
"alignment_mas": alignment_mas.transpose(1, 2), "alignment_mas": alignment_mas.transpose(1, 2),
"o_alignment_dur": o_alignment_dur, "o_alignment_dur": o_alignment_dur,
"alignment_logprob": alignment_logprob, "alignment_logprob": alignment_logprob,
"x_mask": x_mask,
"y_mask": y_mask,
} }
return outputs return outputs
@ -577,6 +580,7 @@ class FastPitch(BaseTTS):
speaker_ids = batch["speaker_ids"] speaker_ids = batch["speaker_ids"]
durations = batch["durations"] durations = batch["durations"]
aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids} aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids}
# forward pass # forward pass
outputs = self.forward( outputs = self.forward(
text_input, text_lengths, mel_lengths, y=mel_input, dr=durations, pitch=pitch, aux_input=aux_input text_input, text_lengths, mel_lengths, y=mel_input, dr=durations, pitch=pitch, aux_input=aux_input
@ -593,12 +597,12 @@ class FastPitch(BaseTTS):
decoder_output_lens=mel_lengths, decoder_output_lens=mel_lengths,
dur_output=outputs["durations_log"], dur_output=outputs["durations_log"],
dur_target=durations, dur_target=durations,
pitch_output=outputs["pitch"], pitch_output=outputs["pitch_avg"],
pitch_target=outputs["pitch_gt"], pitch_target=outputs["pitch_avg_gt"],
input_lens=text_lengths, input_lens=text_lengths,
alignment_logprob=outputs["alignment_logprob"], alignment_logprob=outputs["alignment_logprob"],
alignment_soft=outputs["alignment_soft"] if self.use_binary_alignment_loss else None, alignment_soft=outputs["alignment_soft"] if self.use_binary_alignment_loss else None,
alignment_hard=outputs["alignment_mas"] if self.use_binary_alignment_loss else None alignment_hard=outputs["alignment_mas"] if self.use_binary_alignment_loss else None,
) )
# compute duration error # compute duration error
durations_pred = outputs["durations"] durations_pred = outputs["durations"]
@ -611,15 +615,26 @@ class FastPitch(BaseTTS):
model_outputs = outputs["model_outputs"] model_outputs = outputs["model_outputs"]
alignments = outputs["alignments"] alignments = outputs["alignments"]
mel_input = batch["mel_input"] mel_input = batch["mel_input"]
pitch = batch["pitch"]
pitch_avg_expanded, _ = self.expand_encoder_outputs(
outputs["pitch_avg"], outputs["durations"], outputs["x_mask"], outputs["y_mask"]
)
pred_spec = model_outputs[0].data.cpu().numpy() pred_spec = model_outputs[0].data.cpu().numpy()
gt_spec = mel_input[0].data.cpu().numpy() gt_spec = mel_input[0].data.cpu().numpy()
align_img = alignments[0].data.cpu().numpy() align_img = alignments[0].data.cpu().numpy()
pitch = pitch[0, 0].data.cpu().numpy()
# TODO: denormalize before plotting
pitch = abs(pitch)
pitch_avg_expanded = abs(pitch_avg_expanded[0, 0]).data.cpu().numpy()
figures = { figures = {
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False), "prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
"alignment": plot_alignment(align_img, output_fig=False), "alignment": plot_alignment(align_img, output_fig=False),
"pitch_ground_truth": plot_pitch(pitch, gt_spec, ap, output_fig=False),
"pitch_avg_predicted": plot_pitch(pitch_avg_expanded, pred_spec, ap, output_fig=False),
} }
# plot the attention mask computed from the predicted durations # plot the attention mask computed from the predicted durations