Refactor multi-speaker init in BaseTTS-Tacotron1-2

This commit is contained in:
Eren Gölge 2021-10-18 08:55:45 +00:00
parent 127571423c
commit c514351c0e
4 changed files with 94 additions and 59 deletions

View File

@ -3,11 +3,13 @@
import torch
from coqpit import Coqpit
from torch import nn
from torch.cuda.amp.autocast_mode import autocast
from TTS.tts.layers.tacotron.gst_layers import GST
from TTS.tts.layers.tacotron.tacotron import Decoder, Encoder, PostCBHG
from TTS.tts.models.base_tacotron import BaseTacotron
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
@ -15,11 +17,17 @@ class Tacotron(BaseTacotron):
"""Tacotron as in https://arxiv.org/abs/1703.10135
It's an autoregressive encoder-attention-decoder-postnet architecture.
Check `TacotronConfig` for the arguments.
Args:
config (TacotronConfig): Configuration for the Tacotron model.
speaker_manager (SpeakerManager): Speaker manager to handle multi-speaker settings. Only use if the model is
a multi-speaker model. Defaults to None.
"""
def __init__(self, config: Coqpit):
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager=None):
super().__init__(config)
self.speaker_manager = speaker_manager
chars, self.config, _ = self.get_characters(config)
config.num_chars = self.num_chars = len(chars)
@ -240,21 +248,22 @@ class Tacotron(BaseTacotron):
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input)
# compute loss
loss_dict = criterion(
outputs["model_outputs"],
outputs["decoder_outputs"],
mel_input,
linear_input,
outputs["stop_tokens"],
stop_targets,
stop_target_lengths,
mel_lengths,
outputs["decoder_outputs_backward"],
outputs["alignments"],
alignment_lengths,
outputs["alignments_backward"],
text_lengths,
)
with autocast(enabled=False): # use float32 for the criterion
loss_dict = criterion(
outputs["model_outputs"].float(),
outputs["decoder_outputs"].float(),
mel_input.float(),
linear_input.float(),
outputs["stop_tokens"].float(),
stop_targets.float(),
stop_target_lengths,
mel_lengths,
outputs["decoder_outputs_backward"].float(),
outputs["alignments"].float(),
alignment_lengths,
outputs["alignments_backward"].float(),
text_lengths,
)
# compute alignment error (the lower the better )
align_error = 1 - alignment_diagonal_score(outputs["alignments"])
@ -263,17 +272,23 @@ class Tacotron(BaseTacotron):
def _create_logs(self, batch, outputs, ap):
postnet_outputs = outputs["model_outputs"]
decoder_outputs = outputs["decoder_outputs"]
alignments = outputs["alignments"]
alignments_backward = outputs["alignments_backward"]
mel_input = batch["mel_input"]
linear_input = batch["linear_input"]
pred_spec = postnet_outputs[0].data.cpu().numpy()
gt_spec = mel_input[0].data.cpu().numpy()
pred_linear_spec = postnet_outputs[0].data.cpu().numpy()
pred_mel_spec = decoder_outputs[0].data.cpu().numpy()
gt_linear_spec = linear_input[0].data.cpu().numpy()
gt_mel_spec = mel_input[0].data.cpu().numpy()
align_img = alignments[0].data.cpu().numpy()
figures = {
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
"pred_linear_spec": plot_spectrogram(pred_linear_spec, ap, output_fig=False),
"real_linear_spec": plot_spectrogram(gt_linear_spec, ap, output_fig=False),
"pred_mel_spec": plot_spectrogram(pred_mel_spec, ap, output_fig=False),
"real_mel_spec": plot_spectrogram(gt_mel_spec, ap, output_fig=False),
"alignment": plot_alignment(align_img, output_fig=False),
}
@ -281,7 +296,7 @@ class Tacotron(BaseTacotron):
figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False)
# Sample audio
audio = ap.inv_spectrogram(pred_spec.T)
audio = ap.inv_spectrogram(pred_linear_spec.T)
return figures, {"audio": audio}
def train_log(

View File

@ -3,22 +3,45 @@
import torch
from coqpit import Coqpit
from torch import nn
from torch.cuda.amp.autocast_mode import autocast
from TTS.tts.layers.tacotron.gst_layers import GST
from TTS.tts.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet
from TTS.tts.models.base_tacotron import BaseTacotron
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
class Tacotron2(BaseTacotron):
"""Tacotron2 as in https://arxiv.org/abs/1712.05884
Check `TacotronConfig` for the arguments.
"""Tacotron2 model implementation inherited from :class:`TTS.tts.models.base_tacotron.BaseTacotron`.
Paper::
https://arxiv.org/abs/1712.05884
Paper abstract::
This paper describes Tacotron 2, a neural network architecture for speech synthesis directly from text.
The system is composed of a recurrent sequence-to-sequence feature prediction network that maps character
embeddings to mel-scale spectrograms, followed by a modified WaveNet model acting as a vocoder to synthesize
timedomain waveforms from those spectrograms. Our model achieves a mean opinion score (MOS) of 4.53 comparable
to a MOS of 4.58 for professionally recorded speech. To validate our design choices, we present ablation
studies of key components of our system and evaluate the impact of using mel spectrograms as the input to
WaveNet instead of linguistic, duration, and F0 features. We further demonstrate that using a compact acoustic
intermediate representation enables significant simplification of the WaveNet architecture.
Check :class:`TTS.tts.configs.tacotron2_config.Tacotron2Config` for model arguments.
Args:
config (TacotronConfig):
Configuration for the Tacotron2 model.
speaker_manager (SpeakerManager):
Speaker manager for multi-speaker training. Uuse only for multi-speaker training. Defaults to None.
"""
def __init__(self, config: Coqpit):
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager=None):
super().__init__(config)
self.speaker_manager = speaker_manager
chars, self.config, _ = self.get_characters(config)
config.num_chars = len(chars)
self.decoder_output_dim = config.out_channels
@ -28,9 +51,7 @@ class Tacotron2(BaseTacotron):
for key in config:
setattr(self, key, config[key])
# set speaker embedding channel size for determining `in_channels` for the connected layers.
# `init_multispeaker` needs to be called once more in training to initialize the speaker embedding layer based
# on the number of speakers infered from the dataset.
# init multi-speaker layers
if self.use_speaker_embedding or self.use_d_vector_file:
self.init_multispeaker(config)
self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim
@ -100,6 +121,7 @@ class Tacotron2(BaseTacotron):
@staticmethod
def shape_outputs(mel_outputs, mel_outputs_postnet, alignments):
"""Final reshape of the model output tensors."""
mel_outputs = mel_outputs.transpose(1, 2)
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
return mel_outputs, mel_outputs_postnet, alignments
@ -107,7 +129,8 @@ class Tacotron2(BaseTacotron):
def forward( # pylint: disable=dangerous-default-value
self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input={"speaker_ids": None, "d_vectors": None}
):
"""
"""Forward pass for training with Teacher Forcing.
Shapes:
text: [B, T_in]
text_lengths: [B]
@ -174,6 +197,12 @@ class Tacotron2(BaseTacotron):
@torch.no_grad()
def inference(self, text, aux_input=None):
"""Forward pass for inference with no Teacher-Forcing.
Shapes:
text: :math:`[B, T_in]`
text_lengths: :math:`[B]`
"""
aux_input = self._format_aux_input(aux_input)
embedded_inputs = self.embedding(text).transpose(1, 2)
encoder_outputs = self.encoder.inference(embedded_inputs)
@ -208,7 +237,7 @@ class Tacotron2(BaseTacotron):
return outputs
def train_step(self, batch, criterion):
"""Perform a single training step by fetching the right set if samples from the batch.
"""A single training step. Forward pass and loss computation.
Args:
batch ([type]): [description]
@ -218,7 +247,6 @@ class Tacotron2(BaseTacotron):
text_lengths = batch["text_lengths"]
mel_input = batch["mel_input"]
mel_lengths = batch["mel_lengths"]
linear_input = batch["linear_input"]
stop_targets = batch["stop_targets"]
stop_target_lengths = batch["stop_target_lengths"]
speaker_ids = batch["speaker_ids"]
@ -245,21 +273,22 @@ class Tacotron2(BaseTacotron):
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input)
# compute loss
loss_dict = criterion(
outputs["model_outputs"],
outputs["decoder_outputs"],
mel_input,
linear_input,
outputs["stop_tokens"],
stop_targets,
stop_target_lengths,
mel_lengths,
outputs["decoder_outputs_backward"],
outputs["alignments"],
alignment_lengths,
outputs["alignments_backward"],
text_lengths,
)
with autocast(enabled=False): # use float32 for the criterion
loss_dict = criterion(
outputs["model_outputs"].float(),
outputs["decoder_outputs"].float(),
mel_input.float(),
None,
outputs["stop_tokens"].float(),
stop_targets.float(),
stop_target_lengths,
mel_lengths,
None if outputs["decoder_outputs_backward"] is None else outputs["decoder_outputs_backward"].float(),
outputs["alignments"].float(),
alignment_lengths,
None if outputs["alignments_backward"] is None else outputs["alignments_backward"].float(),
text_lengths,
)
# compute alignment error (the lower the better )
align_error = 1 - alignment_diagonal_score(outputs["alignments"])

View File

@ -217,12 +217,13 @@ class Vits(BaseTTS):
# pylint: disable=dangerous-default-value
def __init__(self, config: Coqpit):
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager=None):
super().__init__(config)
self.END2END = True
self.speaker_manager = speaker_manager
if config.__class__.__name__ == "VitsConfig":
# loading from VitsConfig
if "num_chars" not in config:
@ -314,7 +315,7 @@ class Vits(BaseTTS):
if args.init_discriminator:
self.disc = VitsDiscriminator(use_spectral_norm=args.use_spectral_norm_disriminator)
def init_multispeaker(self, config: Coqpit, data: List = None):
def init_multispeaker(self, config: Coqpit):
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
or with external `d_vectors` computed from a speaker encoder model.
@ -351,18 +352,6 @@ class Vits(BaseTTS):
self.speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file)
self.embedded_speaker_dim = config.d_vector_dim
def on_init_start(self, trainer):
"""Save the speaker.json at the beginning of the training. And update the config.json with the
speakers.json file path."""
if self.speaker_manager is not None:
output_path = os.path.join(trainer.output_path, "speakers.json")
self.speaker_manager.save_speaker_ids_to_file(output_path)
trainer.config.speakers_file = output_path
trainer.config.model_args.speakers_file = output_path
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
print(f" > `speakers.json` is saved to {output_path}.")
print(" > `speakers_file` is updated in the config.json.")
@staticmethod
def _set_cond_input(aux_input: Dict):
"""Set the speaker conditioning input based on the multi-speaker mode."""

View File

@ -108,6 +108,8 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
class AudioProcessor(object):
"""Audio Processor for TTS used by all the data pipelines.
TODO: Make this a dataclass to replace `BaseAudioConfig`.
Note:
All the class arguments are set to default values to enable a flexible initialization
of the class with the model config. They are not meaningful for all the arguments.