From 88053706450b3e964a972a209d3bc66270a9f7b6 Mon Sep 17 00:00:00 2001 From: erogol Date: Mon, 18 May 2020 11:34:13 +0200 Subject: [PATCH] add tf tacotron2 test and edit test utils imports after utils refactoring --- tests/test_demo_server.py | 3 +- tests/test_loader.py | 2 +- tests/test_tacotron2_model.py | 2 +- tests/test_tacotron2_tf_model.py | 59 ++++++++++++++++++++++++++++++++ tests/test_tacotron_model.py | 2 +- tests/test_text_processing.py | 4 +-- 6 files changed, 66 insertions(+), 6 deletions(-) create mode 100644 tests/test_tacotron2_tf_model.py diff --git a/tests/test_demo_server.py b/tests/test_demo_server.py index a0837686..11d16a45 100644 --- a/tests/test_demo_server.py +++ b/tests/test_demo_server.py @@ -6,7 +6,8 @@ import torch as T from TTS.server.synthesizer import Synthesizer from TTS.tests import get_tests_input_path, get_tests_output_path from TTS.utils.text.symbols import make_symbols, phonemes, symbols -from TTS.utils.generic_utils import load_config, save_checkpoint, setup_model +from TTS.utils.generic_utils import setup_model +from TTS.utils.io import load_config, save_checkpoint class DemoServerTest(unittest.TestCase): diff --git a/tests/test_loader.py b/tests/test_loader.py index 447c7b38..9edd233f 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -5,7 +5,7 @@ import torch import numpy as np from torch.utils.data import DataLoader -from TTS.utils.generic_utils import load_config +from TTS.utils.io import load_config from TTS.utils.audio import AudioProcessor from TTS.datasets import TTSDataset from TTS.datasets.preprocess import ljspeech diff --git a/tests/test_tacotron2_model.py b/tests/test_tacotron2_model.py index aa2869eb..eb91b3cc 100644 --- a/tests/test_tacotron2_model.py +++ b/tests/test_tacotron2_model.py @@ -6,7 +6,7 @@ import numpy as np from torch import optim from torch import nn -from TTS.utils.generic_utils import load_config +from TTS.utils.io import load_config from TTS.layers.losses import MSELossMasked from TTS.models.tacotron2 import Tacotron2 diff --git a/tests/test_tacotron2_tf_model.py b/tests/test_tacotron2_tf_model.py new file mode 100644 index 00000000..27398748 --- /dev/null +++ b/tests/test_tacotron2_tf_model.py @@ -0,0 +1,59 @@ +import os +import copy +import torch +import unittest +import numpy as np +import tensorflow as tf + +from torch import optim +from torch import nn +from TTS.utils.io import load_config +from TTS.layers.losses import MSELossMasked +from TTS.tf.models.tacotron2 import Tacotron2 + +#pylint: disable=unused-variable + +torch.manual_seed(1) +use_cuda = torch.cuda.is_available() +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +file_path = os.path.dirname(os.path.realpath(__file__)) +c = load_config(os.path.join(file_path, 'test_config.json')) + + +class TacotronTFTrainTest(unittest.TestCase): + def test_train_step(self): + ''' test forward pass ''' + input = torch.randint(0, 24, (8, 128)).long().to(device) + input_lengths = torch.randint(100, 128, (8, )).long().to(device) + input_lengths = torch.sort(input_lengths, descending=True)[0] + mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) + mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) + mel_lengths = torch.randint(20, 30, (8, )).long().to(device) + stop_targets = torch.zeros(8, 30, 1).float().to(device) + speaker_ids = torch.randint(0, 5, (8, )).long().to(device) + + input = tf.convert_to_tensor(input.cpu().numpy()) + input_lengths = tf.convert_to_tensor(input_lengths.cpu().numpy()) + mel_spec = tf.convert_to_tensor(mel_spec.cpu().numpy()) + + for idx in mel_lengths: + stop_targets[:, int(idx.item()):, 0] = 1.0 + + stop_targets = stop_targets.view(input.shape[0], + stop_targets.size(1) // c.r, -1) + stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() + + model = Tacotron2(num_chars=24, r=c.r, num_speakers=5) + # training pass + output = model(input, input_lengths, mel_spec, training=True) + + # check model output shapes + assert np.all(output[0].shape == mel_spec.shape) + assert np.all(output[1].shape == mel_spec.shape) + assert output[2].shape[2] == input.shape[1] + assert output[2].shape[1] == (mel_spec.shape[1] // model.decoder.r) + assert output[3].shape[1] == (mel_spec.shape[1] // model.decoder.r) + + # inference pass + output = model(input, training=False) diff --git a/tests/test_tacotron_model.py b/tests/test_tacotron_model.py index ac6712b0..7053a580 100644 --- a/tests/test_tacotron_model.py +++ b/tests/test_tacotron_model.py @@ -5,7 +5,7 @@ import unittest from torch import optim from torch import nn -from TTS.utils.generic_utils import load_config +from TTS.utils.io import load_config from TTS.layers.losses import L1LossMasked from TTS.models.tacotron import Tacotron diff --git a/tests/test_text_processing.py b/tests/test_text_processing.py index 6c0c7058..93edabe7 100644 --- a/tests/test_text_processing.py +++ b/tests/test_text_processing.py @@ -5,7 +5,7 @@ import os import unittest from TTS.utils.text import * from TTS.tests import get_tests_path -from TTS.utils.generic_utils import load_config +from TTS.utils.io import load_config TESTS_PATH = get_tests_path() conf = load_config(os.path.join(TESTS_PATH, 'test_config.json')) @@ -92,4 +92,4 @@ def test_text2phone(): gt = "ɹ|iː|s|ə|n|t| |ɹ|ɪ|s|ɜː|tʃ| |æ|t| |h|ɑːɹ|v|ɚ|d| |h|ɐ|z| |ʃ|oʊ|n| |m|ɛ|d|ᵻ|t|eɪ|ɾ|ɪ|ŋ| |f|ɔː|ɹ| |æ|z| |l|ɪ|ɾ|əl| |æ|z| |eɪ|t| |w|iː|k|s| |k|æ|n| |æ|k|tʃ|uː|əl|i| |ɪ|n|k|ɹ|iː|s|,| |ð|ə| |ɡ|ɹ|eɪ| |m|æ|ɾ|ɚ|ɹ| |ɪ|n|ð|ə| |p|ɑːɹ|t|s| |ʌ|v|ð|ə| |b|ɹ|eɪ|n| |ɹ|ɪ|s|p|ɑː|n|s|ə|b|əl| |f|ɔː|ɹ| |ɪ|m|oʊ|ʃ|ə|n|əl| |ɹ|ɛ|ɡ|j|uː|l|eɪ|ʃ|ə|n| |æ|n|d| |l|ɜː|n|ɪ|ŋ|!" lang = "en-us" ph = text2phone(text, lang) - assert gt == ph, f"\n{phonemes} \n vs \n{gt}" \ No newline at end of file + assert gt == ph, f"\n{phonemes} \n vs \n{gt}"