mirror of https://github.com/coqui-ai/TTS.git
Plot unnormalized pitch by `FastPitch`
This commit is contained in:
parent
2b59da802c
commit
8d41060d36
|
@ -14,9 +14,8 @@ from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
|||
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.data import sequence_mask
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.soft_dtw import SoftDTW
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -232,7 +231,9 @@ class FastPitch(BaseTTS):
|
|||
self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
|
||||
|
||||
if self.args.use_aligner:
|
||||
self.aligner = AlignmentNetwork(in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels)
|
||||
self.aligner = AlignmentNetwork(
|
||||
in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def generate_attn(dr, x_mask, y_mask=None):
|
||||
|
@ -307,7 +308,7 @@ class FastPitch(BaseTTS):
|
|||
return x + g
|
||||
|
||||
def _forward_encoder(
|
||||
self, x: torch.LongTensor, x_mask:torch.FloatTensor, g: torch.FloatTensor = None
|
||||
self, x: torch.LongTensor, x_mask: torch.FloatTensor, g: torch.FloatTensor = None
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
"""Encoding forward pass.
|
||||
|
||||
|
@ -522,13 +523,15 @@ class FastPitch(BaseTTS):
|
|||
"durations_log": o_dr_log.squeeze(1),
|
||||
"durations": o_dr.squeeze(1),
|
||||
"attn_durations": o_attn, # for visualization
|
||||
"pitch": o_pitch,
|
||||
"pitch_gt": avg_pitch,
|
||||
"pitch_avg": o_pitch,
|
||||
"pitch_avg_gt": avg_pitch,
|
||||
"alignments": attn,
|
||||
"alignment_soft": alignment_soft.transpose(1, 2),
|
||||
"alignment_mas": alignment_mas.transpose(1, 2),
|
||||
"o_alignment_dur": o_alignment_dur,
|
||||
"alignment_logprob": alignment_logprob,
|
||||
"x_mask": x_mask,
|
||||
"y_mask": y_mask,
|
||||
}
|
||||
return outputs
|
||||
|
||||
|
@ -577,6 +580,7 @@ class FastPitch(BaseTTS):
|
|||
speaker_ids = batch["speaker_ids"]
|
||||
durations = batch["durations"]
|
||||
aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids}
|
||||
|
||||
# forward pass
|
||||
outputs = self.forward(
|
||||
text_input, text_lengths, mel_lengths, y=mel_input, dr=durations, pitch=pitch, aux_input=aux_input
|
||||
|
@ -593,12 +597,12 @@ class FastPitch(BaseTTS):
|
|||
decoder_output_lens=mel_lengths,
|
||||
dur_output=outputs["durations_log"],
|
||||
dur_target=durations,
|
||||
pitch_output=outputs["pitch"],
|
||||
pitch_target=outputs["pitch_gt"],
|
||||
pitch_output=outputs["pitch_avg"],
|
||||
pitch_target=outputs["pitch_avg_gt"],
|
||||
input_lens=text_lengths,
|
||||
alignment_logprob=outputs["alignment_logprob"],
|
||||
alignment_soft=outputs["alignment_soft"] if self.use_binary_alignment_loss else None,
|
||||
alignment_hard=outputs["alignment_mas"] if self.use_binary_alignment_loss else None
|
||||
alignment_hard=outputs["alignment_mas"] if self.use_binary_alignment_loss else None,
|
||||
)
|
||||
# compute duration error
|
||||
durations_pred = outputs["durations"]
|
||||
|
@ -611,15 +615,26 @@ class FastPitch(BaseTTS):
|
|||
model_outputs = outputs["model_outputs"]
|
||||
alignments = outputs["alignments"]
|
||||
mel_input = batch["mel_input"]
|
||||
pitch = batch["pitch"]
|
||||
pitch_avg_expanded, _ = self.expand_encoder_outputs(
|
||||
outputs["pitch_avg"], outputs["durations"], outputs["x_mask"], outputs["y_mask"]
|
||||
)
|
||||
|
||||
pred_spec = model_outputs[0].data.cpu().numpy()
|
||||
gt_spec = mel_input[0].data.cpu().numpy()
|
||||
align_img = alignments[0].data.cpu().numpy()
|
||||
pitch = pitch[0, 0].data.cpu().numpy()
|
||||
|
||||
# TODO: denormalize before plotting
|
||||
pitch = abs(pitch)
|
||||
pitch_avg_expanded = abs(pitch_avg_expanded[0, 0]).data.cpu().numpy()
|
||||
|
||||
figures = {
|
||||
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
|
||||
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
|
||||
"alignment": plot_alignment(align_img, output_fig=False),
|
||||
"pitch_ground_truth": plot_pitch(pitch, gt_spec, ap, output_fig=False),
|
||||
"pitch_avg_predicted": plot_pitch(pitch_avg_expanded, pred_spec, ap, output_fig=False),
|
||||
}
|
||||
|
||||
# plot the attention mask computed from the predicted durations
|
||||
|
|
Loading…
Reference in New Issue