Refactor multi-speaker init in BaseTTS-Tacotron1-2

This commit is contained in:
Eren Gölge 2021-10-18 08:55:45 +00:00
parent 127571423c
commit c514351c0e
4 changed files with 94 additions and 59 deletions

View File

@ -3,11 +3,13 @@
import torch import torch
from coqpit import Coqpit from coqpit import Coqpit
from torch import nn from torch import nn
from torch.cuda.amp.autocast_mode import autocast
from TTS.tts.layers.tacotron.gst_layers import GST from TTS.tts.layers.tacotron.gst_layers import GST
from TTS.tts.layers.tacotron.tacotron import Decoder, Encoder, PostCBHG from TTS.tts.layers.tacotron.tacotron import Decoder, Encoder, PostCBHG
from TTS.tts.models.base_tacotron import BaseTacotron from TTS.tts.models.base_tacotron import BaseTacotron
from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
@ -15,11 +17,17 @@ class Tacotron(BaseTacotron):
"""Tacotron as in https://arxiv.org/abs/1703.10135 """Tacotron as in https://arxiv.org/abs/1703.10135
It's an autoregressive encoder-attention-decoder-postnet architecture. It's an autoregressive encoder-attention-decoder-postnet architecture.
Check `TacotronConfig` for the arguments. Check `TacotronConfig` for the arguments.
Args:
config (TacotronConfig): Configuration for the Tacotron model.
speaker_manager (SpeakerManager): Speaker manager to handle multi-speaker settings. Only use if the model is
a multi-speaker model. Defaults to None.
""" """
def __init__(self, config: Coqpit): def __init__(self, config: Coqpit, speaker_manager: SpeakerManager=None):
super().__init__(config) super().__init__(config)
self.speaker_manager = speaker_manager
chars, self.config, _ = self.get_characters(config) chars, self.config, _ = self.get_characters(config)
config.num_chars = self.num_chars = len(chars) config.num_chars = self.num_chars = len(chars)
@ -240,21 +248,22 @@ class Tacotron(BaseTacotron):
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input) outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input)
# compute loss # compute loss
loss_dict = criterion( with autocast(enabled=False): # use float32 for the criterion
outputs["model_outputs"], loss_dict = criterion(
outputs["decoder_outputs"], outputs["model_outputs"].float(),
mel_input, outputs["decoder_outputs"].float(),
linear_input, mel_input.float(),
outputs["stop_tokens"], linear_input.float(),
stop_targets, outputs["stop_tokens"].float(),
stop_target_lengths, stop_targets.float(),
mel_lengths, stop_target_lengths,
outputs["decoder_outputs_backward"], mel_lengths,
outputs["alignments"], outputs["decoder_outputs_backward"].float(),
alignment_lengths, outputs["alignments"].float(),
outputs["alignments_backward"], alignment_lengths,
text_lengths, outputs["alignments_backward"].float(),
) text_lengths,
)
# compute alignment error (the lower the better ) # compute alignment error (the lower the better )
align_error = 1 - alignment_diagonal_score(outputs["alignments"]) align_error = 1 - alignment_diagonal_score(outputs["alignments"])
@ -263,17 +272,23 @@ class Tacotron(BaseTacotron):
def _create_logs(self, batch, outputs, ap): def _create_logs(self, batch, outputs, ap):
postnet_outputs = outputs["model_outputs"] postnet_outputs = outputs["model_outputs"]
decoder_outputs = outputs["decoder_outputs"]
alignments = outputs["alignments"] alignments = outputs["alignments"]
alignments_backward = outputs["alignments_backward"] alignments_backward = outputs["alignments_backward"]
mel_input = batch["mel_input"] mel_input = batch["mel_input"]
linear_input = batch["linear_input"]
pred_spec = postnet_outputs[0].data.cpu().numpy() pred_linear_spec = postnet_outputs[0].data.cpu().numpy()
gt_spec = mel_input[0].data.cpu().numpy() pred_mel_spec = decoder_outputs[0].data.cpu().numpy()
gt_linear_spec = linear_input[0].data.cpu().numpy()
gt_mel_spec = mel_input[0].data.cpu().numpy()
align_img = alignments[0].data.cpu().numpy() align_img = alignments[0].data.cpu().numpy()
figures = { figures = {
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False), "pred_linear_spec": plot_spectrogram(pred_linear_spec, ap, output_fig=False),
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), "real_linear_spec": plot_spectrogram(gt_linear_spec, ap, output_fig=False),
"pred_mel_spec": plot_spectrogram(pred_mel_spec, ap, output_fig=False),
"real_mel_spec": plot_spectrogram(gt_mel_spec, ap, output_fig=False),
"alignment": plot_alignment(align_img, output_fig=False), "alignment": plot_alignment(align_img, output_fig=False),
} }
@ -281,7 +296,7 @@ class Tacotron(BaseTacotron):
figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False) figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False)
# Sample audio # Sample audio
audio = ap.inv_spectrogram(pred_spec.T) audio = ap.inv_spectrogram(pred_linear_spec.T)
return figures, {"audio": audio} return figures, {"audio": audio}
def train_log( def train_log(

View File

@ -3,22 +3,45 @@
import torch import torch
from coqpit import Coqpit from coqpit import Coqpit
from torch import nn from torch import nn
from torch.cuda.amp.autocast_mode import autocast
from TTS.tts.layers.tacotron.gst_layers import GST from TTS.tts.layers.tacotron.gst_layers import GST
from TTS.tts.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet from TTS.tts.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet
from TTS.tts.models.base_tacotron import BaseTacotron from TTS.tts.models.base_tacotron import BaseTacotron
from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
class Tacotron2(BaseTacotron): class Tacotron2(BaseTacotron):
"""Tacotron2 as in https://arxiv.org/abs/1712.05884 """Tacotron2 model implementation inherited from :class:`TTS.tts.models.base_tacotron.BaseTacotron`.
Check `TacotronConfig` for the arguments.
Paper::
https://arxiv.org/abs/1712.05884
Paper abstract::
This paper describes Tacotron 2, a neural network architecture for speech synthesis directly from text.
The system is composed of a recurrent sequence-to-sequence feature prediction network that maps character
embeddings to mel-scale spectrograms, followed by a modified WaveNet model acting as a vocoder to synthesize
timedomain waveforms from those spectrograms. Our model achieves a mean opinion score (MOS) of 4.53 comparable
to a MOS of 4.58 for professionally recorded speech. To validate our design choices, we present ablation
studies of key components of our system and evaluate the impact of using mel spectrograms as the input to
WaveNet instead of linguistic, duration, and F0 features. We further demonstrate that using a compact acoustic
intermediate representation enables significant simplification of the WaveNet architecture.
Check :class:`TTS.tts.configs.tacotron2_config.Tacotron2Config` for model arguments.
Args:
config (TacotronConfig):
Configuration for the Tacotron2 model.
speaker_manager (SpeakerManager):
Speaker manager for multi-speaker training. Uuse only for multi-speaker training. Defaults to None.
""" """
def __init__(self, config: Coqpit): def __init__(self, config: Coqpit, speaker_manager: SpeakerManager=None):
super().__init__(config) super().__init__(config)
self.speaker_manager = speaker_manager
chars, self.config, _ = self.get_characters(config) chars, self.config, _ = self.get_characters(config)
config.num_chars = len(chars) config.num_chars = len(chars)
self.decoder_output_dim = config.out_channels self.decoder_output_dim = config.out_channels
@ -28,9 +51,7 @@ class Tacotron2(BaseTacotron):
for key in config: for key in config:
setattr(self, key, config[key]) setattr(self, key, config[key])
# set speaker embedding channel size for determining `in_channels` for the connected layers. # init multi-speaker layers
# `init_multispeaker` needs to be called once more in training to initialize the speaker embedding layer based
# on the number of speakers infered from the dataset.
if self.use_speaker_embedding or self.use_d_vector_file: if self.use_speaker_embedding or self.use_d_vector_file:
self.init_multispeaker(config) self.init_multispeaker(config)
self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim
@ -100,6 +121,7 @@ class Tacotron2(BaseTacotron):
@staticmethod @staticmethod
def shape_outputs(mel_outputs, mel_outputs_postnet, alignments): def shape_outputs(mel_outputs, mel_outputs_postnet, alignments):
"""Final reshape of the model output tensors."""
mel_outputs = mel_outputs.transpose(1, 2) mel_outputs = mel_outputs.transpose(1, 2)
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2) mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
return mel_outputs, mel_outputs_postnet, alignments return mel_outputs, mel_outputs_postnet, alignments
@ -107,7 +129,8 @@ class Tacotron2(BaseTacotron):
def forward( # pylint: disable=dangerous-default-value def forward( # pylint: disable=dangerous-default-value
self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input={"speaker_ids": None, "d_vectors": None} self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input={"speaker_ids": None, "d_vectors": None}
): ):
""" """Forward pass for training with Teacher Forcing.
Shapes: Shapes:
text: [B, T_in] text: [B, T_in]
text_lengths: [B] text_lengths: [B]
@ -174,6 +197,12 @@ class Tacotron2(BaseTacotron):
@torch.no_grad() @torch.no_grad()
def inference(self, text, aux_input=None): def inference(self, text, aux_input=None):
"""Forward pass for inference with no Teacher-Forcing.
Shapes:
text: :math:`[B, T_in]`
text_lengths: :math:`[B]`
"""
aux_input = self._format_aux_input(aux_input) aux_input = self._format_aux_input(aux_input)
embedded_inputs = self.embedding(text).transpose(1, 2) embedded_inputs = self.embedding(text).transpose(1, 2)
encoder_outputs = self.encoder.inference(embedded_inputs) encoder_outputs = self.encoder.inference(embedded_inputs)
@ -208,7 +237,7 @@ class Tacotron2(BaseTacotron):
return outputs return outputs
def train_step(self, batch, criterion): def train_step(self, batch, criterion):
"""Perform a single training step by fetching the right set if samples from the batch. """A single training step. Forward pass and loss computation.
Args: Args:
batch ([type]): [description] batch ([type]): [description]
@ -218,7 +247,6 @@ class Tacotron2(BaseTacotron):
text_lengths = batch["text_lengths"] text_lengths = batch["text_lengths"]
mel_input = batch["mel_input"] mel_input = batch["mel_input"]
mel_lengths = batch["mel_lengths"] mel_lengths = batch["mel_lengths"]
linear_input = batch["linear_input"]
stop_targets = batch["stop_targets"] stop_targets = batch["stop_targets"]
stop_target_lengths = batch["stop_target_lengths"] stop_target_lengths = batch["stop_target_lengths"]
speaker_ids = batch["speaker_ids"] speaker_ids = batch["speaker_ids"]
@ -245,21 +273,22 @@ class Tacotron2(BaseTacotron):
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input) outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input)
# compute loss # compute loss
loss_dict = criterion( with autocast(enabled=False): # use float32 for the criterion
outputs["model_outputs"], loss_dict = criterion(
outputs["decoder_outputs"], outputs["model_outputs"].float(),
mel_input, outputs["decoder_outputs"].float(),
linear_input, mel_input.float(),
outputs["stop_tokens"], None,
stop_targets, outputs["stop_tokens"].float(),
stop_target_lengths, stop_targets.float(),
mel_lengths, stop_target_lengths,
outputs["decoder_outputs_backward"], mel_lengths,
outputs["alignments"], None if outputs["decoder_outputs_backward"] is None else outputs["decoder_outputs_backward"].float(),
alignment_lengths, outputs["alignments"].float(),
outputs["alignments_backward"], alignment_lengths,
text_lengths, None if outputs["alignments_backward"] is None else outputs["alignments_backward"].float(),
) text_lengths,
)
# compute alignment error (the lower the better ) # compute alignment error (the lower the better )
align_error = 1 - alignment_diagonal_score(outputs["alignments"]) align_error = 1 - alignment_diagonal_score(outputs["alignments"])

View File

@ -217,12 +217,13 @@ class Vits(BaseTTS):
# pylint: disable=dangerous-default-value # pylint: disable=dangerous-default-value
def __init__(self, config: Coqpit): def __init__(self, config: Coqpit, speaker_manager: SpeakerManager=None):
super().__init__(config) super().__init__(config)
self.END2END = True self.END2END = True
self.speaker_manager = speaker_manager
if config.__class__.__name__ == "VitsConfig": if config.__class__.__name__ == "VitsConfig":
# loading from VitsConfig # loading from VitsConfig
if "num_chars" not in config: if "num_chars" not in config:
@ -314,7 +315,7 @@ class Vits(BaseTTS):
if args.init_discriminator: if args.init_discriminator:
self.disc = VitsDiscriminator(use_spectral_norm=args.use_spectral_norm_disriminator) self.disc = VitsDiscriminator(use_spectral_norm=args.use_spectral_norm_disriminator)
def init_multispeaker(self, config: Coqpit, data: List = None): def init_multispeaker(self, config: Coqpit):
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer """Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
or with external `d_vectors` computed from a speaker encoder model. or with external `d_vectors` computed from a speaker encoder model.
@ -351,18 +352,6 @@ class Vits(BaseTTS):
self.speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file) self.speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file)
self.embedded_speaker_dim = config.d_vector_dim self.embedded_speaker_dim = config.d_vector_dim
def on_init_start(self, trainer):
"""Save the speaker.json at the beginning of the training. And update the config.json with the
speakers.json file path."""
if self.speaker_manager is not None:
output_path = os.path.join(trainer.output_path, "speakers.json")
self.speaker_manager.save_speaker_ids_to_file(output_path)
trainer.config.speakers_file = output_path
trainer.config.model_args.speakers_file = output_path
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
print(f" > `speakers.json` is saved to {output_path}.")
print(" > `speakers_file` is updated in the config.json.")
@staticmethod @staticmethod
def _set_cond_input(aux_input: Dict): def _set_cond_input(aux_input: Dict):
"""Set the speaker conditioning input based on the multi-speaker mode.""" """Set the speaker conditioning input based on the multi-speaker mode."""

View File

@ -108,6 +108,8 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
class AudioProcessor(object): class AudioProcessor(object):
"""Audio Processor for TTS used by all the data pipelines. """Audio Processor for TTS used by all the data pipelines.
TODO: Make this a dataclass to replace `BaseAudioConfig`.
Note: Note:
All the class arguments are set to default values to enable a flexible initialization All the class arguments are set to default values to enable a flexible initialization
of the class with the model config. They are not meaningful for all the arguments. of the class with the model config. They are not meaningful for all the arguments.