diff --git a/tests/__init__.py b/tests/__init__.py index 45aee23a..0a0c3379 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -38,3 +38,14 @@ def run_cli(command): def get_test_data_config(): return BaseDatasetConfig(name="ljspeech", path="tests/data/ljspeech/", meta_file_train="metadata.csv") + + +def assertHasAttr(test_obj, obj, intendedAttr): + # from https://stackoverflow.com/questions/48078636/pythons-unittest-lacks-an-asserthasattr-method-what-should-i-use-instead + testBool = hasattr(obj, intendedAttr) + test_obj.assertTrue(testBool, msg=f"obj lacking an attribute. obj: {obj}, intendedAttr: {intendedAttr}") + + +def assertHasNotAttr(test_obj, obj, intendedAttr): + testBool = hasattr(obj, intendedAttr) + test_obj.assertFalse(testBool, msg=f"obj should not have an attribute. obj: {obj}, intendedAttr: {intendedAttr}") diff --git a/tests/tts_tests/test_vits.py b/tests/tts_tests/test_vits.py index 335472a5..de075a5c 100644 --- a/tests/tts_tests/test_vits.py +++ b/tests/tts_tests/test_vits.py @@ -1,13 +1,14 @@ import os -import torch import unittest -from TTS.config import load_config -from TTS.tts.models.vits import Vits, VitsArgs -from TTS.tts.configs.vits_config import VitsConfig -from TTS.tts.utils.speakers import SpeakerManager -from tests import assertHasAttr, assertHasNotAttr, get_tests_input_path -from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model +import torch + +from tests import assertHasAttr, assertHasNotAttr, get_tests_input_path +from TTS.config import load_config +from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model +from TTS.tts.configs.vits_config import VitsConfig +from TTS.tts.models.vits import Vits, VitsArgs +from TTS.tts.utils.speakers import SpeakerManager LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json") SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json") @@ -18,21 +19,21 @@ use_cuda = torch.cuda.is_available() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +#pylint: disable=no-self-use class TestVits(unittest.TestCase): - def test_init_multispeaker(self): num_speakers = 10 args = VitsArgs(num_speakers=num_speakers, use_speaker_embedding=True) model = Vits(args) - assertHasAttr(self, model, 'emb_g') + assertHasAttr(self, model, "emb_g") args = VitsArgs(num_speakers=0, use_speaker_embedding=True) model = Vits(args) - assertHasNotAttr(self, model, 'emb_g') + assertHasNotAttr(self, model, "emb_g") args = VitsArgs(num_speakers=10, use_speaker_embedding=False) model = Vits(args) - assertHasNotAttr(self, model, 'emb_g') + assertHasNotAttr(self, model, "emb_g") args = VitsArgs(d_vector_dim=101, use_d_vector_file=True) model = Vits(args) @@ -67,12 +68,12 @@ class TestVits(unittest.TestCase): aux_input = {"speaker_ids": None, "style_wav": None, "d_vectors": None, "language_ids": None} args = VitsArgs() model = Vits(args) - aux_out= model.get_aux_input(aux_input) + aux_out = model.get_aux_input(aux_input) speaker_id = torch.randint(10, (1,)) language_id = torch.randint(10, (1,)) d_vector = torch.rand(1, 128) - aux_input = {"speaker_ids": speaker_id, "style_wav": None, "d_vectors": d_vector, "language_ids": language_id} + aux_input = {"speaker_ids": speaker_id, "style_wav": None, "d_vectors": d_vector, "language_ids": language_id} aux_out = model.get_aux_input(aux_input) self.assertEqual(aux_out["speaker_ids"].shape, speaker_id.shape) self.assertEqual(aux_out["language_ids"].shape, language_id.shape) @@ -88,8 +89,8 @@ class TestVits(unittest.TestCase): ref_inp = torch.randn(1, spec_len, 513) ref_inp_len = torch.randint(1, spec_effective_len, (1,)) - ref_spk_id = torch.randint(0, num_speakers, (1,)) - tgt_spk_id = torch.randint(0, num_speakers, (1,)) + ref_spk_id = torch.randint(1, num_speakers, (1,)) + tgt_spk_id = torch.randint(1, num_speakers, (1,)) o_hat, y_mask, (z, z_p, z_hat) = model.voice_conversion(ref_inp, ref_inp_len, ref_spk_id, tgt_spk_id) self.assertEqual(o_hat.shape, (1, 1, spec_len * 256)) @@ -110,7 +111,9 @@ class TestVits(unittest.TestCase): return input_dummy, input_lengths, spec, spec_lengths, waveform def _check_forward_outputs(self, config, output_dict, encoder_config=None): - self.assertEqual(output_dict['model_outputs'].shape[2], config.model_args.spec_segment_size * config.audio["hop_length"]) + self.assertEqual( + output_dict["model_outputs"].shape[2], config.model_args.spec_segment_size * config.audio["hop_length"] + ) self.assertEqual(output_dict["alignments"].shape, (8, 128, 30)) self.assertEqual(output_dict["alignments"].max(), 1) self.assertEqual(output_dict["alignments"].min(), 0) @@ -120,13 +123,15 @@ class TestVits(unittest.TestCase): self.assertEqual(output_dict["logs_p"].shape, (8, config.model_args.hidden_channels, 30)) self.assertEqual(output_dict["m_q"].shape, (8, config.model_args.hidden_channels, 30)) self.assertEqual(output_dict["logs_q"].shape, (8, config.model_args.hidden_channels, 30)) - self.assertEqual(output_dict['waveform_seg'].shape[2], config.model_args.spec_segment_size * config.audio["hop_length"]) + self.assertEqual( + output_dict["waveform_seg"].shape[2], config.model_args.spec_segment_size * config.audio["hop_length"] + ) if encoder_config: - self.assertEqual(output_dict['gt_spk_emb'].shape, (8, encoder_config.model_params["proj_dim"])) - self.assertEqual(output_dict['syn_spk_emb'].shape, (8, encoder_config.model_params["proj_dim"])) + self.assertEqual(output_dict["gt_spk_emb"].shape, (8, encoder_config.model_params["proj_dim"])) + self.assertEqual(output_dict["syn_spk_emb"].shape, (8, encoder_config.model_params["proj_dim"])) else: - self.assertEqual(output_dict['gt_spk_emb'], None) - self.assertEqual(output_dict['syn_spk_emb'], None) + self.assertEqual(output_dict["gt_spk_emb"], None) + self.assertEqual(output_dict["syn_spk_emb"], None) def test_forward(self): num_speakers = 0 @@ -147,7 +152,9 @@ class TestVits(unittest.TestCase): speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device) model = Vits(config).to(device) - output_dict = model.forward(input_dummy, input_lengths, spec, spec_lengths, waveform, aux_input={"speaker_ids": speaker_ids}) + output_dict = model.forward( + input_dummy, input_lengths, spec, spec_lengths, waveform, aux_input={"speaker_ids": speaker_ids} + ) self._check_forward_outputs(config, output_dict) def test_multilingual_forward(self): @@ -162,7 +169,14 @@ class TestVits(unittest.TestCase): lang_ids = torch.randint(0, num_langs, (8,)).long().to(device) model = Vits(config).to(device) - output_dict = model.forward(input_dummy, input_lengths, spec, spec_lengths, waveform, aux_input={"speaker_ids": speaker_ids, "language_ids": lang_ids}) + output_dict = model.forward( + input_dummy, + input_lengths, + spec, + spec_lengths, + waveform, + aux_input={"speaker_ids": speaker_ids, "language_ids": lang_ids}, + ) self._check_forward_outputs(config, output_dict) def test_secl_forward(self): @@ -175,7 +189,12 @@ class TestVits(unittest.TestCase): speaker_manager = SpeakerManager() speaker_manager.speaker_encoder = speaker_encoder - args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10, use_speaker_encoder_as_loss=True) + args = VitsArgs( + language_ids_file=LANG_FILE, + use_language_embedding=True, + spec_segment_size=10, + use_speaker_encoder_as_loss=True, + ) config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args) config.audio.sample_rate = 16000 @@ -184,7 +203,14 @@ class TestVits(unittest.TestCase): lang_ids = torch.randint(0, num_langs, (8,)).long().to(device) model = Vits(config, speaker_manager=speaker_manager).to(device) - output_dict = model.forward(input_dummy, input_lengths, spec, spec_lengths, waveform, aux_input={"speaker_ids": speaker_ids, "language_ids": lang_ids}) + output_dict = model.forward( + input_dummy, + input_lengths, + spec, + spec_lengths, + waveform, + aux_input={"speaker_ids": speaker_ids, "language_ids": lang_ids}, + ) self._check_forward_outputs(config, output_dict, speaker_encoder_config) def test_inference(self): @@ -211,4 +237,4 @@ class TestVits(unittest.TestCase): speaker_ids = torch.randint(0, num_speakers, (1,)).long().to(device) lang_ids = torch.randint(0, num_langs, (1,)).long().to(device) model = Vits(config).to(device) - _ = model.inference(input_dummy, {"speaker_ids": speaker_ids, "language_ids": lang_ids}) \ No newline at end of file + _ = model.inference(input_dummy, {"speaker_ids": speaker_ids, "language_ids": lang_ids})