mirror of https://github.com/coqui-ai/TTS.git
Update forwardtts
This commit is contained in:
parent
bb37462794
commit
34c4be5e49
|
@ -740,6 +740,7 @@ class ForwardTTSLoss(nn.Module):
|
|||
alignment_logprob=None,
|
||||
alignment_hard=None,
|
||||
alignment_soft=None,
|
||||
binary_loss_weight=None
|
||||
):
|
||||
loss = 0
|
||||
return_dict = {}
|
||||
|
@ -772,7 +773,10 @@ class ForwardTTSLoss(nn.Module):
|
|||
if self.binary_alignment_loss_alpha > 0 and alignment_hard is not None:
|
||||
binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft)
|
||||
loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss
|
||||
return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss
|
||||
if binary_loss_weight:
|
||||
return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight
|
||||
else:
|
||||
return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss
|
||||
|
||||
return_dict["loss"] = loss
|
||||
return return_dict
|
||||
|
|
|
@ -15,7 +15,7 @@ from TTS.tts.models.base_tts import BaseTTS
|
|||
from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_spectrogram
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -186,7 +186,7 @@ class ForwardTTS(BaseTTS):
|
|||
self.max_duration = self.args.max_duration
|
||||
self.use_aligner = self.args.use_aligner
|
||||
self.use_pitch = self.args.use_pitch
|
||||
self.use_binary_alignment_loss = False
|
||||
self.binary_loss_weight = 0.0
|
||||
|
||||
self.length_scale = (
|
||||
float(self.args.length_scale) if isinstance(self.args.length_scale, int) else self.args.length_scale
|
||||
|
@ -644,8 +644,9 @@ class ForwardTTS(BaseTTS):
|
|||
pitch_target=outputs["pitch_avg_gt"] if self.use_pitch else None,
|
||||
input_lens=text_lengths,
|
||||
alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None,
|
||||
alignment_soft=outputs["alignment_soft"] if self.use_binary_alignment_loss else None,
|
||||
alignment_hard=outputs["alignment_mas"] if self.use_binary_alignment_loss else None,
|
||||
alignment_soft=outputs["alignment_soft"],
|
||||
alignment_hard=outputs["alignment_mas"],
|
||||
binary_loss_weight=self.binary_loss_weight
|
||||
)
|
||||
# compute duration error
|
||||
durations_pred = outputs["durations"]
|
||||
|
@ -672,17 +673,12 @@ class ForwardTTS(BaseTTS):
|
|||
|
||||
# plot pitch figures
|
||||
if self.args.use_pitch:
|
||||
pitch = batch["pitch"]
|
||||
pitch_avg_expanded, _ = self.expand_encoder_outputs(
|
||||
outputs["pitch_avg"], outputs["durations"], outputs["x_mask"], outputs["y_mask"]
|
||||
)
|
||||
pitch = pitch[0, 0].data.cpu().numpy()
|
||||
# TODO: denormalize before plotting
|
||||
pitch = abs(pitch)
|
||||
pitch_avg_expanded = abs(pitch_avg_expanded[0, 0]).data.cpu().numpy()
|
||||
pitch_avg = abs(outputs["pitch_avg_gt"][0, 0].data.cpu().numpy())
|
||||
pitch_avg_hat = abs(outputs["pitch_avg"][0, 0].data.cpu().numpy())
|
||||
chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy())
|
||||
pitch_figures = {
|
||||
"pitch_ground_truth": plot_pitch(pitch, gt_spec, ap, output_fig=False),
|
||||
"pitch_avg_predicted": plot_pitch(pitch_avg_expanded, pred_spec, ap, output_fig=False),
|
||||
"pitch_ground_truth": plot_avg_pitch(pitch_avg, chars, output_fig=False),
|
||||
"pitch_avg_predicted": plot_avg_pitch(pitch_avg_hat, chars, output_fig=False),
|
||||
}
|
||||
figures.update(pitch_figures)
|
||||
|
||||
|
@ -725,9 +721,8 @@ class ForwardTTS(BaseTTS):
|
|||
return ForwardTTSLoss(self.config)
|
||||
|
||||
def on_train_step_start(self, trainer):
|
||||
"""Enable binary alignment loss when needed"""
|
||||
if trainer.total_steps_done > self.config.binary_align_loss_start_step:
|
||||
self.use_binary_alignment_loss = True
|
||||
"""Schedule binary loss weight."""
|
||||
self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "ForwardTTSConfig", samples: Union[List[List], List[Dict]] = None):
|
||||
|
|
Loading…
Reference in New Issue