From 235f7d9b026b4e7a0bec09c5e36f81eb019f7420 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 12 Jan 2022 11:35:52 +0000 Subject: [PATCH] Extend glow_tts model tests --- TTS/tts/models/glow_tts.py | 63 +++++-- tests/tts_tests/test_glow_tts.py | 293 +++++++++++++++++++++++++++---- 2 files changed, 300 insertions(+), 56 deletions(-) diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 7a48b023..869adcad 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -40,11 +40,20 @@ class GlowTTS(BaseTTS): Check :class:`TTS.tts.configs.glow_tts_config.GlowTTSConfig` for class arguments. Examples: + Init only model layers. + + >>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig + >>> from TTS.tts.models.glow_tts import GlowTTS + >>> config = GlowTTSConfig(num_chars=2) + >>> model = GlowTTS(config) + + Fully init a model ready for action. All the class attributes and class members + (e.g Tokenizer, AudioProcessor, etc.). are initialized internally based on config values. + >>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig >>> from TTS.tts.models.glow_tts import GlowTTS >>> config = GlowTTSConfig() - >>> model = GlowTTS(config) - + >>> model = GlowTTS.init_from_config(config, verbose=False) """ def __init__( @@ -98,25 +107,23 @@ class GlowTTS(BaseTTS): def init_multispeaker(self, config: Coqpit): """Init speaker embedding layer if `use_speaker_embedding` is True and set the expected speaker embedding - vector dimension in the network. If model uses d-vectors, then it only sets the expected dimension. + vector dimension to the encoder layer channel size. If model uses d-vectors, then it only sets + speaker embedding vector dimension to the d-vector dimension from the config. Args: config (Coqpit): Model configuration. """ self.embedded_speaker_dim = 0 - # init speaker manager - if self.speaker_manager is None and (self.use_speaker_embedding or self.use_d_vector_file): - raise ValueError( - " > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model." - ) # set number of speakers - if num_speakers is set in config, use it, otherwise use speaker_manager if self.speaker_manager is not None: self.num_speakers = self.speaker_manager.num_speakers # set ultimate speaker embedding size - if config.use_speaker_embedding or config.use_d_vector_file: + if config.use_d_vector_file: self.embedded_speaker_dim = ( config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512 ) + if self.speaker_manager is not None: + assert config.d_vector_dim == self.speaker_manager.d_vector_dim, " [!] d-vector dimension mismatch b/w config and speaker manager." # init speaker embedding layer if config.use_speaker_embedding and not config.use_d_vector_file: print(" > Init speaker_embedding layer.") @@ -186,12 +193,33 @@ class GlowTTS(BaseTTS): self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None} ): # pylint: disable=dangerous-default-value """ - Shapes: - - x: :math:`[B, T]` - - x_lenghts::math:`B` - - y: :math:`[B, T, C]` - - y_lengths::math:`B` - - g: :math:`[B, C] or B` + Args: + x (torch.Tensor): + Input text sequence ids. :math:`[B, T_en]` + + x_lengths (torch.Tensor): + Lengths of input text sequences. :math:`[B]` + + y (torch.Tensor): + Target mel-spectrogram frames. :math:`[B, T_de, C_mel]` + + y_lengths (torch.Tensor): + Lengths of target mel-spectrogram frames. :math:`[B]` + + aux_input (Dict): + Auxiliary inputs. `d_vectors` is speaker embedding vectors for a multi-speaker model. + :math:`[B, D_vec]`. `speaker_ids` is speaker ids for a multi-speaker model usind speaker-embedding + layer. :math:`B` + + Returns: + Dict: + - z: :math: `[B, T_de, C]` + - logdet: :math:`B` + - y_mean: :math:`[B, T_de, C]` + - y_log_scale: :math:`[B, T_de, C]` + - alignments: :math:`[B, T_en, T_de]` + - durations_log: :math:`[B, T_en, 1]` + - total_durations_log: :math:`[B, T_en, 1]` """ # [B, T, C] -> [B, C, T] y = y.transpose(1, 2) @@ -510,17 +538,18 @@ class GlowTTS(BaseTTS): self.run_data_dep_init = trainer.total_steps_done < self.data_dep_init_steps @staticmethod - def init_from_config(config: "GlowTTSConfig", samples: Union[List[List], List[Dict]] = None): + def init_from_config(config: "GlowTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): """Initiate model from config Args: config (VitsConfig): Model config. samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. Defaults to None. + verbose (bool): If True, print init messages. Defaults to True. """ from TTS.utils.audio import AudioProcessor - ap = AudioProcessor.init_from_config(config) + ap = AudioProcessor.init_from_config(config, verbose) tokenizer, new_config = TTSTokenizer.init_from_config(config) speaker_manager = SpeakerManager.init_from_config(config, samples) return GlowTTS(new_config, ap, tokenizer, speaker_manager) diff --git a/tests/tts_tests/test_glow_tts.py b/tests/tts_tests/test_glow_tts.py index 82d0ec3b..e97b793a 100644 --- a/tests/tts_tests/test_glow_tts.py +++ b/tests/tts_tests/test_glow_tts.py @@ -1,11 +1,13 @@ import copy import os import unittest +from TTS.tts.utils.speakers import SpeakerManager +from TTS.utils.logging.tensorboard_logger import TensorboardLogger import torch from torch import optim -from tests import get_tests_input_path +from tests import get_tests_data_path, get_tests_input_path, get_tests_output_path from TTS.tts.configs.glow_tts_config import GlowTTSConfig from TTS.tts.layers.losses import GlowTTSLoss from TTS.tts.models.glow_tts import GlowTTS @@ -28,36 +30,211 @@ def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) -class GlowTTSTrainTest(unittest.TestCase): - @staticmethod - def test_train_step(): +class TestGlowTTS(unittest.TestCase): + def _create_inputs(self): input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) input_lengths = torch.randint(100, 129, (8,)).long().to(device) input_lengths[-1] = 128 mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device) mel_lengths = torch.randint(20, 30, (8,)).long().to(device) speaker_ids = torch.randint(0, 5, (8,)).long().to(device) + return input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids + def _check_parameter_changes(self, model, model_ref): + count = 0 + for param, param_ref in zip(model.parameters(), model_ref.parameters()): + assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format( + count, param.shape, param, param_ref + ) + count += 1 + + def test_init_multispeaker(self): + config = GlowTTSConfig(num_chars=32) + model = GlowTTS(config) + # speaker embedding with default speaker_embedding_dim + config.use_speaker_embedding = True + config.num_speakers = 5 + config.d_vector_dim = None + model.init_multispeaker(config) + self.assertEqual(model.c_in_channels, model.hidden_channels_enc) + # use external speaker embeddings with speaker_embedding_dim = 301 + config = GlowTTSConfig(num_chars=32) + config.use_d_vector_file = True + config.d_vector_dim = 301 + model = GlowTTS(config) + model.init_multispeaker(config) + self.assertEqual(model.c_in_channels, 301) + # use speaker embedddings by the provided speaker_manager + config = GlowTTSConfig(num_chars=32) + config.use_speaker_embedding = True + config.speakers_file = os.path.join(get_tests_data_path(), "ljspeech", "speakers.json") + speaker_manager = SpeakerManager.init_from_config(config) + model = GlowTTS(config) + model.speaker_manager = speaker_manager + model.init_multispeaker(config) + self.assertEqual(model.c_in_channels, model.hidden_channels_enc) + self.assertEqual(model.num_speakers, speaker_manager.num_speakers) + # use external speaker embeddings by the provided speaker_manager + config = GlowTTSConfig(num_chars=32) + config.use_d_vector_file = True + config.d_vector_dim = 256 + config.d_vector_file = os.path.join(get_tests_data_path(), "dummy_speakers.json") + speaker_manager = SpeakerManager.init_from_config(config) + model = GlowTTS(config) + model.speaker_manager = speaker_manager + model.init_multispeaker(config) + self.assertEqual(model.c_in_channels, speaker_manager.d_vector_dim) + self.assertEqual(model.num_speakers, speaker_manager.num_speakers) + + def test_unlock_act_norm_layers(self): + config = GlowTTSConfig(num_chars=32) + model = GlowTTS(config).to(device) + model.unlock_act_norm_layers() + for f in model.decoder.flows: + if getattr(f, "set_ddi", False): + self.assertFalse(f.initialized) + + def test_lock_act_norm_layers(self): + config = GlowTTSConfig(num_chars=32) + model = GlowTTS(config).to(device) + model.lock_act_norm_layers() + for f in model.decoder.flows: + if getattr(f, "set_ddi", False): + self.assertTrue(f.initialized) + + def test_forward(self): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs() + # create model + config = GlowTTSConfig(num_chars=32) + model = GlowTTS(config).to(device) + model.train() + print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model))) + # inference encoder and decoder with MAS + y = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths) + self.assertEqual(y["z"].shape, mel_spec.shape) + self.assertEqual(y["logdet"].shape, torch.Size([8])) + self.assertEqual(y["y_mean"].shape, mel_spec.shape) + self.assertEqual(y["y_log_scale"].shape, mel_spec.shape) + self.assertEqual(y["alignments"].shape, mel_spec.shape[:2] + (input_dummy.shape[1],)) + self.assertEqual(y["durations_log"].shape, input_dummy.shape + (1,)) + self.assertEqual(y["total_durations_log"].shape, input_dummy.shape + (1,)) + + def test_forward_with_d_vector(self): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs() + d_vector = torch.rand(8, 256).to(device) + # create model + config = GlowTTSConfig( + num_chars=32, + use_d_vector_file=True, + d_vector_dim=256, + d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), + ) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + model.train() + print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model))) + # inference encoder and decoder with MAS + y = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths, {"d_vectors": d_vector}) + self.assertEqual(y["z"].shape, mel_spec.shape) + self.assertEqual(y["logdet"].shape, torch.Size([8])) + self.assertEqual(y["y_mean"].shape, mel_spec.shape) + self.assertEqual(y["y_log_scale"].shape, mel_spec.shape) + self.assertEqual(y["alignments"].shape, mel_spec.shape[:2] + (input_dummy.shape[1],)) + self.assertEqual(y["durations_log"].shape, input_dummy.shape + (1,)) + self.assertEqual(y["total_durations_log"].shape, input_dummy.shape + (1,)) + + def test_forward_with_speaker_id(self): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs() + speaker_ids = torch.randint(0, 24, (8,)).long().to(device) + # create model + config = GlowTTSConfig( + num_chars=32, + use_speaker_embedding=True, + num_speakers=24, + ) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + model.train() + print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model))) + # inference encoder and decoder with MAS + y = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths, {"speaker_ids": speaker_ids}) + self.assertEqual(y["z"].shape, mel_spec.shape) + self.assertEqual(y["logdet"].shape, torch.Size([8])) + self.assertEqual(y["y_mean"].shape, mel_spec.shape) + self.assertEqual(y["y_log_scale"].shape, mel_spec.shape) + self.assertEqual(y["alignments"].shape, mel_spec.shape[:2] + (input_dummy.shape[1],)) + self.assertEqual(y["durations_log"].shape, input_dummy.shape + (1,)) + self.assertEqual(y["total_durations_log"].shape, input_dummy.shape + (1,)) + + def _assert_inference_outputs(self, outputs, input_dummy, mel_spec): + output_shape = outputs["model_outputs"].shape + self.assertEqual(outputs["model_outputs"].shape[::2] , mel_spec.shape[::2]) + self.assertEqual(outputs["logdet"], None) + self.assertEqual(outputs["y_mean"].shape, output_shape) + self.assertEqual(outputs["y_log_scale"].shape, output_shape) + self.assertEqual(outputs["alignments"].shape, output_shape[:2] + (input_dummy.shape[1],)) + self.assertEqual(outputs["durations_log"].shape, input_dummy.shape + (1,)) + self.assertEqual(outputs["total_durations_log"].shape, input_dummy.shape + (1,)) + + def test_inference(self): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs() + config = GlowTTSConfig(num_chars=32) + model = GlowTTS(config).to(device) + model.eval() + outputs = model.inference(input_dummy, {"x_lengths": input_lengths}) + self._assert_inference_outputs(outputs, input_dummy, mel_spec) + + def test_inference_with_d_vector(self): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs() + d_vector = torch.rand(8, 256).to(device) + config = GlowTTSConfig(num_chars=32, use_d_vector_file=True, d_vector_dim=256, d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json")) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + model.eval() + outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "d_vectors": d_vector}) + self._assert_inference_outputs(outputs, input_dummy, mel_spec) + + def test_inference_with_speaker_ids(self): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs() + speaker_ids = torch.randint(0, 24, (8,)).long().to(device) + # create model + config = GlowTTSConfig( + num_chars=32, + use_speaker_embedding=True, + num_speakers=24, + ) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "speaker_ids": speaker_ids}) + self._assert_inference_outputs(outputs, input_dummy, mel_spec) + + def test_inference_with_MAS(self): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs() + # create model + config = GlowTTSConfig(num_chars=32) + model = GlowTTS(config).to(device) + model.eval() + # inference encoder and decoder with MAS + y = model.inference_with_MAS(input_dummy, input_lengths, mel_spec, mel_lengths) + y2 = model.decoder_inference(mel_spec, mel_lengths) + assert ( + y2["model_outputs"].shape == y["model_outputs"].shape + ), "Difference between the shapes of the glowTTS inference with MAS ({}) and the inference using only the decoder ({}) !!".format( + y["model_outputs"].shape, y2["model_outputs"].shape + ) + + def test_train_step(self): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs() criterion = GlowTTSLoss() - # model to train config = GlowTTSConfig(num_chars=32) model = GlowTTS(config).to(device) - # reference model to compare model weights model_ref = GlowTTS(config).to(device) - model.train() print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model))) - # pass the state to ref model model_ref.load_state_dict(copy.deepcopy(model.state_dict())) - count = 0 for param, param_ref in zip(model.parameters(), model_ref.parameters()): assert (param - param_ref).sum() == 0, param count += 1 - optimizer = optim.Adam(model.parameters(), lr=0.001) for _ in range(5): optimizer.zero_grad() @@ -75,40 +252,78 @@ class GlowTTSTrainTest(unittest.TestCase): loss = loss_dict["loss"] loss.backward() optimizer.step() - # check parameter changes - count = 0 - for param, param_ref in zip(model.parameters(), model_ref.parameters()): - assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format( - count, param.shape, param, param_ref - ) - count += 1 + self._check_parameter_changes(model, model_ref) - -class GlowTTSInferenceTest(unittest.TestCase): - @staticmethod - def test_inference(): - input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) - input_lengths = torch.randint(100, 129, (8,)).long().to(device) - input_lengths[-1] = 128 - mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device) - mel_lengths = torch.randint(20, 30, (8,)).long().to(device) - speaker_ids = torch.randint(0, 5, (8,)).long().to(device) - - # create model + def test_train_eval_log(self): + input_dummy, input_lengths, mel_spec, mel_lengths, _ = self._create_inputs() + batch = {} + batch["text_input"] = input_dummy + batch["text_lengths"] = input_lengths + batch["mel_lengths"] = mel_lengths + batch["mel_input"] = mel_spec + batch["d_vectors"] = None + batch["speaker_ids"] = None config = GlowTTSConfig(num_chars=32) - model = GlowTTS(config).to(device) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + model.run_data_dep_init = False + model.train() + logger = TensorboardLogger(log_dir=os.path.join(get_tests_output_path(), "dummy_glow_tts_logs"), model_name = "glow_tts_test_train_log") + criterion = model.get_criterion() + outputs, _ = model.train_step(batch, criterion) + model.train_log(batch, outputs, logger, None, 1) + model.eval_log(batch, outputs, logger, None, 1) + logger.finish() + def test_test_run(self): + config = GlowTTSConfig(num_chars=32) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + model.run_data_dep_init = False model.eval() - print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model))) + test_figures, test_audios = model.test_run(None) + self.assertTrue(test_figures is not None) + self.assertTrue(test_audios is not None) - # inference encoder and decoder with MAS - y = model.inference_with_MAS(input_dummy, input_lengths, mel_spec, mel_lengths) + def test_load_checkpoint(self): + chkp_path = os.path.join(get_tests_output_path(), "dummy_glow_tts_checkpoint.pth") + config = GlowTTSConfig(num_chars=32) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + chkp = {} + chkp["model"] = model.state_dict() + torch.save(chkp, chkp_path) + model.load_checkpoint(config, chkp_path) + self.assertTrue(model.training) + model.load_checkpoint(config, chkp_path, eval=True) + self.assertFalse(model.training) - y2 = model.decoder_inference(mel_spec, mel_lengths) + def test_get_criterion(self): + config = GlowTTSConfig(num_chars=32) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + criterion = model.get_criterion() + self.assertTrue(criterion is not None) + + def test_init_from_config(self): + config = GlowTTSConfig(num_chars=32) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + + config = GlowTTSConfig(num_chars=32, num_speakers=2) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + self.assertTrue(model.num_speakers == 2) + self.assertTrue(not hasattr(model, "emb_g")) + + config = GlowTTSConfig(num_chars=32, num_speakers=2, use_speaker_embedding=True) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + self.assertTrue(model.num_speakers == 2) + self.assertTrue(hasattr(model, "emb_g")) + + config = GlowTTSConfig(num_chars=32, num_speakers=2, use_speaker_embedding=True, speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json")) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + self.assertTrue(model.num_speakers == 10) + self.assertTrue(hasattr(model, "emb_g")) + + config = GlowTTSConfig(num_chars=32, use_d_vector_file=True, d_vector_dim=256, d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json")) + model = GlowTTS.init_from_config(config, verbose=False).to(device) + self.assertTrue(model.num_speakers == 1) + self.assertTrue(not hasattr(model, "emb_g")) + self.assertTrue(model.c_in_channels == config.d_vector_dim) - assert ( - y2["model_outputs"].shape == y["model_outputs"].shape - ), "Difference between the shapes of the glowTTS inference with MAS ({}) and the inference using only the decoder ({}) !!".format( - y["model_outputs"].shape, y2["model_outputs"].shape - )