mirror of https://github.com/coqui-ai/TTS.git
Add custom asserts to tests
This commit is contained in:
parent
7129b04d46
commit
497332bd46
|
@ -38,3 +38,14 @@ def run_cli(command):
|
||||||
|
|
||||||
def get_test_data_config():
|
def get_test_data_config():
|
||||||
return BaseDatasetConfig(name="ljspeech", path="tests/data/ljspeech/", meta_file_train="metadata.csv")
|
return BaseDatasetConfig(name="ljspeech", path="tests/data/ljspeech/", meta_file_train="metadata.csv")
|
||||||
|
|
||||||
|
|
||||||
|
def assertHasAttr(test_obj, obj, intendedAttr):
|
||||||
|
# from https://stackoverflow.com/questions/48078636/pythons-unittest-lacks-an-asserthasattr-method-what-should-i-use-instead
|
||||||
|
testBool = hasattr(obj, intendedAttr)
|
||||||
|
test_obj.assertTrue(testBool, msg=f"obj lacking an attribute. obj: {obj}, intendedAttr: {intendedAttr}")
|
||||||
|
|
||||||
|
|
||||||
|
def assertHasNotAttr(test_obj, obj, intendedAttr):
|
||||||
|
testBool = hasattr(obj, intendedAttr)
|
||||||
|
test_obj.assertFalse(testBool, msg=f"obj should not have an attribute. obj: {obj}, intendedAttr: {intendedAttr}")
|
||||||
|
|
|
@ -1,13 +1,14 @@
|
||||||
import os
|
import os
|
||||||
import torch
|
|
||||||
import unittest
|
import unittest
|
||||||
from TTS.config import load_config
|
|
||||||
from TTS.tts.models.vits import Vits, VitsArgs
|
|
||||||
from TTS.tts.configs.vits_config import VitsConfig
|
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
|
||||||
from tests import assertHasAttr, assertHasNotAttr, get_tests_input_path
|
|
||||||
from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tests import assertHasAttr, assertHasNotAttr, get_tests_input_path
|
||||||
|
from TTS.config import load_config
|
||||||
|
from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
|
||||||
|
from TTS.tts.configs.vits_config import VitsConfig
|
||||||
|
from TTS.tts.models.vits import Vits, VitsArgs
|
||||||
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
|
|
||||||
LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
|
LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
|
||||||
SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json")
|
SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json")
|
||||||
|
@ -18,21 +19,21 @@ use_cuda = torch.cuda.is_available()
|
||||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|
||||||
|
#pylint: disable=no-self-use
|
||||||
class TestVits(unittest.TestCase):
|
class TestVits(unittest.TestCase):
|
||||||
|
|
||||||
def test_init_multispeaker(self):
|
def test_init_multispeaker(self):
|
||||||
num_speakers = 10
|
num_speakers = 10
|
||||||
args = VitsArgs(num_speakers=num_speakers, use_speaker_embedding=True)
|
args = VitsArgs(num_speakers=num_speakers, use_speaker_embedding=True)
|
||||||
model = Vits(args)
|
model = Vits(args)
|
||||||
assertHasAttr(self, model, 'emb_g')
|
assertHasAttr(self, model, "emb_g")
|
||||||
|
|
||||||
args = VitsArgs(num_speakers=0, use_speaker_embedding=True)
|
args = VitsArgs(num_speakers=0, use_speaker_embedding=True)
|
||||||
model = Vits(args)
|
model = Vits(args)
|
||||||
assertHasNotAttr(self, model, 'emb_g')
|
assertHasNotAttr(self, model, "emb_g")
|
||||||
|
|
||||||
args = VitsArgs(num_speakers=10, use_speaker_embedding=False)
|
args = VitsArgs(num_speakers=10, use_speaker_embedding=False)
|
||||||
model = Vits(args)
|
model = Vits(args)
|
||||||
assertHasNotAttr(self, model, 'emb_g')
|
assertHasNotAttr(self, model, "emb_g")
|
||||||
|
|
||||||
args = VitsArgs(d_vector_dim=101, use_d_vector_file=True)
|
args = VitsArgs(d_vector_dim=101, use_d_vector_file=True)
|
||||||
model = Vits(args)
|
model = Vits(args)
|
||||||
|
@ -67,12 +68,12 @@ class TestVits(unittest.TestCase):
|
||||||
aux_input = {"speaker_ids": None, "style_wav": None, "d_vectors": None, "language_ids": None}
|
aux_input = {"speaker_ids": None, "style_wav": None, "d_vectors": None, "language_ids": None}
|
||||||
args = VitsArgs()
|
args = VitsArgs()
|
||||||
model = Vits(args)
|
model = Vits(args)
|
||||||
aux_out= model.get_aux_input(aux_input)
|
aux_out = model.get_aux_input(aux_input)
|
||||||
|
|
||||||
speaker_id = torch.randint(10, (1,))
|
speaker_id = torch.randint(10, (1,))
|
||||||
language_id = torch.randint(10, (1,))
|
language_id = torch.randint(10, (1,))
|
||||||
d_vector = torch.rand(1, 128)
|
d_vector = torch.rand(1, 128)
|
||||||
aux_input = {"speaker_ids": speaker_id, "style_wav": None, "d_vectors": d_vector, "language_ids": language_id}
|
aux_input = {"speaker_ids": speaker_id, "style_wav": None, "d_vectors": d_vector, "language_ids": language_id}
|
||||||
aux_out = model.get_aux_input(aux_input)
|
aux_out = model.get_aux_input(aux_input)
|
||||||
self.assertEqual(aux_out["speaker_ids"].shape, speaker_id.shape)
|
self.assertEqual(aux_out["speaker_ids"].shape, speaker_id.shape)
|
||||||
self.assertEqual(aux_out["language_ids"].shape, language_id.shape)
|
self.assertEqual(aux_out["language_ids"].shape, language_id.shape)
|
||||||
|
@ -88,8 +89,8 @@ class TestVits(unittest.TestCase):
|
||||||
|
|
||||||
ref_inp = torch.randn(1, spec_len, 513)
|
ref_inp = torch.randn(1, spec_len, 513)
|
||||||
ref_inp_len = torch.randint(1, spec_effective_len, (1,))
|
ref_inp_len = torch.randint(1, spec_effective_len, (1,))
|
||||||
ref_spk_id = torch.randint(0, num_speakers, (1,))
|
ref_spk_id = torch.randint(1, num_speakers, (1,))
|
||||||
tgt_spk_id = torch.randint(0, num_speakers, (1,))
|
tgt_spk_id = torch.randint(1, num_speakers, (1,))
|
||||||
o_hat, y_mask, (z, z_p, z_hat) = model.voice_conversion(ref_inp, ref_inp_len, ref_spk_id, tgt_spk_id)
|
o_hat, y_mask, (z, z_p, z_hat) = model.voice_conversion(ref_inp, ref_inp_len, ref_spk_id, tgt_spk_id)
|
||||||
|
|
||||||
self.assertEqual(o_hat.shape, (1, 1, spec_len * 256))
|
self.assertEqual(o_hat.shape, (1, 1, spec_len * 256))
|
||||||
|
@ -110,7 +111,9 @@ class TestVits(unittest.TestCase):
|
||||||
return input_dummy, input_lengths, spec, spec_lengths, waveform
|
return input_dummy, input_lengths, spec, spec_lengths, waveform
|
||||||
|
|
||||||
def _check_forward_outputs(self, config, output_dict, encoder_config=None):
|
def _check_forward_outputs(self, config, output_dict, encoder_config=None):
|
||||||
self.assertEqual(output_dict['model_outputs'].shape[2], config.model_args.spec_segment_size * config.audio["hop_length"])
|
self.assertEqual(
|
||||||
|
output_dict["model_outputs"].shape[2], config.model_args.spec_segment_size * config.audio["hop_length"]
|
||||||
|
)
|
||||||
self.assertEqual(output_dict["alignments"].shape, (8, 128, 30))
|
self.assertEqual(output_dict["alignments"].shape, (8, 128, 30))
|
||||||
self.assertEqual(output_dict["alignments"].max(), 1)
|
self.assertEqual(output_dict["alignments"].max(), 1)
|
||||||
self.assertEqual(output_dict["alignments"].min(), 0)
|
self.assertEqual(output_dict["alignments"].min(), 0)
|
||||||
|
@ -120,13 +123,15 @@ class TestVits(unittest.TestCase):
|
||||||
self.assertEqual(output_dict["logs_p"].shape, (8, config.model_args.hidden_channels, 30))
|
self.assertEqual(output_dict["logs_p"].shape, (8, config.model_args.hidden_channels, 30))
|
||||||
self.assertEqual(output_dict["m_q"].shape, (8, config.model_args.hidden_channels, 30))
|
self.assertEqual(output_dict["m_q"].shape, (8, config.model_args.hidden_channels, 30))
|
||||||
self.assertEqual(output_dict["logs_q"].shape, (8, config.model_args.hidden_channels, 30))
|
self.assertEqual(output_dict["logs_q"].shape, (8, config.model_args.hidden_channels, 30))
|
||||||
self.assertEqual(output_dict['waveform_seg'].shape[2], config.model_args.spec_segment_size * config.audio["hop_length"])
|
self.assertEqual(
|
||||||
|
output_dict["waveform_seg"].shape[2], config.model_args.spec_segment_size * config.audio["hop_length"]
|
||||||
|
)
|
||||||
if encoder_config:
|
if encoder_config:
|
||||||
self.assertEqual(output_dict['gt_spk_emb'].shape, (8, encoder_config.model_params["proj_dim"]))
|
self.assertEqual(output_dict["gt_spk_emb"].shape, (8, encoder_config.model_params["proj_dim"]))
|
||||||
self.assertEqual(output_dict['syn_spk_emb'].shape, (8, encoder_config.model_params["proj_dim"]))
|
self.assertEqual(output_dict["syn_spk_emb"].shape, (8, encoder_config.model_params["proj_dim"]))
|
||||||
else:
|
else:
|
||||||
self.assertEqual(output_dict['gt_spk_emb'], None)
|
self.assertEqual(output_dict["gt_spk_emb"], None)
|
||||||
self.assertEqual(output_dict['syn_spk_emb'], None)
|
self.assertEqual(output_dict["syn_spk_emb"], None)
|
||||||
|
|
||||||
def test_forward(self):
|
def test_forward(self):
|
||||||
num_speakers = 0
|
num_speakers = 0
|
||||||
|
@ -147,7 +152,9 @@ class TestVits(unittest.TestCase):
|
||||||
speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device)
|
speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device)
|
||||||
|
|
||||||
model = Vits(config).to(device)
|
model = Vits(config).to(device)
|
||||||
output_dict = model.forward(input_dummy, input_lengths, spec, spec_lengths, waveform, aux_input={"speaker_ids": speaker_ids})
|
output_dict = model.forward(
|
||||||
|
input_dummy, input_lengths, spec, spec_lengths, waveform, aux_input={"speaker_ids": speaker_ids}
|
||||||
|
)
|
||||||
self._check_forward_outputs(config, output_dict)
|
self._check_forward_outputs(config, output_dict)
|
||||||
|
|
||||||
def test_multilingual_forward(self):
|
def test_multilingual_forward(self):
|
||||||
|
@ -162,7 +169,14 @@ class TestVits(unittest.TestCase):
|
||||||
lang_ids = torch.randint(0, num_langs, (8,)).long().to(device)
|
lang_ids = torch.randint(0, num_langs, (8,)).long().to(device)
|
||||||
|
|
||||||
model = Vits(config).to(device)
|
model = Vits(config).to(device)
|
||||||
output_dict = model.forward(input_dummy, input_lengths, spec, spec_lengths, waveform, aux_input={"speaker_ids": speaker_ids, "language_ids": lang_ids})
|
output_dict = model.forward(
|
||||||
|
input_dummy,
|
||||||
|
input_lengths,
|
||||||
|
spec,
|
||||||
|
spec_lengths,
|
||||||
|
waveform,
|
||||||
|
aux_input={"speaker_ids": speaker_ids, "language_ids": lang_ids},
|
||||||
|
)
|
||||||
self._check_forward_outputs(config, output_dict)
|
self._check_forward_outputs(config, output_dict)
|
||||||
|
|
||||||
def test_secl_forward(self):
|
def test_secl_forward(self):
|
||||||
|
@ -175,7 +189,12 @@ class TestVits(unittest.TestCase):
|
||||||
speaker_manager = SpeakerManager()
|
speaker_manager = SpeakerManager()
|
||||||
speaker_manager.speaker_encoder = speaker_encoder
|
speaker_manager.speaker_encoder = speaker_encoder
|
||||||
|
|
||||||
args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10, use_speaker_encoder_as_loss=True)
|
args = VitsArgs(
|
||||||
|
language_ids_file=LANG_FILE,
|
||||||
|
use_language_embedding=True,
|
||||||
|
spec_segment_size=10,
|
||||||
|
use_speaker_encoder_as_loss=True,
|
||||||
|
)
|
||||||
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args)
|
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args)
|
||||||
config.audio.sample_rate = 16000
|
config.audio.sample_rate = 16000
|
||||||
|
|
||||||
|
@ -184,7 +203,14 @@ class TestVits(unittest.TestCase):
|
||||||
lang_ids = torch.randint(0, num_langs, (8,)).long().to(device)
|
lang_ids = torch.randint(0, num_langs, (8,)).long().to(device)
|
||||||
|
|
||||||
model = Vits(config, speaker_manager=speaker_manager).to(device)
|
model = Vits(config, speaker_manager=speaker_manager).to(device)
|
||||||
output_dict = model.forward(input_dummy, input_lengths, spec, spec_lengths, waveform, aux_input={"speaker_ids": speaker_ids, "language_ids": lang_ids})
|
output_dict = model.forward(
|
||||||
|
input_dummy,
|
||||||
|
input_lengths,
|
||||||
|
spec,
|
||||||
|
spec_lengths,
|
||||||
|
waveform,
|
||||||
|
aux_input={"speaker_ids": speaker_ids, "language_ids": lang_ids},
|
||||||
|
)
|
||||||
self._check_forward_outputs(config, output_dict, speaker_encoder_config)
|
self._check_forward_outputs(config, output_dict, speaker_encoder_config)
|
||||||
|
|
||||||
def test_inference(self):
|
def test_inference(self):
|
||||||
|
@ -211,4 +237,4 @@ class TestVits(unittest.TestCase):
|
||||||
speaker_ids = torch.randint(0, num_speakers, (1,)).long().to(device)
|
speaker_ids = torch.randint(0, num_speakers, (1,)).long().to(device)
|
||||||
lang_ids = torch.randint(0, num_langs, (1,)).long().to(device)
|
lang_ids = torch.randint(0, num_langs, (1,)).long().to(device)
|
||||||
model = Vits(config).to(device)
|
model = Vits(config).to(device)
|
||||||
_ = model.inference(input_dummy, {"speaker_ids": speaker_ids, "language_ids": lang_ids})
|
_ = model.inference(input_dummy, {"speaker_ids": speaker_ids, "language_ids": lang_ids})
|
||||||
|
|
Loading…
Reference in New Issue