Plot unnormalized pitch by `FastPitch`

This commit is contained in:
Eren Gölge 2021-09-06 14:26:36 +00:00
parent 2b59da802c
commit 8d41060d36
1 changed files with 24 additions and 9 deletions

View File

@ -14,9 +14,8 @@ from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.utils.soft_dtw import SoftDTW
@dataclass
@ -232,7 +231,9 @@ class FastPitch(BaseTTS):
self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
if self.args.use_aligner:
self.aligner = AlignmentNetwork(in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels)
self.aligner = AlignmentNetwork(
in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels
)
@staticmethod
def generate_attn(dr, x_mask, y_mask=None):
@ -307,7 +308,7 @@ class FastPitch(BaseTTS):
return x + g
def _forward_encoder(
self, x: torch.LongTensor, x_mask:torch.FloatTensor, g: torch.FloatTensor = None
self, x: torch.LongTensor, x_mask: torch.FloatTensor, g: torch.FloatTensor = None
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Encoding forward pass.
@ -522,13 +523,15 @@ class FastPitch(BaseTTS):
"durations_log": o_dr_log.squeeze(1),
"durations": o_dr.squeeze(1),
"attn_durations": o_attn, # for visualization
"pitch": o_pitch,
"pitch_gt": avg_pitch,
"pitch_avg": o_pitch,
"pitch_avg_gt": avg_pitch,
"alignments": attn,
"alignment_soft": alignment_soft.transpose(1, 2),
"alignment_mas": alignment_mas.transpose(1, 2),
"o_alignment_dur": o_alignment_dur,
"alignment_logprob": alignment_logprob,
"x_mask": x_mask,
"y_mask": y_mask,
}
return outputs
@ -577,6 +580,7 @@ class FastPitch(BaseTTS):
speaker_ids = batch["speaker_ids"]
durations = batch["durations"]
aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids}
# forward pass
outputs = self.forward(
text_input, text_lengths, mel_lengths, y=mel_input, dr=durations, pitch=pitch, aux_input=aux_input
@ -593,12 +597,12 @@ class FastPitch(BaseTTS):
decoder_output_lens=mel_lengths,
dur_output=outputs["durations_log"],
dur_target=durations,
pitch_output=outputs["pitch"],
pitch_target=outputs["pitch_gt"],
pitch_output=outputs["pitch_avg"],
pitch_target=outputs["pitch_avg_gt"],
input_lens=text_lengths,
alignment_logprob=outputs["alignment_logprob"],
alignment_soft=outputs["alignment_soft"] if self.use_binary_alignment_loss else None,
alignment_hard=outputs["alignment_mas"] if self.use_binary_alignment_loss else None
alignment_hard=outputs["alignment_mas"] if self.use_binary_alignment_loss else None,
)
# compute duration error
durations_pred = outputs["durations"]
@ -611,15 +615,26 @@ class FastPitch(BaseTTS):
model_outputs = outputs["model_outputs"]
alignments = outputs["alignments"]
mel_input = batch["mel_input"]
pitch = batch["pitch"]
pitch_avg_expanded, _ = self.expand_encoder_outputs(
outputs["pitch_avg"], outputs["durations"], outputs["x_mask"], outputs["y_mask"]
)
pred_spec = model_outputs[0].data.cpu().numpy()
gt_spec = mel_input[0].data.cpu().numpy()
align_img = alignments[0].data.cpu().numpy()
pitch = pitch[0, 0].data.cpu().numpy()
# TODO: denormalize before plotting
pitch = abs(pitch)
pitch_avg_expanded = abs(pitch_avg_expanded[0, 0]).data.cpu().numpy()
figures = {
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
"alignment": plot_alignment(align_img, output_fig=False),
"pitch_ground_truth": plot_pitch(pitch, gt_spec, ap, output_fig=False),
"pitch_avg_predicted": plot_pitch(pitch_avg_expanded, pred_spec, ap, output_fig=False),
}
# plot the attention mask computed from the predicted durations