bugfix on tacotron unit test

This commit is contained in:
Edresson 2021-05-05 06:38:01 -03:00
parent e3f56b613b
commit d78f27ea41
1 changed files with 2 additions and 0 deletions

View File

@ -37,6 +37,7 @@ class TacotronTrainTest(unittest.TestCase):
mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device)
linear_spec = torch.rand(8, 30, c.audio["fft_size"]).to(device)
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
mel_lengths[-1] = mel_spec.size(1)
stop_targets = torch.zeros(8, 30, 1).float().to(device)
speaker_ids = torch.randint(0, 5, (8,)).long().to(device)
@ -96,6 +97,7 @@ class MultiSpeakeTacotronTrainTest(unittest.TestCase):
mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device)
linear_spec = torch.rand(8, 30, c.audio["fft_size"]).to(device)
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
mel_lengths[-1] = mel_spec.size(1)
stop_targets = torch.zeros(8, 30, 1).float().to(device)
speaker_embeddings = torch.rand(8, 55).to(device)