Update forwardtts

This commit is contained in:
Eren Gölge 2022-01-25 09:28:33 +00:00
parent bb37462794
commit 34c4be5e49
2 changed files with 17 additions and 18 deletions

View File

@ -740,6 +740,7 @@ class ForwardTTSLoss(nn.Module):
alignment_logprob=None,
alignment_hard=None,
alignment_soft=None,
binary_loss_weight=None
):
loss = 0
return_dict = {}
@ -772,7 +773,10 @@ class ForwardTTSLoss(nn.Module):
if self.binary_alignment_loss_alpha > 0 and alignment_hard is not None:
binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft)
loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss
return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss
if binary_loss_weight:
return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight
else:
return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss
return_dict["loss"] = loss
return return_dict

View File

@ -15,7 +15,7 @@ from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram
from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_spectrogram
@dataclass
@ -186,7 +186,7 @@ class ForwardTTS(BaseTTS):
self.max_duration = self.args.max_duration
self.use_aligner = self.args.use_aligner
self.use_pitch = self.args.use_pitch
self.use_binary_alignment_loss = False
self.binary_loss_weight = 0.0
self.length_scale = (
float(self.args.length_scale) if isinstance(self.args.length_scale, int) else self.args.length_scale
@ -644,8 +644,9 @@ class ForwardTTS(BaseTTS):
pitch_target=outputs["pitch_avg_gt"] if self.use_pitch else None,
input_lens=text_lengths,
alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None,
alignment_soft=outputs["alignment_soft"] if self.use_binary_alignment_loss else None,
alignment_hard=outputs["alignment_mas"] if self.use_binary_alignment_loss else None,
alignment_soft=outputs["alignment_soft"],
alignment_hard=outputs["alignment_mas"],
binary_loss_weight=self.binary_loss_weight
)
# compute duration error
durations_pred = outputs["durations"]
@ -672,17 +673,12 @@ class ForwardTTS(BaseTTS):
# plot pitch figures
if self.args.use_pitch:
pitch = batch["pitch"]
pitch_avg_expanded, _ = self.expand_encoder_outputs(
outputs["pitch_avg"], outputs["durations"], outputs["x_mask"], outputs["y_mask"]
)
pitch = pitch[0, 0].data.cpu().numpy()
# TODO: denormalize before plotting
pitch = abs(pitch)
pitch_avg_expanded = abs(pitch_avg_expanded[0, 0]).data.cpu().numpy()
pitch_avg = abs(outputs["pitch_avg_gt"][0, 0].data.cpu().numpy())
pitch_avg_hat = abs(outputs["pitch_avg"][0, 0].data.cpu().numpy())
chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy())
pitch_figures = {
"pitch_ground_truth": plot_pitch(pitch, gt_spec, ap, output_fig=False),
"pitch_avg_predicted": plot_pitch(pitch_avg_expanded, pred_spec, ap, output_fig=False),
"pitch_ground_truth": plot_avg_pitch(pitch_avg, chars, output_fig=False),
"pitch_avg_predicted": plot_avg_pitch(pitch_avg_hat, chars, output_fig=False),
}
figures.update(pitch_figures)
@ -725,9 +721,8 @@ class ForwardTTS(BaseTTS):
return ForwardTTSLoss(self.config)
def on_train_step_start(self, trainer):
"""Enable binary alignment loss when needed"""
if trainer.total_steps_done > self.config.binary_align_loss_start_step:
self.use_binary_alignment_loss = True
"""Schedule binary loss weight."""
self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0
@staticmethod
def init_from_config(config: "ForwardTTSConfig", samples: Union[List[List], List[Dict]] = None):