mirror of https://github.com/coqui-ai/TTS.git
fix unit tests
This commit is contained in:
parent
c90037c2e9
commit
7a9a27282a
|
@ -4,8 +4,8 @@ import torch as T
|
||||||
|
|
||||||
from tests import get_tests_input_path
|
from tests import get_tests_input_path
|
||||||
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss
|
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss
|
||||||
from TTS.speaker_encoder.model import SpeakerEncoder
|
from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder
|
||||||
|
# from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder
|
||||||
file_path = get_tests_input_path()
|
file_path = get_tests_input_path()
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ class SpeakerEncoderTests(unittest.TestCase):
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
dummy_input = T.rand(4, 20, 80) # B x T x D
|
dummy_input = T.rand(4, 20, 80) # B x T x D
|
||||||
dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)]
|
dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)]
|
||||||
model = SpeakerEncoder(input_dim=80, proj_dim=256, lstm_dim=768, num_lstm_layers=3)
|
model = LSTMSpeakerEncoder(input_dim=80, proj_dim=256, lstm_dim=768, num_lstm_layers=3)
|
||||||
# computing d vectors
|
# computing d vectors
|
||||||
output = model.forward(dummy_input)
|
output = model.forward(dummy_input)
|
||||||
assert output.shape[0] == 4
|
assert output.shape[0] == 4
|
||||||
|
@ -96,17 +96,3 @@ class AngleProtoLossTests(unittest.TestCase):
|
||||||
loss = AngleProtoLoss()
|
loss = AngleProtoLoss()
|
||||||
output = loss.forward(dummy_input)
|
output = loss.forward(dummy_input)
|
||||||
assert output.item() < 0.005
|
assert output.item() < 0.005
|
||||||
|
|
||||||
|
|
||||||
# class LoaderTest(unittest.TestCase):
|
|
||||||
# def test_output(self):
|
|
||||||
# items = libri_tts("/home/erogol/Data/Libri-TTS/train-clean-360/")
|
|
||||||
# ap = AudioProcessor(**c['audio'])
|
|
||||||
# dataset = MyDataset(ap, items, 1.6, 64, 10)
|
|
||||||
# loader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=0, collate_fn=dataset.collate_fn)
|
|
||||||
# count = 0
|
|
||||||
# for mel, spk in loader:
|
|
||||||
# print(mel.shape)
|
|
||||||
# if count == 4:
|
|
||||||
# break
|
|
||||||
# count += 1
|
|
||||||
|
|
Loading…
Reference in New Issue