mirror of https://github.com/coqui-ai/TTS.git
37 lines
1.2 KiB
Python
37 lines
1.2 KiB
Python
import unittest
|
|
|
|
import torch
|
|
|
|
from TTS.tts.configs.matcha_tts import MatchaTTSConfig
|
|
from TTS.tts.models.matcha_tts import MatchaTTS
|
|
|
|
torch.manual_seed(1)
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
|
c = MatchaTTSConfig()
|
|
|
|
|
|
class TestMatchTTS(unittest.TestCase):
|
|
@staticmethod
|
|
def _create_inputs(batch_size=8):
|
|
input_dummy = torch.randint(0, 24, (batch_size, 128)).long().to(device)
|
|
input_lengths = torch.randint(100, 129, (batch_size,)).long().to(device)
|
|
input_lengths[-1] = 128
|
|
mel_spec = torch.rand(batch_size, 30, c.audio["num_mels"]).to(device)
|
|
mel_lengths = torch.randint(20, 30, (batch_size,)).long().to(device)
|
|
speaker_ids = torch.randint(0, 5, (batch_size,)).long().to(device)
|
|
return input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids
|
|
|
|
def _test_forward(self, batch_size):
|
|
input_dummy, input_lengths, mel_spec, mel_lengths, _ = self._create_inputs(batch_size)
|
|
config = MatchaTTSConfig(num_chars=32)
|
|
model = MatchaTTS(config).to(device)
|
|
|
|
model.train()
|
|
|
|
model.forward(input_dummy, input_lengths, mel_spec, mel_lengths)
|
|
|
|
def test_forward(self):
|
|
self._test_forward(1)
|
|
self._test_forward(3)
|