From 146fbfd7c90045f69eb027f54e9f3292eed9951c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 13 Jan 2022 17:39:06 +0000 Subject: [PATCH] Extend unittests --- TTS/tts/layers/vits/networks.py | 9 +- TTS/tts/models/vits.py | 46 ++++- tests/tts_tests/test_glow_tts.py | 89 ++++++---- tests/tts_tests/test_vits.py | 285 +++++++++++++++++++++++++++---- 4 files changed, 361 insertions(+), 68 deletions(-) diff --git a/TTS/tts/layers/vits/networks.py b/TTS/tts/layers/vits/networks.py index 7c225344..f97b584f 100644 --- a/TTS/tts/layers/vits/networks.py +++ b/TTS/tts/layers/vits/networks.py @@ -83,6 +83,7 @@ class TextEncoder(nn.Module): - x: :math:`[B, T]` - x_length: :math:`[B]` """ + assert x.shape[0] == x_lengths.shape[0] x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] # concat the lang emb in embedding chars @@ -90,7 +91,7 @@ class TextEncoder(nn.Module): x = torch.cat((x, lang_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1) x = torch.transpose(x, 1, -1) # [b, h, t] - x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) # [b, 1, t] x = self.encoder(x * x_mask, x_mask) stats = self.proj(x) * x_mask @@ -136,6 +137,9 @@ class ResidualCouplingBlock(nn.Module): def forward(self, x, x_mask, g=None, reverse=False): """ + Note: + Set `reverse` to True for inference. + Shapes: - x: :math:`[B, C, T]` - x_mask: :math:`[B, 1, T]` @@ -209,6 +213,9 @@ class ResidualCouplingBlocks(nn.Module): def forward(self, x, x_mask, g=None, reverse=False): """ + Note: + Set `reverse` to True for inference. + Shapes: - x: :math:`[B, C, T]` - x_mask: :math:`[B, 1, T]` diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 2ecd1a07..4612c02b 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -563,6 +563,19 @@ class Vits(BaseTTS): - d_vectors: :math:`[B, C, 1]` - speaker_ids: :math:`[B]` - language_ids: :math:`[B]` + + Return Shapes: + - model_outputs: :math:`[B, 1, T_wav]` + - alignments: :math:`[B, T_seq, T_dec]` + - z: :math:`[B, C, T_dec]` + - z_p: :math:`[B, C, T_dec]` + - m_p: :math:`[B, C, T_dec]` + - logs_p: :math:`[B, C, T_dec]` + - m_q: :math:`[B, C, T_dec]` + - logs_q: :math:`[B, C, T_dec]` + - waveform_seg: :math:`[B, 1, spec_seg_size * hop_length]` + - gt_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]` + - syn_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]` """ outputs = {} sid, g, lid = self._set_cond_input(aux_input) @@ -666,15 +679,33 @@ class Vits(BaseTTS): ) return outputs - def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}): + @staticmethod + def _set_x_lengths(x, aux_input): + if "x_lengths" in aux_input and aux_input["x_lengths"] is not None: + return aux_input["x_lengths"] + return torch.tensor(x.shape[1:2]).to(x.device) + + def inference(self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None}): """ + Note: + To run in batch mode, provide `x_lengths` else model assumes that the batch size is 1. + Shapes: - x: :math:`[B, T_seq]` - - d_vectors: :math:`[B, C, 1]` + - x_lengths: :math:`[B]` + - d_vectors: :math:`[B, C]` - speaker_ids: :math:`[B]` + + Return Shapes: + - model_outputs: :math:`[B, 1, T_wav]` + - alignments: :math:`[B, T_seq, T_dec]` + - z: :math:`[B, C, T_dec]` + - z_p: :math:`[B, C, T_dec]` + - m_p: :math:`[B, C, T_dec]` + - logs_p: :math:`[B, C, T_dec]` """ sid, g, lid = self._set_cond_input(aux_input) - x_lengths = torch.tensor(x.shape[1:2]).to(x.device) + x_lengths = self._set_x_lengths(x, aux_input) # speaker embedding if self.args.use_speaker_embedding and sid is not None: @@ -704,8 +735,9 @@ class Vits(BaseTTS): w = torch.exp(logw) * x_mask * self.length_scale w_ceil = torch.ceil(w) y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() - y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype) - attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype).unsqueeze(1) # [B, 1, T_dec] + + attn_mask = x_mask * y_mask.transpose(1, 2) # [B, 1, T_enc] * [B, T_dec, 1] attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2)) m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2) @@ -1004,7 +1036,7 @@ class Vits(BaseTTS): assert not self.training @staticmethod - def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None): + def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): """Initiate model from config Args: @@ -1014,7 +1046,7 @@ class Vits(BaseTTS): """ from TTS.utils.audio import AudioProcessor - ap = AudioProcessor.init_from_config(config) + ap = AudioProcessor.init_from_config(config, verbose=verbose) tokenizer, new_config = TTSTokenizer.init_from_config(config) speaker_manager = SpeakerManager.init_from_config(config, samples) language_manager = LanguageManager.init_from_config(config) diff --git a/tests/tts_tests/test_glow_tts.py b/tests/tts_tests/test_glow_tts.py index e48977e9..305f86b8 100644 --- a/tests/tts_tests/test_glow_tts.py +++ b/tests/tts_tests/test_glow_tts.py @@ -23,6 +23,7 @@ c = GlowTTSConfig() ap = AudioProcessor(**c.audio) WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav") +BATCH_SIZE = 3 def count_parameters(model): @@ -32,13 +33,13 @@ def count_parameters(model): class TestGlowTTS(unittest.TestCase): @staticmethod - def _create_inputs(): - input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) - input_lengths = torch.randint(100, 129, (8,)).long().to(device) + def _create_inputs(batch_size=8): + input_dummy = torch.randint(0, 24, (batch_size, 128)).long().to(device) + input_lengths = torch.randint(100, 129, (batch_size,)).long().to(device) input_lengths[-1] = 128 - mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device) - mel_lengths = torch.randint(20, 30, (8,)).long().to(device) - speaker_ids = torch.randint(0, 5, (8,)).long().to(device) + mel_spec = torch.rand(batch_size, 30, c.audio["num_mels"]).to(device) + mel_lengths = torch.randint(20, 30, (batch_size,)).long().to(device) + speaker_ids = torch.randint(0, 5, (batch_size,)).long().to(device) return input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids @staticmethod @@ -104,8 +105,8 @@ class TestGlowTTS(unittest.TestCase): if getattr(f, "set_ddi", False): self.assertTrue(f.initialized) - def test_forward(self): - input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs() + def _test_forward(self, batch_size): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(batch_size) # create model config = GlowTTSConfig(num_chars=32) model = GlowTTS(config).to(device) @@ -114,16 +115,20 @@ class TestGlowTTS(unittest.TestCase): # inference encoder and decoder with MAS y = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths) self.assertEqual(y["z"].shape, mel_spec.shape) - self.assertEqual(y["logdet"].shape, torch.Size([8])) + self.assertEqual(y["logdet"].shape, torch.Size([batch_size])) self.assertEqual(y["y_mean"].shape, mel_spec.shape) self.assertEqual(y["y_log_scale"].shape, mel_spec.shape) self.assertEqual(y["alignments"].shape, mel_spec.shape[:2] + (input_dummy.shape[1],)) self.assertEqual(y["durations_log"].shape, input_dummy.shape + (1,)) self.assertEqual(y["total_durations_log"].shape, input_dummy.shape + (1,)) - def test_forward_with_d_vector(self): - input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs() - d_vector = torch.rand(8, 256).to(device) + def test_forward(self): + self._test_forward(1) + self._test_forward(3) + + def _test_forward_with_d_vector(self, batch_size): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(batch_size) + d_vector = torch.rand(batch_size, 256).to(device) # create model config = GlowTTSConfig( num_chars=32, @@ -137,16 +142,20 @@ class TestGlowTTS(unittest.TestCase): # inference encoder and decoder with MAS y = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths, {"d_vectors": d_vector}) self.assertEqual(y["z"].shape, mel_spec.shape) - self.assertEqual(y["logdet"].shape, torch.Size([8])) + self.assertEqual(y["logdet"].shape, torch.Size([batch_size])) self.assertEqual(y["y_mean"].shape, mel_spec.shape) self.assertEqual(y["y_log_scale"].shape, mel_spec.shape) self.assertEqual(y["alignments"].shape, mel_spec.shape[:2] + (input_dummy.shape[1],)) self.assertEqual(y["durations_log"].shape, input_dummy.shape + (1,)) self.assertEqual(y["total_durations_log"].shape, input_dummy.shape + (1,)) - def test_forward_with_speaker_id(self): - input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs() - speaker_ids = torch.randint(0, 24, (8,)).long().to(device) + def test_forward_with_d_vector(self): + self._test_forward_with_d_vector(1) + self._test_forward_with_d_vector(3) + + def _test_forward_with_speaker_id(self, batch_size): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(batch_size) + speaker_ids = torch.randint(0, 24, (batch_size,)).long().to(device) # create model config = GlowTTSConfig( num_chars=32, @@ -159,13 +168,17 @@ class TestGlowTTS(unittest.TestCase): # inference encoder and decoder with MAS y = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths, {"speaker_ids": speaker_ids}) self.assertEqual(y["z"].shape, mel_spec.shape) - self.assertEqual(y["logdet"].shape, torch.Size([8])) + self.assertEqual(y["logdet"].shape, torch.Size([batch_size])) self.assertEqual(y["y_mean"].shape, mel_spec.shape) self.assertEqual(y["y_log_scale"].shape, mel_spec.shape) self.assertEqual(y["alignments"].shape, mel_spec.shape[:2] + (input_dummy.shape[1],)) self.assertEqual(y["durations_log"].shape, input_dummy.shape + (1,)) self.assertEqual(y["total_durations_log"].shape, input_dummy.shape + (1,)) + def test_forward_with_speaker_id(self): + self._test_forward_with_speaker_id(1) + self._test_forward_with_speaker_id(3) + def _assert_inference_outputs(self, outputs, input_dummy, mel_spec): output_shape = outputs["model_outputs"].shape self.assertEqual(outputs["model_outputs"].shape[::2], mel_spec.shape[::2]) @@ -176,17 +189,21 @@ class TestGlowTTS(unittest.TestCase): self.assertEqual(outputs["durations_log"].shape, input_dummy.shape + (1,)) self.assertEqual(outputs["total_durations_log"].shape, input_dummy.shape + (1,)) - def test_inference(self): - input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs() + def _test_inference(self, batch_size): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(batch_size) config = GlowTTSConfig(num_chars=32) model = GlowTTS(config).to(device) model.eval() outputs = model.inference(input_dummy, {"x_lengths": input_lengths}) self._assert_inference_outputs(outputs, input_dummy, mel_spec) - def test_inference_with_d_vector(self): - input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs() - d_vector = torch.rand(8, 256).to(device) + def test_inference(self): + self._test_inference(1) + self._test_inference(3) + + def _test_inference_with_d_vector(self, batch_size): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(batch_size) + d_vector = torch.rand(batch_size, 256).to(device) config = GlowTTSConfig( num_chars=32, use_d_vector_file=True, @@ -198,9 +215,13 @@ class TestGlowTTS(unittest.TestCase): outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "d_vectors": d_vector}) self._assert_inference_outputs(outputs, input_dummy, mel_spec) - def test_inference_with_speaker_ids(self): - input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs() - speaker_ids = torch.randint(0, 24, (8,)).long().to(device) + def test_inference_with_d_vector(self): + self._test_inference_with_d_vector(1) + self._test_inference_with_d_vector(3) + + def _test_inference_with_speaker_ids(self, batch_size): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(batch_size) + speaker_ids = torch.randint(0, 24, (batch_size,)).long().to(device) # create model config = GlowTTSConfig( num_chars=32, @@ -211,8 +232,12 @@ class TestGlowTTS(unittest.TestCase): outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "speaker_ids": speaker_ids}) self._assert_inference_outputs(outputs, input_dummy, mel_spec) - def test_inference_with_MAS(self): - input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs() + def test_inference_with_speaker_ids(self): + self._test_inference_with_speaker_ids(1) + self._test_inference_with_speaker_ids(3) + + def _test_inference_with_MAS(self, batch_size): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(batch_size) # create model config = GlowTTSConfig(num_chars=32) model = GlowTTS(config).to(device) @@ -226,8 +251,13 @@ class TestGlowTTS(unittest.TestCase): y["model_outputs"].shape, y2["model_outputs"].shape ) + def test_inference_with_MAS(self): + self._test_inference_with_MAS(1) + self._test_inference_with_MAS(3) + def test_train_step(self): - input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs() + batch_size = BATCH_SIZE + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(batch_size) criterion = GlowTTSLoss() # model to train config = GlowTTSConfig(num_chars=32) @@ -263,7 +293,8 @@ class TestGlowTTS(unittest.TestCase): self._check_parameter_changes(model, model_ref) def test_train_eval_log(self): - input_dummy, input_lengths, mel_spec, mel_lengths, _ = self._create_inputs() + batch_size = BATCH_SIZE + input_dummy, input_lengths, mel_spec, mel_lengths, _ = self._create_inputs(batch_size) batch = {} batch["text_input"] = input_dummy batch["text_lengths"] = input_lengths diff --git a/tests/tts_tests/test_vits.py b/tests/tts_tests/test_vits.py index 4274d947..53e7c09e 100644 --- a/tests/tts_tests/test_vits.py +++ b/tests/tts_tests/test_vits.py @@ -1,9 +1,11 @@ +import copy import os import unittest +from TTS.utils.logging.tensorboard_logger import TensorboardLogger import torch -from tests import assertHasAttr, assertHasNotAttr, get_tests_input_path +from tests import assertHasAttr, assertHasNotAttr, get_tests_data_path, get_tests_input_path, get_tests_output_path from TTS.config import load_config from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model from TTS.tts.configs.vits_config import VitsConfig @@ -100,35 +102,35 @@ class TestVits(unittest.TestCase): self.assertEqual(z_p.shape, (1, args.hidden_channels, spec_len)) self.assertEqual(z_hat.shape, (1, args.hidden_channels, spec_len)) - def _init_inputs(self, config): - input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) - input_lengths = torch.randint(100, 129, (8,)).long().to(device) + def _create_inputs(self, config, batch_size=2): + input_dummy = torch.randint(0, 24, (batch_size, 128)).long().to(device) + input_lengths = torch.randint(100, 129, (batch_size,)).long().to(device) input_lengths[-1] = 128 - spec = torch.rand(8, config.audio["fft_size"] // 2 + 1, 30).to(device) - spec_lengths = torch.randint(20, 30, (8,)).long().to(device) + spec = torch.rand(batch_size, config.audio["fft_size"] // 2 + 1, 30).to(device) + spec_lengths = torch.randint(20, 30, (batch_size,)).long().to(device) spec_lengths[-1] = spec.size(2) - waveform = torch.rand(8, 1, spec.size(2) * config.audio["hop_length"]).to(device) + waveform = torch.rand(batch_size, 1, spec.size(2) * config.audio["hop_length"]).to(device) return input_dummy, input_lengths, spec, spec_lengths, waveform - def _check_forward_outputs(self, config, output_dict, encoder_config=None): + def _check_forward_outputs(self, config, output_dict, encoder_config=None, batch_size=2): self.assertEqual( output_dict["model_outputs"].shape[2], config.model_args.spec_segment_size * config.audio["hop_length"] ) - self.assertEqual(output_dict["alignments"].shape, (8, 128, 30)) + self.assertEqual(output_dict["alignments"].shape, (batch_size, 128, 30)) self.assertEqual(output_dict["alignments"].max(), 1) self.assertEqual(output_dict["alignments"].min(), 0) - self.assertEqual(output_dict["z"].shape, (8, config.model_args.hidden_channels, 30)) - self.assertEqual(output_dict["z_p"].shape, (8, config.model_args.hidden_channels, 30)) - self.assertEqual(output_dict["m_p"].shape, (8, config.model_args.hidden_channels, 30)) - self.assertEqual(output_dict["logs_p"].shape, (8, config.model_args.hidden_channels, 30)) - self.assertEqual(output_dict["m_q"].shape, (8, config.model_args.hidden_channels, 30)) - self.assertEqual(output_dict["logs_q"].shape, (8, config.model_args.hidden_channels, 30)) + self.assertEqual(output_dict["z"].shape, (batch_size, config.model_args.hidden_channels, 30)) + self.assertEqual(output_dict["z_p"].shape, (batch_size, config.model_args.hidden_channels, 30)) + self.assertEqual(output_dict["m_p"].shape, (batch_size, config.model_args.hidden_channels, 30)) + self.assertEqual(output_dict["logs_p"].shape, (batch_size, config.model_args.hidden_channels, 30)) + self.assertEqual(output_dict["m_q"].shape, (batch_size, config.model_args.hidden_channels, 30)) + self.assertEqual(output_dict["logs_q"].shape, (batch_size, config.model_args.hidden_channels, 30)) self.assertEqual( output_dict["waveform_seg"].shape[2], config.model_args.spec_segment_size * config.audio["hop_length"] ) if encoder_config: - self.assertEqual(output_dict["gt_spk_emb"].shape, (8, encoder_config.model_params["proj_dim"])) - self.assertEqual(output_dict["syn_spk_emb"].shape, (8, encoder_config.model_params["proj_dim"])) + self.assertEqual(output_dict["gt_spk_emb"].shape, (batch_size, encoder_config.model_params["proj_dim"])) + self.assertEqual(output_dict["syn_spk_emb"].shape, (batch_size, encoder_config.model_params["proj_dim"])) else: self.assertEqual(output_dict["gt_spk_emb"], None) self.assertEqual(output_dict["syn_spk_emb"], None) @@ -137,7 +139,7 @@ class TestVits(unittest.TestCase): num_speakers = 0 config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True) config.model_args.spec_segment_size = 10 - input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config) + input_dummy, input_lengths, spec, spec_lengths, waveform = self._create_inputs(config) model = Vits(config).to(device) output_dict = model.forward(input_dummy, input_lengths, spec, spec_lengths, waveform) self._check_forward_outputs(config, output_dict) @@ -148,7 +150,7 @@ class TestVits(unittest.TestCase): config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True) config.model_args.spec_segment_size = 10 - input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config) + input_dummy, input_lengths, spec, spec_lengths, waveform = self._create_inputs(config) speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device) model = Vits(config).to(device) @@ -157,16 +159,36 @@ class TestVits(unittest.TestCase): ) self._check_forward_outputs(config, output_dict) + def test_d_vector_forward(self): + batch_size = 2 + args = VitsArgs( + spec_segment_size=10, + num_chars=32, + use_d_vector_file=True, + d_vector_dim=256, + d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), + ) + config = VitsConfig(model_args=args) + model = Vits.init_from_config(config, verbose=False).to(device) + model.train() + input_dummy, input_lengths, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) + d_vectors = torch.randn(batch_size, 256).to(device) + output_dict = model.forward( + input_dummy, input_lengths, spec, spec_lengths, waveform, aux_input={"d_vectors": d_vectors} + ) + self._check_forward_outputs(config, output_dict) + def test_multilingual_forward(self): num_speakers = 10 num_langs = 3 + batch_size = 2 args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10) config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args) - input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config) - speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device) - lang_ids = torch.randint(0, num_langs, (8,)).long().to(device) + input_dummy, input_lengths, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) + speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) + lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device) model = Vits(config).to(device) output_dict = model.forward( @@ -182,6 +204,7 @@ class TestVits(unittest.TestCase): def test_secl_forward(self): num_speakers = 10 num_langs = 3 + batch_size = 2 speaker_encoder_config = load_config(SPEAKER_ENCODER_CONFIG) speaker_encoder_config.model_params["use_torch_spec"] = True @@ -198,9 +221,9 @@ class TestVits(unittest.TestCase): config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args) config.audio.sample_rate = 16000 - input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config) - speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device) - lang_ids = torch.randint(0, num_langs, (8,)).long().to(device) + input_dummy, input_lengths, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) + speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) + lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device) model = Vits(config, speaker_manager=speaker_manager).to(device) output_dict = model.forward( @@ -213,28 +236,228 @@ class TestVits(unittest.TestCase): ) self._check_forward_outputs(config, output_dict, speaker_encoder_config) + def _check_inference_outputs(self, config, outputs, input_dummy, batch_size=1): + feat_len = outputs["z"].shape[2] + self.assertEqual(outputs["model_outputs"].shape[:2], (batch_size, 1)) # we don't know the channel dimension + self.assertEqual(outputs["alignments"].shape, (batch_size, input_dummy.shape[1], feat_len)) + self.assertEqual(outputs["z"].shape, (batch_size, config.model_args.hidden_channels, feat_len)) + self.assertEqual(outputs["z_p"].shape, (batch_size, config.model_args.hidden_channels, feat_len)) + self.assertEqual(outputs["m_p"].shape, (batch_size, config.model_args.hidden_channels, feat_len)) + self.assertEqual(outputs["logs_p"].shape, (batch_size, config.model_args.hidden_channels, feat_len)) + def test_inference(self): num_speakers = 0 config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True) - input_dummy = torch.randint(0, 24, (1, 128)).long().to(device) model = Vits(config).to(device) - _ = model.inference(input_dummy) + + batch_size = 1 + input_dummy, *_ = self._create_inputs(config, batch_size=batch_size) + outputs = model.inference(input_dummy) + self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size) + + batch_size = 2 + input_dummy, input_lengths, *_ = self._create_inputs(config, batch_size=batch_size) + outputs = model.inference(input_dummy, aux_input={"x_lengths": input_lengths}) + self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size) def test_multispeaker_inference(self): num_speakers = 10 config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True) - input_dummy = torch.randint(0, 24, (1, 128)).long().to(device) - speaker_ids = torch.randint(0, num_speakers, (1,)).long().to(device) model = Vits(config).to(device) - _ = model.inference(input_dummy, {"speaker_ids": speaker_ids}) + + batch_size = 1 + input_dummy, *_ = self._create_inputs(config, batch_size=batch_size) + speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) + outputs = model.inference(input_dummy, {"speaker_ids": speaker_ids}) + self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size) + + batch_size = 2 + input_dummy, input_lengths, *_ = self._create_inputs(config, batch_size=batch_size) + speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) + outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "speaker_ids": speaker_ids}) + self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size) def test_multilingual_inference(self): num_speakers = 10 num_langs = 3 args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10) config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args) + model = Vits(config).to(device) + input_dummy = torch.randint(0, 24, (1, 128)).long().to(device) speaker_ids = torch.randint(0, num_speakers, (1,)).long().to(device) lang_ids = torch.randint(0, num_langs, (1,)).long().to(device) - model = Vits(config).to(device) _ = model.inference(input_dummy, {"speaker_ids": speaker_ids, "language_ids": lang_ids}) + + batch_size = 1 + input_dummy, *_ = self._create_inputs(config, batch_size=batch_size) + speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) + lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device) + outputs = model.inference(input_dummy, {"speaker_ids": speaker_ids, "language_ids": lang_ids}) + self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size) + + batch_size = 2 + input_dummy, input_lengths, *_ = self._create_inputs(config, batch_size=batch_size) + speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) + lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device) + outputs = model.inference( + input_dummy, {"x_lengths": input_lengths, "speaker_ids": speaker_ids, "language_ids": lang_ids} + ) + self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size) + + def test_d_vector_inference(self): + args = VitsArgs( + spec_segment_size=10, + num_chars=32, + use_d_vector_file=True, + d_vector_dim=256, + d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), + ) + config = VitsConfig(model_args=args) + model = Vits.init_from_config(config, verbose=False).to(device) + model.eval() + # batch size = 1 + input_dummy = torch.randint(0, 24, (1, 128)).long().to(device) + d_vectors = torch.randn(1, 256).to(device) + outputs = model.inference(input_dummy, aux_input={"d_vectors": d_vectors}) + self._check_inference_outputs(config, outputs, input_dummy) + # batch size = 2 + input_dummy, input_lengths, *_ = self._create_inputs(config) + d_vectors = torch.randn(2, 256).to(device) + outputs = model.inference(input_dummy, aux_input={"x_lengths": input_lengths, "d_vectors": d_vectors}) + self._check_inference_outputs(config, outputs, input_dummy, batch_size=2) + + @staticmethod + def _check_parameter_changes(model, model_ref): + count = 0 + for param, param_ref in zip(model.parameters(), model_ref.parameters()): + assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format( + count, param.shape, param, param_ref + ) + count += 1 + + def _create_batch(self, config, batch_size): + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(config, batch_size) + batch = {} + batch["text_input"] = input_dummy + batch["text_lengths"] = input_lengths + batch["mel_lengths"] = mel_lengths + batch["linear_input"] = mel_spec.transpose(1, 2) + batch["waveform"] = torch.rand(batch_size, config.audio["sample_rate"] * 10, 1).to(device) + batch["d_vectors"] = None + batch["speaker_ids"] = None + batch["language_ids"] = None + return batch + + def test_train_step(self): + # setup the model + config = VitsConfig(model_args=VitsArgs(num_chars=32, spec_segment_size=10)) + model = Vits(config).to(device) + # create a batch + batch = self._create_batch(config, 1) + # model to train + criterions = model.get_criterion() + criterions = [criterions[0].to(device), criterions[1].to(device)] + # reference model to compare model weights + model_ref = Vits(config).to(device) + model.train() + # pass the state to ref model + model_ref.load_state_dict(copy.deepcopy(model.state_dict())) + count = 0 + for param, param_ref in zip(model.parameters(), model_ref.parameters()): + assert (param - param_ref).sum() == 0, param + count += 1 + optimizers = model.get_optimizer() + for _ in range(5): + _, loss_dict = model.train_step(batch, criterions, 0) + loss = loss_dict["loss"] + loss.backward() + optimizers[0].step() + + _, loss_dict = model.train_step(batch, criterions, 1) + loss = loss_dict["loss"] + loss.backward() + optimizers[1].step() + # check parameter changes + self._check_parameter_changes(model, model_ref) + + def test_train_eval_log(self): + batch_size = 2 + config = VitsConfig(model_args=VitsArgs(num_chars=32, spec_segment_size=10)) + model = Vits.init_from_config(config, verbose=False).to(device) + model.run_data_dep_init = False + model.train() + batch = self._create_batch(config, batch_size) + logger = TensorboardLogger( + log_dir=os.path.join(get_tests_output_path(), "dummy_vits_logs"), model_name="vits_test_train_log" + ) + criterion = model.get_criterion() + criterion = [criterion[0].to(device), criterion[1].to(device)] + outputs = [None] * 2 + outputs[0], _ = model.train_step(batch, criterion, 0) + outputs[1], _ = model.train_step(batch, criterion, 1) + model.train_log(batch, outputs, logger, None, 1) + + model.eval_log(batch, outputs, logger, None, 1) + logger.finish() + + def test_test_run(self): + config = VitsConfig(model_args=VitsArgs(num_chars=32)) + model = Vits.init_from_config(config, verbose=False).to(device) + model.run_data_dep_init = False + model.eval() + test_figures, test_audios = model.test_run(None) + self.assertTrue(test_figures is not None) + self.assertTrue(test_audios is not None) + + def test_load_checkpoint(self): + chkp_path = os.path.join(get_tests_output_path(), "dummy_glow_tts_checkpoint.pth") + config = VitsConfig(VitsArgs(num_chars=32)) + model = Vits.init_from_config(config, verbose=False).to(device) + chkp = {} + chkp["model"] = model.state_dict() + torch.save(chkp, chkp_path) + model.load_checkpoint(config, chkp_path) + self.assertTrue(model.training) + model.load_checkpoint(config, chkp_path, eval=True) + self.assertFalse(model.training) + + def test_get_criterion(self): + config = VitsConfig(VitsArgs(num_chars=32)) + model = Vits.init_from_config(config, verbose=False).to(device) + criterion = model.get_criterion() + self.assertTrue(criterion is not None) + + def test_init_from_config(self): + config = VitsConfig(model_args=VitsArgs(num_chars=32)) + model = Vits.init_from_config(config, verbose=False).to(device) + + config = VitsConfig(model_args=VitsArgs(num_chars=32, num_speakers=2)) + model = Vits.init_from_config(config, verbose=False).to(device) + self.assertTrue(not hasattr(model, "emb_g")) + + config = VitsConfig(model_args=VitsArgs(num_chars=32, num_speakers=2, use_speaker_embedding=True)) + model = Vits.init_from_config(config, verbose=False).to(device) + self.assertEqual(model.num_speakers, 2) + self.assertTrue(hasattr(model, "emb_g")) + + config = VitsConfig(model_args=VitsArgs( + num_chars=32, + num_speakers=2, + use_speaker_embedding=True, + speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"), + )) + model = Vits.init_from_config(config, verbose=False).to(device) + self.assertEqual(model.num_speakers, 10) + self.assertTrue(hasattr(model, "emb_g")) + + config = VitsConfig(model_args=VitsArgs( + num_chars=32, + use_d_vector_file=True, + d_vector_dim=256, + d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), + )) + model = Vits.init_from_config(config, verbose=False).to(device) + self.assertTrue(model.num_speakers == 1) + self.assertTrue(not hasattr(model, "emb_g")) + self.assertTrue(model.embedded_speaker_dim == config.d_vector_dim)