mirror of https://github.com/coqui-ai/TTS.git
Update forwardtts
This commit is contained in:
parent
bb37462794
commit
34c4be5e49
|
@ -740,6 +740,7 @@ class ForwardTTSLoss(nn.Module):
|
||||||
alignment_logprob=None,
|
alignment_logprob=None,
|
||||||
alignment_hard=None,
|
alignment_hard=None,
|
||||||
alignment_soft=None,
|
alignment_soft=None,
|
||||||
|
binary_loss_weight=None
|
||||||
):
|
):
|
||||||
loss = 0
|
loss = 0
|
||||||
return_dict = {}
|
return_dict = {}
|
||||||
|
@ -772,7 +773,10 @@ class ForwardTTSLoss(nn.Module):
|
||||||
if self.binary_alignment_loss_alpha > 0 and alignment_hard is not None:
|
if self.binary_alignment_loss_alpha > 0 and alignment_hard is not None:
|
||||||
binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft)
|
binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft)
|
||||||
loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss
|
loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss
|
||||||
return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss
|
if binary_loss_weight:
|
||||||
|
return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight
|
||||||
|
else:
|
||||||
|
return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss
|
||||||
|
|
||||||
return_dict["loss"] = loss
|
return_dict["loss"] = loss
|
||||||
return return_dict
|
return return_dict
|
||||||
|
|
|
@ -15,7 +15,7 @@ from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask
|
from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_spectrogram
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -186,7 +186,7 @@ class ForwardTTS(BaseTTS):
|
||||||
self.max_duration = self.args.max_duration
|
self.max_duration = self.args.max_duration
|
||||||
self.use_aligner = self.args.use_aligner
|
self.use_aligner = self.args.use_aligner
|
||||||
self.use_pitch = self.args.use_pitch
|
self.use_pitch = self.args.use_pitch
|
||||||
self.use_binary_alignment_loss = False
|
self.binary_loss_weight = 0.0
|
||||||
|
|
||||||
self.length_scale = (
|
self.length_scale = (
|
||||||
float(self.args.length_scale) if isinstance(self.args.length_scale, int) else self.args.length_scale
|
float(self.args.length_scale) if isinstance(self.args.length_scale, int) else self.args.length_scale
|
||||||
|
@ -644,8 +644,9 @@ class ForwardTTS(BaseTTS):
|
||||||
pitch_target=outputs["pitch_avg_gt"] if self.use_pitch else None,
|
pitch_target=outputs["pitch_avg_gt"] if self.use_pitch else None,
|
||||||
input_lens=text_lengths,
|
input_lens=text_lengths,
|
||||||
alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None,
|
alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None,
|
||||||
alignment_soft=outputs["alignment_soft"] if self.use_binary_alignment_loss else None,
|
alignment_soft=outputs["alignment_soft"],
|
||||||
alignment_hard=outputs["alignment_mas"] if self.use_binary_alignment_loss else None,
|
alignment_hard=outputs["alignment_mas"],
|
||||||
|
binary_loss_weight=self.binary_loss_weight
|
||||||
)
|
)
|
||||||
# compute duration error
|
# compute duration error
|
||||||
durations_pred = outputs["durations"]
|
durations_pred = outputs["durations"]
|
||||||
|
@ -672,17 +673,12 @@ class ForwardTTS(BaseTTS):
|
||||||
|
|
||||||
# plot pitch figures
|
# plot pitch figures
|
||||||
if self.args.use_pitch:
|
if self.args.use_pitch:
|
||||||
pitch = batch["pitch"]
|
pitch_avg = abs(outputs["pitch_avg_gt"][0, 0].data.cpu().numpy())
|
||||||
pitch_avg_expanded, _ = self.expand_encoder_outputs(
|
pitch_avg_hat = abs(outputs["pitch_avg"][0, 0].data.cpu().numpy())
|
||||||
outputs["pitch_avg"], outputs["durations"], outputs["x_mask"], outputs["y_mask"]
|
chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy())
|
||||||
)
|
|
||||||
pitch = pitch[0, 0].data.cpu().numpy()
|
|
||||||
# TODO: denormalize before plotting
|
|
||||||
pitch = abs(pitch)
|
|
||||||
pitch_avg_expanded = abs(pitch_avg_expanded[0, 0]).data.cpu().numpy()
|
|
||||||
pitch_figures = {
|
pitch_figures = {
|
||||||
"pitch_ground_truth": plot_pitch(pitch, gt_spec, ap, output_fig=False),
|
"pitch_ground_truth": plot_avg_pitch(pitch_avg, chars, output_fig=False),
|
||||||
"pitch_avg_predicted": plot_pitch(pitch_avg_expanded, pred_spec, ap, output_fig=False),
|
"pitch_avg_predicted": plot_avg_pitch(pitch_avg_hat, chars, output_fig=False),
|
||||||
}
|
}
|
||||||
figures.update(pitch_figures)
|
figures.update(pitch_figures)
|
||||||
|
|
||||||
|
@ -725,9 +721,8 @@ class ForwardTTS(BaseTTS):
|
||||||
return ForwardTTSLoss(self.config)
|
return ForwardTTSLoss(self.config)
|
||||||
|
|
||||||
def on_train_step_start(self, trainer):
|
def on_train_step_start(self, trainer):
|
||||||
"""Enable binary alignment loss when needed"""
|
"""Schedule binary loss weight."""
|
||||||
if trainer.total_steps_done > self.config.binary_align_loss_start_step:
|
self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0
|
||||||
self.use_binary_alignment_loss = True
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def init_from_config(config: "ForwardTTSConfig", samples: Union[List[List], List[Dict]] = None):
|
def init_from_config(config: "ForwardTTSConfig", samples: Union[List[List], List[Dict]] = None):
|
||||||
|
|
Loading…
Reference in New Issue