From e4648ffef11c3a606c55b738b486007998c25459 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sat, 7 Aug 2021 21:37:46 +0000 Subject: [PATCH] Fix multi-speaker init of Tacotron models & tests --- TTS/tts/layers/tacotron/gst_layers.py | 12 +- TTS/tts/models/base_tacotron.py | 6 +- TTS/tts/models/glow_tts.py | 46 ++++++- TTS/tts/models/speedy_speech.py | 3 +- TTS/tts/models/tacotron.py | 28 ++-- TTS/tts/models/tacotron2.py | 29 ++-- TTS/tts/utils/speakers.py | 9 +- .../test_tacotron2_d-vectors_train.py | 1 + tests/tts_tests/test_tacotron2_model.py | 130 ++++++++++-------- tests/tts_tests/test_tacotron_model.py | 118 +++++++++------- 10 files changed, 229 insertions(+), 153 deletions(-) mode change 100755 => 100644 TTS/tts/models/glow_tts.py mode change 100755 => 100644 TTS/tts/utils/speakers.py diff --git a/TTS/tts/layers/tacotron/gst_layers.py b/TTS/tts/layers/tacotron/gst_layers.py index 02154093..0d3ed039 100644 --- a/TTS/tts/layers/tacotron/gst_layers.py +++ b/TTS/tts/layers/tacotron/gst_layers.py @@ -8,10 +8,10 @@ class GST(nn.Module): See https://arxiv.org/pdf/1803.09017""" - def __init__(self, num_mel, num_heads, num_style_tokens, gst_embedding_dim, d_vector_dim=None): + def __init__(self, num_mel, num_heads, num_style_tokens, gst_embedding_dim, embedded_speaker_dim=None): super().__init__() self.encoder = ReferenceEncoder(num_mel, gst_embedding_dim) - self.style_token_layer = StyleTokenLayer(num_heads, num_style_tokens, gst_embedding_dim, d_vector_dim) + self.style_token_layer = StyleTokenLayer(num_heads, num_style_tokens, gst_embedding_dim, embedded_speaker_dim) def forward(self, inputs, speaker_embedding=None): enc_out = self.encoder(inputs) @@ -83,19 +83,19 @@ class ReferenceEncoder(nn.Module): class StyleTokenLayer(nn.Module): """NN Module attending to style tokens based on prosody encodings.""" - def __init__(self, num_heads, num_style_tokens, embedding_dim, d_vector_dim=None): + def __init__(self, num_heads, num_style_tokens, gst_embedding_dim, d_vector_dim=None): super().__init__() - self.query_dim = embedding_dim // 2 + self.query_dim = gst_embedding_dim // 2 if d_vector_dim: self.query_dim += d_vector_dim - self.key_dim = embedding_dim // num_heads + self.key_dim = gst_embedding_dim // num_heads self.style_tokens = nn.Parameter(torch.FloatTensor(num_style_tokens, self.key_dim)) nn.init.normal_(self.style_tokens, mean=0, std=0.5) self.attention = MultiHeadAttention( - query_dim=self.query_dim, key_dim=self.key_dim, num_units=embedding_dim, num_heads=num_heads + query_dim=self.query_dim, key_dim=self.key_dim, num_units=gst_embedding_dim, num_heads=num_heads ) def forward(self, inputs): diff --git a/TTS/tts/models/base_tacotron.py b/TTS/tts/models/base_tacotron.py index 2d2cc111..66842305 100644 --- a/TTS/tts/models/base_tacotron.py +++ b/TTS/tts/models/base_tacotron.py @@ -76,9 +76,6 @@ class BaseTacotron(BaseTTS): self.decoder_backward = None self.coarse_decoder = None - # init multi-speaker layers - self.init_multispeaker(config) - @staticmethod def _format_aux_input(aux_input: Dict) -> Dict: return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input) @@ -237,6 +234,7 @@ class BaseTacotron(BaseTTS): def compute_gst(self, inputs, style_input, speaker_embedding=None): """Compute global style token""" if isinstance(style_input, dict): + # multiply each style token with a weight query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).type_as(inputs) if speaker_embedding is not None: query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1) @@ -248,8 +246,10 @@ class BaseTacotron(BaseTTS): gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key) gst_outputs = gst_outputs + gst_outputs_att * v_amplifier elif style_input is None: + # ignore style token and return zero tensor gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs) else: + # compute style tokens gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable inputs = self._concat_speaker_embedding(inputs, gst_outputs) return inputs diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py old mode 100755 new mode 100644 index 1c631c8e..92c42fa7 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -12,15 +12,19 @@ from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.data import sequence_mask from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.speakers import get_speaker_manager +from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor from TTS.utils.io import load_fsspec class GlowTTS(BaseTTS): - """Glow TTS models from https://arxiv.org/abs/2005.11129 + """GlowTTS model. - Paper abstract: + Paper:: + https://arxiv.org/abs/2005.11129 + + Paper abstract:: Recently, text-to-speech (TTS) models such as FastSpeech and ParaNet have been proposed to generate mel-spectrograms from text in parallel. Despite the advantage, the parallel TTS models cannot be trained without guidance from autoregressive TTS models as their external aligners. In this work, we propose Glow-TTS, @@ -145,7 +149,6 @@ class GlowTTS(BaseTTS): g = F.normalize(g).unsqueeze(-1) else: g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] - # embedding pass o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) # drop redisual frames wrt num_squeeze and set y_lengths. @@ -362,12 +365,49 @@ class GlowTTS(BaseTTS): train_audio = ap.inv_melspectrogram(pred_spec.T) return figures, {"audio": train_audio} + @torch.no_grad() def eval_step(self, batch: dict, criterion: nn.Module): return self.train_step(batch, criterion) def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict): return self.train_log(ap, batch, outputs) + @torch.no_grad() + def test_run(self, ap): + """Generic test run for `tts` models used by `Trainer`. + + You can override this for a different behaviour. + + Returns: + Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. + """ + print(" | > Synthesizing test sentences.") + test_audios = {} + test_figures = {} + test_sentences = self.config.test_sentences + aux_inputs = self.get_aux_input() + for idx, sen in enumerate(test_sentences): + outputs = synthesis( + self, + sen, + self.config, + "cuda" in str(next(self.parameters()).device), + ap, + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + style_wav=aux_inputs["style_wav"], + enable_eos_bos_chars=self.config.enable_eos_bos_chars, + use_griffin_lim=True, + do_trim_silence=False, + ) + + test_audios["{}-audio".format(idx)] = outputs["wav"] + test_figures["{}-prediction".format(idx)] = plot_spectrogram( + outputs["outputs"]["model_outputs"], ap, output_fig=False + ) + test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False) + return test_figures, test_audios + def preprocess(self, y, y_lengths, y_max_length, attn=None): if y_max_length is not None: y_max_length = (y_max_length // self.num_squeeze) * self.num_squeeze diff --git a/TTS/tts/models/speedy_speech.py b/TTS/tts/models/speedy_speech.py index 33b9cb66..86109e74 100644 --- a/TTS/tts/models/speedy_speech.py +++ b/TTS/tts/models/speedy_speech.py @@ -106,7 +106,7 @@ class SpeedySpeech(BaseTTS): if isinstance(config.model_args.length_scale, int) else config.model_args.length_scale ) - self.emb = nn.Embedding(config.model_args.num_chars, config.model_args.hidden_channels) + self.emb = nn.Embedding(self.num_chars, config.model_args.hidden_channels) self.encoder = Encoder( config.model_args.hidden_channels, config.model_args.hidden_channels, @@ -228,6 +228,7 @@ class SpeedySpeech(BaseTTS): outputs = {"model_outputs": o_de.transpose(1, 2), "durations_log": o_dr_log.squeeze(1), "alignments": attn} return outputs + @torch.no_grad() def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument """ Shapes: diff --git a/TTS/tts/models/tacotron.py b/TTS/tts/models/tacotron.py index 7949ddf9..f7dfa70b 100644 --- a/TTS/tts/models/tacotron.py +++ b/TTS/tts/models/tacotron.py @@ -30,12 +30,11 @@ class Tacotron(BaseTacotron): for key in config: setattr(self, key, config[key]) - # speaker embedding layer - if self.num_speakers > 1: + # set speaker embedding channel size for determining `in_channels` for the connected layers. + # `init_multispeaker` needs to be called once more in training to initialize the speaker embedding layer based + # on the number of speakers infered from the dataset. + if self.use_speaker_embedding or self.use_d_vector_file: self.init_multispeaker(config) - - # speaker and gst embeddings is concat in decoder input - if self.num_speakers > 1: self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim if self.use_gst: @@ -75,13 +74,11 @@ class Tacotron(BaseTacotron): if self.gst and self.use_gst: self.gst_layer = GST( num_mel=self.decoder_output_dim, - d_vector_dim=self.d_vector_dim - if self.config.gst.gst_use_speaker_embedding and self.use_speaker_embedding - else None, num_heads=self.gst.gst_num_heads, num_style_tokens=self.gst.gst_num_style_tokens, gst_embedding_dim=self.gst.gst_embedding_dim, ) + # backward pass decoder if self.bidirectional_decoder: self._init_backward_decoder() @@ -106,7 +103,9 @@ class Tacotron(BaseTacotron): self.max_decoder_steps, ) - def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input=None): + def forward( # pylint: disable=dangerous-default-value + self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input={"speaker_ids": None, "d_vectors": None} + ): """ Shapes: text: [B, T_in] @@ -115,6 +114,7 @@ class Tacotron(BaseTacotron): mel_lengths: [B] aux_input: 'speaker_ids': [B, 1] and 'd_vectors':[B, C] """ + aux_input = self._format_aux_input(aux_input) outputs = {"alignments_backward": None, "decoder_outputs_backward": None} inputs = self.embedding(text) input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths) @@ -125,12 +125,10 @@ class Tacotron(BaseTacotron): # global style token if self.gst and self.use_gst: # B x gst_dim - encoder_outputs = self.compute_gst( - encoder_outputs, mel_specs, aux_input["d_vectors"] if "d_vectors" in aux_input else None - ) + encoder_outputs = self.compute_gst(encoder_outputs, mel_specs) # speaker embedding - if self.num_speakers > 1: - if not self.use_d_vectors: + if self.use_speaker_embedding or self.use_d_vector_file: + if not self.use_d_vector_file: # B x 1 x speaker_embed_dim embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[:, None] else: @@ -182,7 +180,7 @@ class Tacotron(BaseTacotron): # B x gst_dim encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"]) if self.num_speakers > 1: - if not self.use_d_vectors: + if not self.use_d_vector_file: # B x 1 x speaker_embed_dim embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"]) # reshape embedded_speakers diff --git a/TTS/tts/models/tacotron2.py b/TTS/tts/models/tacotron2.py index 19619662..c6df0706 100644 --- a/TTS/tts/models/tacotron2.py +++ b/TTS/tts/models/tacotron2.py @@ -31,12 +31,11 @@ class Tacotron2(BaseTacotron): for key in config: setattr(self, key, config[key]) - # speaker embedding layer - if self.num_speakers > 1: + # set speaker embedding channel size for determining `in_channels` for the connected layers. + # `init_multispeaker` needs to be called once more in training to initialize the speaker embedding layer based + # on the number of speakers infered from the dataset. + if self.use_speaker_embedding or self.use_d_vector_file: self.init_multispeaker(config) - - # speaker and gst embeddings is concat in decoder input - if self.num_speakers > 1: self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim if self.use_gst: @@ -47,6 +46,7 @@ class Tacotron2(BaseTacotron): # base model layers self.encoder = Encoder(self.encoder_in_features) + self.decoder = Decoder( self.decoder_in_features, self.decoder_output_dim, @@ -73,9 +73,6 @@ class Tacotron2(BaseTacotron): if self.gst and self.use_gst: self.gst_layer = GST( num_mel=self.decoder_output_dim, - d_vector_dim=self.d_vector_dim - if self.config.gst.gst_use_speaker_embedding and self.use_speaker_embedding - else None, num_heads=self.gst.gst_num_heads, num_style_tokens=self.gst.gst_num_style_tokens, gst_embedding_dim=self.gst.gst_embedding_dim, @@ -110,7 +107,9 @@ class Tacotron2(BaseTacotron): mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2) return mel_outputs, mel_outputs_postnet, alignments - def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input=None): + def forward( # pylint: disable=dangerous-default-value + self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input={"speaker_ids": None, "d_vectors": None} + ): """ Shapes: text: [B, T_in] @@ -130,11 +129,10 @@ class Tacotron2(BaseTacotron): encoder_outputs = self.encoder(embedded_inputs, text_lengths) if self.gst and self.use_gst: # B x gst_dim - encoder_outputs = self.compute_gst( - encoder_outputs, mel_specs, aux_input["d_vectors"] if "d_vectors" in aux_input else None - ) - if self.num_speakers > 1: - if not self.use_d_vectors: + encoder_outputs = self.compute_gst(encoder_outputs, mel_specs) + + if self.use_speaker_embedding or self.use_d_vector_file: + if not self.use_d_vector_file: # B x 1 x speaker_embed_dim embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[:, None] else: @@ -186,8 +184,9 @@ class Tacotron2(BaseTacotron): if self.gst and self.use_gst: # B x gst_dim encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"]) + if self.num_speakers > 1: - if not self.use_d_vectors: + if not self.use_d_vector_file: embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[None] # reshape embedded_speakers if embedded_speakers.ndim == 1: diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py old mode 100755 new mode 100644 index ed14cd8e..1b9ab96f --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -360,10 +360,13 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None, elif c.use_d_vector_file and c.d_vector_file: # new speaker manager with external speaker embeddings. speaker_manager.set_d_vectors_from_file(c.d_vector_file) - elif c.use_d_vector_file and not c.d_vector_file: # new speaker manager with speaker IDs file. - raise "use_d_vector_file is True, so you need pass a external speaker embedding file, run GE2E-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb or AngularPrototypical-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb notebook in notebooks/ folder" + elif c.use_d_vector_file and not c.d_vector_file: + raise "use_d_vector_file is True, so you need pass a external speaker embedding file." + elif c.use_speaker_embedding and "speakers_file" in c and c.speakers_file: + # new speaker manager with speaker IDs file. + speaker_manager.set_speaker_ids_from_file(c.speakers_file) print( - " > Training with {} speakers: {}".format( + " > Speaker manager is loaded with {} speakers: {}".format( speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids) ) ) diff --git a/tests/tts_tests/test_tacotron2_d-vectors_train.py b/tests/tts_tests/test_tacotron2_d-vectors_train.py index 3313b8c4..1a8d78bf 100644 --- a/tests/tts_tests/test_tacotron2_d-vectors_train.py +++ b/tests/tts_tests/test_tacotron2_d-vectors_train.py @@ -29,6 +29,7 @@ config = Tacotron2Config( "Be a voice, not an echo.", ], d_vector_file="tests/data/ljspeech/speakers.json", + d_vector_dim=256, max_decoder_steps=50, ) diff --git a/tests/tts_tests/test_tacotron2_model.py b/tests/tts_tests/test_tacotron2_model.py index a8132467..65d2bd9d 100644 --- a/tests/tts_tests/test_tacotron2_model.py +++ b/tests/tts_tests/test_tacotron2_model.py @@ -25,8 +25,68 @@ WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav") class TacotronTrainTest(unittest.TestCase): + """Test vanilla Tacotron2 model.""" + def test_train_step(self): # pylint: disable=no-self-use config = config_global.copy() + config.use_speaker_embedding = False + config.num_speakers = 1 + + input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) + input_lengths = torch.randint(100, 128, (8,)).long().to(device) + input_lengths = torch.sort(input_lengths, descending=True)[0] + mel_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device) + mel_postnet_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device) + mel_lengths = torch.randint(20, 30, (8,)).long().to(device) + mel_lengths[0] = 30 + stop_targets = torch.zeros(8, 30, 1).float().to(device) + + for idx in mel_lengths: + stop_targets[:, int(idx.item()) :, 0] = 1.0 + + stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // config.r, -1) + stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() + + criterion = MSELossMasked(seq_len_norm=False).to(device) + criterion_st = nn.BCEWithLogitsLoss().to(device) + model = Tacotron2(config).to(device) + model.train() + model_ref = copy.deepcopy(model) + count = 0 + for param, param_ref in zip(model.parameters(), model_ref.parameters()): + assert (param - param_ref).sum() == 0, param + count += 1 + optimizer = optim.Adam(model.parameters(), lr=config.lr) + for i in range(5): + outputs = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths) + assert torch.sigmoid(outputs["stop_tokens"]).data.max() <= 1.0 + assert torch.sigmoid(outputs["stop_tokens"]).data.min() >= 0.0 + optimizer.zero_grad() + loss = criterion(outputs["decoder_outputs"], mel_spec, mel_lengths) + stop_loss = criterion_st(outputs["stop_tokens"], stop_targets) + loss = loss + criterion(outputs["model_outputs"], mel_postnet_spec, mel_lengths) + stop_loss + loss.backward() + optimizer.step() + # check parameter changes + count = 0 + for param, param_ref in zip(model.parameters(), model_ref.parameters()): + # ignore pre-higway layer since it works conditional + # if count not in [145, 59]: + assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format( + count, param.shape, param, param_ref + ) + count += 1 + + +class MultiSpeakerTacotronTrainTest(unittest.TestCase): + """Test multi-speaker Tacotron2 with speaker embedding layer""" + + @staticmethod + def test_train_step(): + config = config_global.copy() + config.use_speaker_embedding = True + config.num_speakers = 5 + input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) input_lengths = torch.randint(100, 128, (8,)).long().to(device) input_lengths = torch.sort(input_lengths, descending=True)[0] @@ -45,6 +105,7 @@ class TacotronTrainTest(unittest.TestCase): criterion = MSELossMasked(seq_len_norm=False).to(device) criterion_st = nn.BCEWithLogitsLoss().to(device) + config.d_vector_dim = 55 model = Tacotron2(config).to(device) model.train() model_ref = copy.deepcopy(model) @@ -76,65 +137,18 @@ class TacotronTrainTest(unittest.TestCase): count += 1 -class MultiSpeakeTacotronTrainTest(unittest.TestCase): - @staticmethod - def test_train_step(): - config = config_global.copy() - input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) - input_lengths = torch.randint(100, 128, (8,)).long().to(device) - input_lengths = torch.sort(input_lengths, descending=True)[0] - mel_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device) - mel_postnet_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device) - mel_lengths = torch.randint(20, 30, (8,)).long().to(device) - mel_lengths[0] = 30 - stop_targets = torch.zeros(8, 30, 1).float().to(device) - speaker_ids = torch.rand(8, 55).to(device) - - for idx in mel_lengths: - stop_targets[:, int(idx.item()) :, 0] = 1.0 - - stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // config.r, -1) - stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() - - criterion = MSELossMasked(seq_len_norm=False).to(device) - criterion_st = nn.BCEWithLogitsLoss().to(device) - config.d_vector_dim = 55 - model = Tacotron2(config).to(device) - model.train() - model_ref = copy.deepcopy(model) - count = 0 - for param, param_ref in zip(model.parameters(), model_ref.parameters()): - assert (param - param_ref).sum() == 0, param - count += 1 - optimizer = optim.Adam(model.parameters(), lr=config.lr) - for i in range(5): - outputs = model.forward( - input_dummy, input_lengths, mel_spec, mel_lengths, aux_input={"d_vectors": speaker_ids} - ) - assert torch.sigmoid(outputs["stop_tokens"]).data.max() <= 1.0 - assert torch.sigmoid(outputs["stop_tokens"]).data.min() >= 0.0 - optimizer.zero_grad() - loss = criterion(outputs["decoder_outputs"], mel_spec, mel_lengths) - stop_loss = criterion_st(outputs["stop_tokens"], stop_targets) - loss = loss + criterion(outputs["model_outputs"], mel_postnet_spec, mel_lengths) + stop_loss - loss.backward() - optimizer.step() - # check parameter changes - count = 0 - for param, param_ref in zip(model.parameters(), model_ref.parameters()): - # ignore pre-higway layer since it works conditional - # if count not in [145, 59]: - assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format( - count, param.shape, param, param_ref - ) - count += 1 - - class TacotronGSTTrainTest(unittest.TestCase): + """Test multi-speaker Tacotron2 with Global Style Token and Speaker Embedding""" + # pylint: disable=no-self-use def test_train_step(self): # with random gst mel style config = config_global.copy() + config.use_speaker_embedding = True + config.num_speakers = 10 + config.use_gst = True + config.gst = GSTConfig() + input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) input_lengths = torch.randint(100, 128, (8,)).long().to(device) input_lengths = torch.sort(input_lengths, descending=True)[0] @@ -247,9 +261,17 @@ class TacotronGSTTrainTest(unittest.TestCase): class SCGSTMultiSpeakeTacotronTrainTest(unittest.TestCase): + """Test multi-speaker Tacotron2 with Global Style Tokens and d-vector inputs.""" + @staticmethod def test_train_step(): + config = config_global.copy() + config.use_d_vector_file = True + + config.use_gst = True + config.gst = GSTConfig() + input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) input_lengths = torch.randint(100, 128, (8,)).long().to(device) input_lengths = torch.sort(input_lengths, descending=True)[0] diff --git a/tests/tts_tests/test_tacotron_model.py b/tests/tts_tests/test_tacotron_model.py index 6c673568..3f570276 100644 --- a/tests/tts_tests/test_tacotron_model.py +++ b/tests/tts_tests/test_tacotron_model.py @@ -32,6 +32,61 @@ class TacotronTrainTest(unittest.TestCase): @staticmethod def test_train_step(): config = config_global.copy() + config.use_speaker_embedding = False + config.num_speakers = 1 + + input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) + input_lengths = torch.randint(100, 129, (8,)).long().to(device) + input_lengths[-1] = 128 + mel_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device) + linear_spec = torch.rand(8, 30, config.audio["fft_size"] // 2 + 1).to(device) + mel_lengths = torch.randint(20, 30, (8,)).long().to(device) + mel_lengths[-1] = mel_spec.size(1) + stop_targets = torch.zeros(8, 30, 1).float().to(device) + + for idx in mel_lengths: + stop_targets[:, int(idx.item()) :, 0] = 1.0 + + stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // config.r, -1) + stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() + + criterion = L1LossMasked(seq_len_norm=False).to(device) + criterion_st = nn.BCEWithLogitsLoss().to(device) + model = Tacotron(config).to(device) # FIXME: missing num_speakers parameter to Tacotron ctor + model.train() + print(" > Num parameters for Tacotron model:%s" % (count_parameters(model))) + model_ref = copy.deepcopy(model) + count = 0 + for param, param_ref in zip(model.parameters(), model_ref.parameters()): + assert (param - param_ref).sum() == 0, param + count += 1 + optimizer = optim.Adam(model.parameters(), lr=config.lr) + for _ in range(5): + outputs = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths) + optimizer.zero_grad() + loss = criterion(outputs["decoder_outputs"], mel_spec, mel_lengths) + stop_loss = criterion_st(outputs["stop_tokens"], stop_targets) + loss = loss + criterion(outputs["model_outputs"], linear_spec, mel_lengths) + stop_loss + loss.backward() + optimizer.step() + # check parameter changes + count = 0 + for param, param_ref in zip(model.parameters(), model_ref.parameters()): + # ignore pre-higway layer since it works conditional + # if count not in [145, 59]: + assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format( + count, param.shape, param, param_ref + ) + count += 1 + + +class MultiSpeakeTacotronTrainTest(unittest.TestCase): + @staticmethod + def test_train_step(): + config = config_global.copy() + config.use_speaker_embedding = True + config.num_speakers = 5 + input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) input_lengths = torch.randint(100, 129, (8,)).long().to(device) input_lengths[-1] = 128 @@ -50,6 +105,7 @@ class TacotronTrainTest(unittest.TestCase): criterion = L1LossMasked(seq_len_norm=False).to(device) criterion_st = nn.BCEWithLogitsLoss().to(device) + config.d_vector_dim = 55 model = Tacotron(config).to(device) # FIXME: missing num_speakers parameter to Tacotron ctor model.train() print(" > Num parameters for Tacotron model:%s" % (count_parameters(model))) @@ -80,63 +136,14 @@ class TacotronTrainTest(unittest.TestCase): count += 1 -class MultiSpeakeTacotronTrainTest(unittest.TestCase): - @staticmethod - def test_train_step(): - config = config_global.copy() - input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) - input_lengths = torch.randint(100, 129, (8,)).long().to(device) - input_lengths[-1] = 128 - mel_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device) - linear_spec = torch.rand(8, 30, config.audio["fft_size"] // 2 + 1).to(device) - mel_lengths = torch.randint(20, 30, (8,)).long().to(device) - mel_lengths[-1] = mel_spec.size(1) - stop_targets = torch.zeros(8, 30, 1).float().to(device) - speaker_embeddings = torch.rand(8, 55).to(device) - - for idx in mel_lengths: - stop_targets[:, int(idx.item()) :, 0] = 1.0 - - stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // config.r, -1) - stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() - - criterion = L1LossMasked(seq_len_norm=False).to(device) - criterion_st = nn.BCEWithLogitsLoss().to(device) - config.d_vector_dim = 55 - model = Tacotron(config).to(device) # FIXME: missing num_speakers parameter to Tacotron ctor - model.train() - print(" > Num parameters for Tacotron model:%s" % (count_parameters(model))) - model_ref = copy.deepcopy(model) - count = 0 - for param, param_ref in zip(model.parameters(), model_ref.parameters()): - assert (param - param_ref).sum() == 0, param - count += 1 - optimizer = optim.Adam(model.parameters(), lr=config.lr) - for _ in range(5): - outputs = model.forward( - input_dummy, input_lengths, mel_spec, mel_lengths, aux_input={"d_vectors": speaker_embeddings} - ) - optimizer.zero_grad() - loss = criterion(outputs["decoder_outputs"], mel_spec, mel_lengths) - stop_loss = criterion_st(outputs["stop_tokens"], stop_targets) - loss = loss + criterion(outputs["model_outputs"], linear_spec, mel_lengths) + stop_loss - loss.backward() - optimizer.step() - # check parameter changes - count = 0 - for param, param_ref in zip(model.parameters(), model_ref.parameters()): - # ignore pre-higway layer since it works conditional - # if count not in [145, 59]: - assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format( - count, param.shape, param, param_ref - ) - count += 1 - - class TacotronGSTTrainTest(unittest.TestCase): @staticmethod def test_train_step(): config = config_global.copy() + config.use_speaker_embedding = True + config.num_speakers = 10 + config.use_gst = True + config.gst = GSTConfig() # with random gst mel style input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) input_lengths = torch.randint(100, 129, (8,)).long().to(device) @@ -244,6 +251,11 @@ class SCGSTMultiSpeakeTacotronTrainTest(unittest.TestCase): @staticmethod def test_train_step(): config = config_global.copy() + config.use_d_vector_file = True + + config.use_gst = True + config.gst = GSTConfig() + input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) input_lengths = torch.randint(100, 129, (8,)).long().to(device) input_lengths[-1] = 128