add tf tacotron2 test and edit test utils imports after utils

refactoring
This commit is contained in:
erogol 2020-05-18 11:34:13 +02:00
parent 67397be1c0
commit 8805370645
6 changed files with 66 additions and 6 deletions

View File

@ -6,7 +6,8 @@ import torch as T
from TTS.server.synthesizer import Synthesizer from TTS.server.synthesizer import Synthesizer
from TTS.tests import get_tests_input_path, get_tests_output_path from TTS.tests import get_tests_input_path, get_tests_output_path
from TTS.utils.text.symbols import make_symbols, phonemes, symbols from TTS.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.utils.generic_utils import load_config, save_checkpoint, setup_model from TTS.utils.generic_utils import setup_model
from TTS.utils.io import load_config, save_checkpoint
class DemoServerTest(unittest.TestCase): class DemoServerTest(unittest.TestCase):

View File

@ -5,7 +5,7 @@ import torch
import numpy as np import numpy as np
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from TTS.utils.generic_utils import load_config from TTS.utils.io import load_config
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.datasets import TTSDataset from TTS.datasets import TTSDataset
from TTS.datasets.preprocess import ljspeech from TTS.datasets.preprocess import ljspeech

View File

@ -6,7 +6,7 @@ import numpy as np
from torch import optim from torch import optim
from torch import nn from torch import nn
from TTS.utils.generic_utils import load_config from TTS.utils.io import load_config
from TTS.layers.losses import MSELossMasked from TTS.layers.losses import MSELossMasked
from TTS.models.tacotron2 import Tacotron2 from TTS.models.tacotron2 import Tacotron2

View File

@ -0,0 +1,59 @@
import os
import copy
import torch
import unittest
import numpy as np
import tensorflow as tf
from torch import optim
from torch import nn
from TTS.utils.io import load_config
from TTS.layers.losses import MSELossMasked
from TTS.tf.models.tacotron2 import Tacotron2
#pylint: disable=unused-variable
torch.manual_seed(1)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
file_path = os.path.dirname(os.path.realpath(__file__))
c = load_config(os.path.join(file_path, 'test_config.json'))
class TacotronTFTrainTest(unittest.TestCase):
def test_train_step(self):
''' test forward pass '''
input = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.randint(100, 128, (8, )).long().to(device)
input_lengths = torch.sort(input_lengths, descending=True)[0]
mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
stop_targets = torch.zeros(8, 30, 1).float().to(device)
speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
input = tf.convert_to_tensor(input.cpu().numpy())
input_lengths = tf.convert_to_tensor(input_lengths.cpu().numpy())
mel_spec = tf.convert_to_tensor(mel_spec.cpu().numpy())
for idx in mel_lengths:
stop_targets[:, int(idx.item()):, 0] = 1.0
stop_targets = stop_targets.view(input.shape[0],
stop_targets.size(1) // c.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
model = Tacotron2(num_chars=24, r=c.r, num_speakers=5)
# training pass
output = model(input, input_lengths, mel_spec, training=True)
# check model output shapes
assert np.all(output[0].shape == mel_spec.shape)
assert np.all(output[1].shape == mel_spec.shape)
assert output[2].shape[2] == input.shape[1]
assert output[2].shape[1] == (mel_spec.shape[1] // model.decoder.r)
assert output[3].shape[1] == (mel_spec.shape[1] // model.decoder.r)
# inference pass
output = model(input, training=False)

View File

@ -5,7 +5,7 @@ import unittest
from torch import optim from torch import optim
from torch import nn from torch import nn
from TTS.utils.generic_utils import load_config from TTS.utils.io import load_config
from TTS.layers.losses import L1LossMasked from TTS.layers.losses import L1LossMasked
from TTS.models.tacotron import Tacotron from TTS.models.tacotron import Tacotron

View File

@ -5,7 +5,7 @@ import os
import unittest import unittest
from TTS.utils.text import * from TTS.utils.text import *
from TTS.tests import get_tests_path from TTS.tests import get_tests_path
from TTS.utils.generic_utils import load_config from TTS.utils.io import load_config
TESTS_PATH = get_tests_path() TESTS_PATH = get_tests_path()
conf = load_config(os.path.join(TESTS_PATH, 'test_config.json')) conf = load_config(os.path.join(TESTS_PATH, 'test_config.json'))