coqui-tts/TTS/tts/models/tacotron2.py

341 lines
14 KiB
Python

# coding: utf-8
from typing import Dict
import torch
from coqpit import Coqpit
from torch import nn
from torch.cuda.amp.autocast_mode import autocast
from TTS.tts.layers.tacotron.gst_layers import GST
from TTS.tts.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet
from TTS.tts.models.base_tacotron import BaseTacotron
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
class Tacotron2(BaseTacotron):
"""Tacotron2 model implementation inherited from :class:`TTS.tts.models.base_tacotron.BaseTacotron`.
Paper::
https://arxiv.org/abs/1712.05884
Paper abstract::
This paper describes Tacotron 2, a neural network architecture for speech synthesis directly from text.
The system is composed of a recurrent sequence-to-sequence feature prediction network that maps character
embeddings to mel-scale spectrograms, followed by a modified WaveNet model acting as a vocoder to synthesize
timedomain waveforms from those spectrograms. Our model achieves a mean opinion score (MOS) of 4.53 comparable
to a MOS of 4.58 for professionally recorded speech. To validate our design choices, we present ablation
studies of key components of our system and evaluate the impact of using mel spectrograms as the input to
WaveNet instead of linguistic, duration, and F0 features. We further demonstrate that using a compact acoustic
intermediate representation enables significant simplification of the WaveNet architecture.
Check :class:`TTS.tts.configs.tacotron2_config.Tacotron2Config` for model arguments.
Args:
config (TacotronConfig):
Configuration for the Tacotron2 model.
speaker_manager (SpeakerManager):
Speaker manager for multi-speaker training. Uuse only for multi-speaker training. Defaults to None.
"""
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
super().__init__(config)
self.speaker_manager = speaker_manager
chars, self.config, _ = self.get_characters(config)
config.num_chars = len(chars)
self.decoder_output_dim = config.out_channels
# pass all config fields to `self`
# for fewer code change
for key in config:
setattr(self, key, config[key])
# init multi-speaker layers
if self.use_speaker_embedding or self.use_d_vector_file:
self.init_multispeaker(config)
self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim
if self.use_gst:
self.decoder_in_features += self.gst.gst_embedding_dim
# embedding layer
self.embedding = nn.Embedding(self.num_chars, 512, padding_idx=0)
# base model layers
self.encoder = Encoder(self.encoder_in_features)
self.decoder = Decoder(
self.decoder_in_features,
self.decoder_output_dim,
self.r,
self.attention_type,
self.attention_win,
self.attention_norm,
self.prenet_type,
self.prenet_dropout,
self.use_forward_attn,
self.transition_agent,
self.forward_attn_mask,
self.location_attn,
self.attention_heads,
self.separate_stopnet,
self.max_decoder_steps,
)
self.postnet = Postnet(self.out_channels)
# setup prenet dropout
self.decoder.prenet.dropout_at_inference = self.prenet_dropout_at_inference
# global style token layers
if self.gst and self.use_gst:
self.gst_layer = GST(
num_mel=self.decoder_output_dim,
num_heads=self.gst.gst_num_heads,
num_style_tokens=self.gst.gst_num_style_tokens,
gst_embedding_dim=self.gst.gst_embedding_dim,
)
# backward pass decoder
if self.bidirectional_decoder:
self._init_backward_decoder()
# setup DDC
if self.double_decoder_consistency:
self.coarse_decoder = Decoder(
self.decoder_in_features,
self.decoder_output_dim,
self.ddc_r,
self.attention_type,
self.attention_win,
self.attention_norm,
self.prenet_type,
self.prenet_dropout,
self.use_forward_attn,
self.transition_agent,
self.forward_attn_mask,
self.location_attn,
self.attention_heads,
self.separate_stopnet,
self.max_decoder_steps,
)
@staticmethod
def shape_outputs(mel_outputs, mel_outputs_postnet, alignments):
"""Final reshape of the model output tensors."""
mel_outputs = mel_outputs.transpose(1, 2)
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
return mel_outputs, mel_outputs_postnet, alignments
def forward( # pylint: disable=dangerous-default-value
self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input={"speaker_ids": None, "d_vectors": None}
):
"""Forward pass for training with Teacher Forcing.
Shapes:
text: :math:`[B, T_in]`
text_lengths: :math:`[B]`
mel_specs: :math:`[B, T_out, C]`
mel_lengths: :math:`[B]`
aux_input: 'speaker_ids': :math:`[B, 1]` and 'd_vectors': :math:`[B, C]`
"""
aux_input = self._format_aux_input(aux_input)
outputs = {"alignments_backward": None, "decoder_outputs_backward": None}
# compute mask for padding
# B x T_in_max (boolean)
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
# B x D_embed x T_in_max
embedded_inputs = self.embedding(text).transpose(1, 2)
# B x T_in_max x D_en
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
if self.gst and self.use_gst:
# B x gst_dim
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs)
if self.use_speaker_embedding or self.use_d_vector_file:
if not self.use_d_vector_file:
# B x 1 x speaker_embed_dim
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[:, None]
else:
# B x 1 x speaker_embed_dim
embedded_speakers = torch.unsqueeze(aux_input["d_vectors"], 1)
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, embedded_speakers)
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
# B x mel_dim x T_out -- B x T_out//r x T_in -- B x T_out//r
decoder_outputs, alignments, stop_tokens = self.decoder(encoder_outputs, mel_specs, input_mask)
# sequence masking
if mel_lengths is not None:
decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs)
# B x mel_dim x T_out
postnet_outputs = self.postnet(decoder_outputs)
postnet_outputs = decoder_outputs + postnet_outputs
# sequence masking
if output_mask is not None:
postnet_outputs = postnet_outputs * output_mask.unsqueeze(1).expand_as(postnet_outputs)
# B x T_out x mel_dim -- B x T_out x mel_dim -- B x T_out//r x T_in
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(decoder_outputs, postnet_outputs, alignments)
if self.bidirectional_decoder:
decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask)
outputs["alignments_backward"] = alignments_backward
outputs["decoder_outputs_backward"] = decoder_outputs_backward
if self.double_decoder_consistency:
decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(
mel_specs, encoder_outputs, alignments, input_mask
)
outputs["alignments_backward"] = alignments_backward
outputs["decoder_outputs_backward"] = decoder_outputs_backward
outputs.update(
{
"model_outputs": postnet_outputs,
"decoder_outputs": decoder_outputs,
"alignments": alignments,
"stop_tokens": stop_tokens,
}
)
return outputs
@torch.no_grad()
def inference(self, text, aux_input=None):
"""Forward pass for inference with no Teacher-Forcing.
Shapes:
text: :math:`[B, T_in]`
text_lengths: :math:`[B]`
"""
aux_input = self._format_aux_input(aux_input)
embedded_inputs = self.embedding(text).transpose(1, 2)
encoder_outputs = self.encoder.inference(embedded_inputs)
if self.gst and self.use_gst:
# B x gst_dim
encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"])
if self.num_speakers > 1:
if not self.use_d_vector_file:
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[None]
# reshape embedded_speakers
if embedded_speakers.ndim == 1:
embedded_speakers = embedded_speakers[None, None, :]
elif embedded_speakers.ndim == 2:
embedded_speakers = embedded_speakers[None, :]
else:
embedded_speakers = aux_input["d_vectors"]
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, embedded_speakers)
decoder_outputs, alignments, stop_tokens = self.decoder.inference(encoder_outputs)
postnet_outputs = self.postnet(decoder_outputs)
postnet_outputs = decoder_outputs + postnet_outputs
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(decoder_outputs, postnet_outputs, alignments)
outputs = {
"model_outputs": postnet_outputs,
"decoder_outputs": decoder_outputs,
"alignments": alignments,
"stop_tokens": stop_tokens,
}
return outputs
def train_step(self, batch: Dict, criterion: torch.nn.Module):
"""A single training step. Forward pass and loss computation.
Args:
batch ([Dict]): A dictionary of input tensors.
criterion ([type]): Callable criterion to compute model loss.
"""
text_input = batch["text_input"]
text_lengths = batch["text_lengths"]
mel_input = batch["mel_input"]
mel_lengths = batch["mel_lengths"]
stop_targets = batch["stop_targets"]
stop_target_lengths = batch["stop_target_lengths"]
speaker_ids = batch["speaker_ids"]
d_vectors = batch["d_vectors"]
# forward pass model
outputs = self.forward(
text_input,
text_lengths,
mel_input,
mel_lengths,
aux_input={"speaker_ids": speaker_ids, "d_vectors": d_vectors},
)
# set the [alignment] lengths wrt reduction factor for guided attention
if mel_lengths.max() % self.decoder.r != 0:
alignment_lengths = (
mel_lengths + (self.decoder.r - (mel_lengths.max() % self.decoder.r))
) // self.decoder.r
else:
alignment_lengths = mel_lengths // self.decoder.r
aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors}
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input)
# compute loss
with autocast(enabled=False): # use float32 for the criterion
loss_dict = criterion(
outputs["model_outputs"].float(),
outputs["decoder_outputs"].float(),
mel_input.float(),
None,
outputs["stop_tokens"].float(),
stop_targets.float(),
stop_target_lengths,
mel_lengths,
None if outputs["decoder_outputs_backward"] is None else outputs["decoder_outputs_backward"].float(),
outputs["alignments"].float(),
alignment_lengths,
None if outputs["alignments_backward"] is None else outputs["alignments_backward"].float(),
text_lengths,
)
# compute alignment error (the lower the better )
align_error = 1 - alignment_diagonal_score(outputs["alignments"])
loss_dict["align_error"] = align_error
return outputs, loss_dict
def _create_logs(self, batch, outputs, ap):
"""Create dashboard log information."""
postnet_outputs = outputs["model_outputs"]
alignments = outputs["alignments"]
alignments_backward = outputs["alignments_backward"]
mel_input = batch["mel_input"]
pred_spec = postnet_outputs[0].data.cpu().numpy()
gt_spec = mel_input[0].data.cpu().numpy()
align_img = alignments[0].data.cpu().numpy()
figures = {
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
"alignment": plot_alignment(align_img, output_fig=False),
}
if self.bidirectional_decoder or self.double_decoder_consistency:
figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False)
# Sample audio
audio = ap.inv_melspectrogram(pred_spec.T)
return figures, {"audio": audio}
def train_log(
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
) -> None: # pylint: disable=no-self-use
"""Log training progress."""
ap = assets["audio_processor"]
figures, audios = self._create_logs(batch, outputs, ap)
logger.train_figures(steps, figures)
logger.train_audios(steps, audios, ap.sample_rate)
def eval_step(self, batch: dict, criterion: nn.Module):
return self.train_step(batch, criterion)
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
ap = assets["audio_processor"]
figures, audios = self._create_logs(batch, outputs, ap)
logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, ap.sample_rate)