mirror of https://github.com/coqui-ai/TTS.git
Plot unnormalized pitch by `FastPitch`
This commit is contained in:
parent
2b59da802c
commit
8d41060d36
|
@ -14,9 +14,8 @@ from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||||
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.tts.utils.data import sequence_mask
|
from TTS.tts.utils.data import sequence_mask
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.soft_dtw import SoftDTW
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -232,7 +231,9 @@ class FastPitch(BaseTTS):
|
||||||
self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
|
self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
|
||||||
|
|
||||||
if self.args.use_aligner:
|
if self.args.use_aligner:
|
||||||
self.aligner = AlignmentNetwork(in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels)
|
self.aligner = AlignmentNetwork(
|
||||||
|
in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_attn(dr, x_mask, y_mask=None):
|
def generate_attn(dr, x_mask, y_mask=None):
|
||||||
|
@ -522,13 +523,15 @@ class FastPitch(BaseTTS):
|
||||||
"durations_log": o_dr_log.squeeze(1),
|
"durations_log": o_dr_log.squeeze(1),
|
||||||
"durations": o_dr.squeeze(1),
|
"durations": o_dr.squeeze(1),
|
||||||
"attn_durations": o_attn, # for visualization
|
"attn_durations": o_attn, # for visualization
|
||||||
"pitch": o_pitch,
|
"pitch_avg": o_pitch,
|
||||||
"pitch_gt": avg_pitch,
|
"pitch_avg_gt": avg_pitch,
|
||||||
"alignments": attn,
|
"alignments": attn,
|
||||||
"alignment_soft": alignment_soft.transpose(1, 2),
|
"alignment_soft": alignment_soft.transpose(1, 2),
|
||||||
"alignment_mas": alignment_mas.transpose(1, 2),
|
"alignment_mas": alignment_mas.transpose(1, 2),
|
||||||
"o_alignment_dur": o_alignment_dur,
|
"o_alignment_dur": o_alignment_dur,
|
||||||
"alignment_logprob": alignment_logprob,
|
"alignment_logprob": alignment_logprob,
|
||||||
|
"x_mask": x_mask,
|
||||||
|
"y_mask": y_mask,
|
||||||
}
|
}
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@ -577,6 +580,7 @@ class FastPitch(BaseTTS):
|
||||||
speaker_ids = batch["speaker_ids"]
|
speaker_ids = batch["speaker_ids"]
|
||||||
durations = batch["durations"]
|
durations = batch["durations"]
|
||||||
aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids}
|
aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids}
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
outputs = self.forward(
|
outputs = self.forward(
|
||||||
text_input, text_lengths, mel_lengths, y=mel_input, dr=durations, pitch=pitch, aux_input=aux_input
|
text_input, text_lengths, mel_lengths, y=mel_input, dr=durations, pitch=pitch, aux_input=aux_input
|
||||||
|
@ -593,12 +597,12 @@ class FastPitch(BaseTTS):
|
||||||
decoder_output_lens=mel_lengths,
|
decoder_output_lens=mel_lengths,
|
||||||
dur_output=outputs["durations_log"],
|
dur_output=outputs["durations_log"],
|
||||||
dur_target=durations,
|
dur_target=durations,
|
||||||
pitch_output=outputs["pitch"],
|
pitch_output=outputs["pitch_avg"],
|
||||||
pitch_target=outputs["pitch_gt"],
|
pitch_target=outputs["pitch_avg_gt"],
|
||||||
input_lens=text_lengths,
|
input_lens=text_lengths,
|
||||||
alignment_logprob=outputs["alignment_logprob"],
|
alignment_logprob=outputs["alignment_logprob"],
|
||||||
alignment_soft=outputs["alignment_soft"] if self.use_binary_alignment_loss else None,
|
alignment_soft=outputs["alignment_soft"] if self.use_binary_alignment_loss else None,
|
||||||
alignment_hard=outputs["alignment_mas"] if self.use_binary_alignment_loss else None
|
alignment_hard=outputs["alignment_mas"] if self.use_binary_alignment_loss else None,
|
||||||
)
|
)
|
||||||
# compute duration error
|
# compute duration error
|
||||||
durations_pred = outputs["durations"]
|
durations_pred = outputs["durations"]
|
||||||
|
@ -611,15 +615,26 @@ class FastPitch(BaseTTS):
|
||||||
model_outputs = outputs["model_outputs"]
|
model_outputs = outputs["model_outputs"]
|
||||||
alignments = outputs["alignments"]
|
alignments = outputs["alignments"]
|
||||||
mel_input = batch["mel_input"]
|
mel_input = batch["mel_input"]
|
||||||
|
pitch = batch["pitch"]
|
||||||
|
pitch_avg_expanded, _ = self.expand_encoder_outputs(
|
||||||
|
outputs["pitch_avg"], outputs["durations"], outputs["x_mask"], outputs["y_mask"]
|
||||||
|
)
|
||||||
|
|
||||||
pred_spec = model_outputs[0].data.cpu().numpy()
|
pred_spec = model_outputs[0].data.cpu().numpy()
|
||||||
gt_spec = mel_input[0].data.cpu().numpy()
|
gt_spec = mel_input[0].data.cpu().numpy()
|
||||||
align_img = alignments[0].data.cpu().numpy()
|
align_img = alignments[0].data.cpu().numpy()
|
||||||
|
pitch = pitch[0, 0].data.cpu().numpy()
|
||||||
|
|
||||||
|
# TODO: denormalize before plotting
|
||||||
|
pitch = abs(pitch)
|
||||||
|
pitch_avg_expanded = abs(pitch_avg_expanded[0, 0]).data.cpu().numpy()
|
||||||
|
|
||||||
figures = {
|
figures = {
|
||||||
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
|
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
|
||||||
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
|
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
|
||||||
"alignment": plot_alignment(align_img, output_fig=False),
|
"alignment": plot_alignment(align_img, output_fig=False),
|
||||||
|
"pitch_ground_truth": plot_pitch(pitch, gt_spec, ap, output_fig=False),
|
||||||
|
"pitch_avg_predicted": plot_pitch(pitch_avg_expanded, pred_spec, ap, output_fig=False),
|
||||||
}
|
}
|
||||||
|
|
||||||
# plot the attention mask computed from the predicted durations
|
# plot the attention mask computed from the predicted durations
|
||||||
|
|
Loading…
Reference in New Issue