From c0b40a0cb76b3d7668521537433e944c78ad5bf1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 20 Feb 2022 11:56:21 +0100 Subject: [PATCH] Update VITS tests --- tests/tts_tests/test_vits.py | 114 ++++++++++++++++++++++------------- 1 file changed, 71 insertions(+), 43 deletions(-) diff --git a/tests/tts_tests/test_vits.py b/tests/tts_tests/test_vits.py index eaa325b0..4018c6bd 100644 --- a/tests/tts_tests/test_vits.py +++ b/tests/tts_tests/test_vits.py @@ -3,17 +3,19 @@ import os import unittest import torch +from TTS.tts.datasets.formatters import ljspeech from tests import assertHasAttr, assertHasNotAttr, get_tests_data_path, get_tests_input_path, get_tests_output_path from TTS.config import load_config from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model from TTS.tts.configs.vits_config import VitsConfig -from TTS.tts.models.vits import Vits, VitsArgs +from TTS.tts.models.vits import Vits, VitsArgs, load_audio, amp_to_db, db_to_amp, wav_to_spec, wav_to_mel, spec_to_mel, VitsDataset from TTS.tts.utils.speakers import SpeakerManager -from TTS.utils.logging.tensorboard_logger import TensorboardLogger +from trainer.logging.tensorboard_logger import TensorboardLogger LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json") SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json") +WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav") torch.manual_seed(1) @@ -23,6 +25,28 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # pylint: disable=no-self-use class TestVits(unittest.TestCase): + def test_load_audio(self): + wav, sr = load_audio(WAV_FILE) + self.assertEqual(wav.shape, (1, 41885)) + self.assertEqual(sr, 22050) + + spec = wav_to_spec(wav, n_fft=1024, hop_length=512, win_length=1024, center=False) + mel = wav_to_mel(wav, n_fft=1024, num_mels=80, sample_rate=sr, hop_length=512, win_length=1024, fmin=0, fmax=8000, center=False) + mel2 = spec_to_mel(spec, n_fft=1024, num_mels=80, sample_rate=sr, fmin=0, fmax=8000) + + self.assertEqual((mel - mel2).abs().max(), 0) + self.assertEqual(spec.shape[0], mel.shape[0]) + self.assertEqual(spec.shape[2], mel.shape[2]) + + spec_db = amp_to_db(spec) + spec_amp = db_to_amp(spec_db) + + self.assertAlmostEqual((spec - spec_amp).abs().max(), 0, delta=1e-4) + + def test_dataset(self): + """TODO:""" + ... + def test_init_multispeaker(self): num_speakers = 10 args = VitsArgs(num_speakers=num_speakers, use_speaker_embedding=True) @@ -107,10 +131,11 @@ class TestVits(unittest.TestCase): input_lengths = torch.randint(100, 129, (batch_size,)).long().to(device) input_lengths[-1] = 128 spec = torch.rand(batch_size, config.audio["fft_size"] // 2 + 1, 30).to(device) + mel = torch.rand(batch_size, config.audio["num_mels"], 30).to(device) spec_lengths = torch.randint(20, 30, (batch_size,)).long().to(device) spec_lengths[-1] = spec.size(2) waveform = torch.rand(batch_size, 1, spec.size(2) * config.audio["hop_length"]).to(device) - return input_dummy, input_lengths, spec, spec_lengths, waveform + return input_dummy, input_lengths, mel, spec, spec_lengths, waveform def _check_forward_outputs(self, config, output_dict, encoder_config=None, batch_size=2): self.assertEqual( @@ -139,7 +164,7 @@ class TestVits(unittest.TestCase): num_speakers = 0 config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True) config.model_args.spec_segment_size = 10 - input_dummy, input_lengths, spec, spec_lengths, waveform = self._create_inputs(config) + input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config) model = Vits(config).to(device) output_dict = model.forward(input_dummy, input_lengths, spec, spec_lengths, waveform) self._check_forward_outputs(config, output_dict) @@ -150,7 +175,7 @@ class TestVits(unittest.TestCase): config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True) config.model_args.spec_segment_size = 10 - input_dummy, input_lengths, spec, spec_lengths, waveform = self._create_inputs(config) + input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config) speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device) model = Vits(config).to(device) @@ -171,7 +196,7 @@ class TestVits(unittest.TestCase): config = VitsConfig(model_args=args) model = Vits.init_from_config(config, verbose=False).to(device) model.train() - input_dummy, input_lengths, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) + input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) d_vectors = torch.randn(batch_size, 256).to(device) output_dict = model.forward( input_dummy, input_lengths, spec, spec_lengths, waveform, aux_input={"d_vectors": d_vectors} @@ -186,7 +211,7 @@ class TestVits(unittest.TestCase): args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10) config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args) - input_dummy, input_lengths, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) + input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device) @@ -221,7 +246,7 @@ class TestVits(unittest.TestCase): config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args) config.audio.sample_rate = 16000 - input_dummy, input_lengths, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) + input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device) @@ -330,20 +355,25 @@ class TestVits(unittest.TestCase): @staticmethod def _check_parameter_changes(model, model_ref): count = 0 - for param, param_ref in zip(model.parameters(), model_ref.parameters()): + for item1, item2 in zip(model.named_parameters(), model_ref.named_parameters()): + name = item1[0] + param = item1[1] + param_ref = item2[1] assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format( - count, param.shape, param, param_ref + name, param.shape, param, param_ref ) - count += 1 + count = count + 1 def _create_batch(self, config, batch_size): - input_dummy, input_lengths, mel_spec, mel_lengths, _ = self._create_inputs(config, batch_size) + input_dummy, input_lengths, mel, spec, mel_lengths, _ = self._create_inputs(config, batch_size) batch = {} - batch["text_input"] = input_dummy - batch["text_lengths"] = input_lengths - batch["mel_lengths"] = mel_lengths - batch["linear_input"] = mel_spec.transpose(1, 2) - batch["waveform"] = torch.rand(batch_size, config.audio["sample_rate"] * 10, 1).to(device) + batch["tokens"] = input_dummy + batch["token_lens"] = input_lengths + batch["spec_lens"] = mel_lengths + batch["mel_lens"] = mel_lengths + batch["spec"] = spec + batch["mel"] = mel + batch["waveform"] = torch.rand(batch_size, 1, config.audio["sample_rate"] * 10).to(device) batch["d_vectors"] = None batch["speaker_ids"] = None batch["language_ids"] = None @@ -351,33 +381,31 @@ class TestVits(unittest.TestCase): def test_train_step(self): # setup the model - config = VitsConfig(model_args=VitsArgs(num_chars=32, spec_segment_size=10)) - model = Vits(config).to(device) - # create a batch - batch = self._create_batch(config, 1) - # model to train - criterions = model.get_criterion() - criterions = [criterions[0].to(device), criterions[1].to(device)] - # reference model to compare model weights - model_ref = Vits(config).to(device) - model.train() - # pass the state to ref model - model_ref.load_state_dict(copy.deepcopy(model.state_dict())) - count = 0 - for param, param_ref in zip(model.parameters(), model_ref.parameters()): - assert (param - param_ref).sum() == 0, param - count += 1 - optimizers = model.get_optimizer() - for _ in range(5): - _, loss_dict = model.train_step(batch, criterions, 0) - loss = loss_dict["loss"] - loss.backward() - optimizers[0].step() + with torch.autograd.set_detect_anomaly(True): + + config = VitsConfig(model_args=VitsArgs(num_chars=32, spec_segment_size=10)) + model = Vits(config).to(device) + model.train() + # model to train + optimizers = model.get_optimizer() + criterions = model.get_criterion() + criterions = [criterions[0].to(device), criterions[1].to(device)] + # reference model to compare model weights + model_ref = Vits(config).to(device) + # # pass the state to ref model + model_ref.load_state_dict(copy.deepcopy(model.state_dict())) + count = 0 + for param, param_ref in zip(model.parameters(), model_ref.parameters()): + assert (param - param_ref).sum() == 0, param + count = count + 1 + for _ in range(5): + batch = self._create_batch(config, 2) + for idx in [0, 1]: + _, loss_dict = model.train_step(batch, criterions, idx) + loss_dict["loss"].backward() + optimizers[idx].step() + optimizers[idx].zero_grad() - _, loss_dict = model.train_step(batch, criterions, 1) - loss = loss_dict["loss"] - loss.backward() - optimizers[1].step() # check parameter changes self._check_parameter_changes(model, model_ref)