diff --git a/tests/test_speaker_encoder.py b/tests/test_speaker_encoder.py index 8939ccf6..3b45d2e2 100644 --- a/tests/test_speaker_encoder.py +++ b/tests/test_speaker_encoder.py @@ -4,8 +4,8 @@ import torch as T from tests import get_tests_input_path from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss -from TTS.speaker_encoder.model import SpeakerEncoder - +from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder +# from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder file_path = get_tests_input_path() @@ -14,7 +14,7 @@ class SpeakerEncoderTests(unittest.TestCase): def test_in_out(self): dummy_input = T.rand(4, 20, 80) # B x T x D dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)] - model = SpeakerEncoder(input_dim=80, proj_dim=256, lstm_dim=768, num_lstm_layers=3) + model = LSTMSpeakerEncoder(input_dim=80, proj_dim=256, lstm_dim=768, num_lstm_layers=3) # computing d vectors output = model.forward(dummy_input) assert output.shape[0] == 4 @@ -96,17 +96,3 @@ class AngleProtoLossTests(unittest.TestCase): loss = AngleProtoLoss() output = loss.forward(dummy_input) assert output.item() < 0.005 - - -# class LoaderTest(unittest.TestCase): -# def test_output(self): -# items = libri_tts("/home/erogol/Data/Libri-TTS/train-clean-360/") -# ap = AudioProcessor(**c['audio']) -# dataset = MyDataset(ap, items, 1.6, 64, 10) -# loader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=0, collate_fn=dataset.collate_fn) -# count = 0 -# for mel, spk in loader: -# print(mel.shape) -# if count == 4: -# break -# count += 1