diff --git a/tf/tests/__init__.py b/tf/tests/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/tf/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/tf/tests/test_layers_tf.py b/tf/tests/test_layers_tf.py new file mode 100644 index 00000000..e69de29b diff --git a/tf/tests/test_tacotron2_tf_model.py b/tf/tests/test_tacotron2_tf_model.py new file mode 100644 index 00000000..dc904ea2 --- /dev/null +++ b/tf/tests/test_tacotron2_tf_model.py @@ -0,0 +1,135 @@ +import os +import torch +import unittest +import numpy as np +import tensorflow as tf +tf.get_logger().setLevel('INFO') + +from TTS.utils.io import load_config +from TTS.tf.models.tacotron2 import Tacotron2 +from TTS.tf.utils.tflite import convert_tacotron2_to_tflite, load_tflite_model +from TTS.utils.synthesis import run_model_tflite, text_to_seqvec + +#pylint: disable=unused-variable + +torch.manual_seed(1) +use_cuda = torch.cuda.is_available() +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +file_path = os.path.dirname(os.path.realpath(__file__)).replace('/tf/','/') +c = load_config(os.path.join(file_path, 'test_config.json')) + + +class TacotronTFTrainTest(unittest.TestCase): + + @staticmethod + def generate_dummy_inputs(): + chars_seq = torch.randint(0, 24, (8, 128)).long().to(device) + chars_seq_lengths = torch.randint(100, 128, (8, )).long().to(device) + chars_seq_lengths = torch.sort(chars_seq_lengths, descending=True)[0] + mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) + mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) + mel_lengths = torch.randint(20, 30, (8, )).long().to(device) + stop_targets = torch.zeros(8, 30, 1).float().to(device) + speaker_ids = torch.randint(0, 5, (8, )).long().to(device) + + chars_seq = tf.convert_to_tensor(chars_seq.cpu().numpy()) + chars_seq_lengths = tf.convert_to_tensor(chars_seq_lengths.cpu().numpy()) + mel_spec = tf.convert_to_tensor(mel_spec.cpu().numpy()) + return chars_seq, chars_seq_lengths, mel_spec, mel_postnet_spec, mel_lengths,\ + stop_targets, speaker_ids + + def test_train_step(self): + ''' test forward pass ''' + chars_seq, chars_seq_lengths, mel_spec, mel_postnet_spec, mel_lengths,\ + stop_targets, speaker_ids = self.generate_dummy_inputs() + + for idx in mel_lengths: + stop_targets[:, int(idx.item()):, 0] = 1.0 + + stop_targets = stop_targets.view(chars_seq.shape[0], + stop_targets.size(1) // c.r, -1) + stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() + + model = Tacotron2(num_chars=24, r=c.r, num_speakers=5) + # training pass + output = model(chars_seq, chars_seq_lengths, mel_spec, training=True) + + # check model output shapes + assert np.all(output[0].shape == mel_spec.shape) + assert np.all(output[1].shape == mel_spec.shape) + assert output[2].shape[2] == chars_seq.shape[1] + assert output[2].shape[1] == (mel_spec.shape[1] // model.decoder.r) + assert output[3].shape[1] == (mel_spec.shape[1] // model.decoder.r) + + # inference pass + output = model(chars_seq, training=False) + + def test_forward_attention(self,): + chars_seq, chars_seq_lengths, mel_spec, mel_postnet_spec, mel_lengths,\ + stop_targets, speaker_ids = self.generate_dummy_inputs() + + for idx in mel_lengths: + stop_targets[:, int(idx.item()):, 0] = 1.0 + + stop_targets = stop_targets.view(chars_seq.shape[0], + stop_targets.size(1) // c.r, -1) + stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() + + model = Tacotron2(num_chars=24, r=c.r, num_speakers=5, forward_attn=True) + # training pass + output = model(chars_seq, chars_seq_lengths, mel_spec, training=True) + + # check model output shapes + assert np.all(output[0].shape == mel_spec.shape) + assert np.all(output[1].shape == mel_spec.shape) + assert output[2].shape[2] == chars_seq.shape[1] + assert output[2].shape[1] == (mel_spec.shape[1] // model.decoder.r) + assert output[3].shape[1] == (mel_spec.shape[1] // model.decoder.r) + + # inference pass + output = model(chars_seq, training=False) + + def test_tflite_conversion(self, ): + model = Tacotron2(num_chars=24, + num_speakers=0, + r=3, + postnet_output_dim=80, + decoder_output_dim=80, + attn_type='original', + attn_win=False, + attn_norm='sigmoid', + prenet_type='original', + prenet_dropout=True, + forward_attn=False, + trans_agent=False, + forward_attn_mask=False, + location_attn=True, + attn_K=0, + separate_stopnet=True, + bidirectional_decoder=False, + enable_tflite=True) + model.build_inference() + convert_tacotron2_to_tflite(model, output_path='test_tacotron2.tflite', experimental_converter=True) + # init tflite model + tflite_model = load_tflite_model('test_tacotron2.tflite') + # fake input + inputs = tf.random.uniform([1, 4], maxval=10, dtype=tf.int32) + # run inference + # get input and output details + input_details = tflite_model.get_input_details() + output_details = tflite_model.get_output_details() + # reshape input tensor for the new input shape + tflite_model.resize_tensor_input(input_details[0]['index'], inputs.shape) + tflite_model.allocate_tensors() + detail = input_details[0] + input_shape = detail['shape'] + tflite_model.set_tensor(detail['index'], inputs) + # run the tflite_model + tflite_model.invoke() + # collect outputs + decoder_output = tflite_model.get_tensor(output_details[0]['index']) + postnet_output = tflite_model.get_tensor(output_details[1]['index']) + # remove tflite binary + os.remove('test_tacotron2.tflite') +