coqui-tts/tests/loader_tests.py

545 lines
21 KiB
Python

import os
import unittest
import shutil
import numpy as np
from torch.utils.data import DataLoader
from TTS.utils.generic_utils import load_config
from TTS.utils.audio import AudioProcessor
from TTS.datasets import TTSDataset, TTSDatasetCached, TTSDatasetMemory
from TTS.datasets.preprocess import ljspeech, tts_cache
file_path = os.path.dirname(os.path.realpath(__file__))
OUTPATH = os.path.join(file_path, "outputs/loader_tests/")
os.makedirs(OUTPATH, exist_ok=True)
c = load_config(os.path.join(file_path, 'test_config.json'))
ok_ljspeech = os.path.exists(c.data_path)
DATA_EXIST = True
CACHE_EXIST = True
if not os.path.exists(c.data_path_cache):
CACHE_EXIST = False
if not os.path.exists(c.data_path):
DATA_EXIST = False
class TestTTSDataset(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestTTSDataset, self).__init__(*args, **kwargs)
self.max_loader_iter = 4
self.ap = AudioProcessor(**c.audio)
def _create_dataloader(self, batch_size, r, bgs):
dataset = TTSDataset.MyDataset(
c.data_path,
'metadata.csv',
r,
c.text_cleaner,
preprocessor=ljspeech,
ap=self.ap,
batch_group_size=bgs,
min_seq_len=c.min_seq_len)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=True,
num_workers=c.num_loader_workers)
return dataloader, dataset
def test_loader(self):
if ok_ljspeech:
dataloader, dataset = self._create_dataloader(2, c.r, 0)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
neg_values = text_input[text_input < 0]
check_count = len(neg_values)
assert check_count == 0, \
" !! Negative values in text_input: {}".format(check_count)
# TODO: more assertion here
assert linear_input.shape[0] == c.batch_size
assert linear_input.shape[2] == self.ap.num_freq
assert mel_input.shape[0] == c.batch_size
assert mel_input.shape[2] == c.audio['num_mels']
# check normalization ranges
if self.ap.symmetric_norm:
assert mel_input.max() <= self.ap.max_norm
assert mel_input.min() >= -self.ap.max_norm
assert mel_input.min() < 0
else:
assert mel_input.max() <= self.ap.max_norm
assert mel_input.min() >= 0
def test_batch_group_shuffle(self):
if ok_ljspeech:
dataloader, dataset = self._create_dataloader(2, c.r, 16)
last_length = 0
frames = dataset.items
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
avg_length = mel_lengths.numpy().mean()
assert avg_length >= last_length
dataloader.dataset.sort_items()
assert frames[0] != dataloader.dataset.items[0]
def test_padding_and_spec(self):
if ok_ljspeech:
dataloader, dataset = self._create_dataloader(1, 1, 0)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
# check mel_spec consistency
wav = self.ap.load_wav(item_idx[0])
mel = self.ap.melspectrogram(wav)
mel_dl = mel_input[0].cpu().numpy()
assert (
abs(mel.T).astype("float32") - abs(mel_dl[:-1])).sum() == 0
# check mel-spec correctness
mel_spec = mel_input[0].cpu().numpy()
wav = self.ap.inv_mel_spectrogram(mel_spec.T)
self.ap.save_wav(wav, OUTPATH + '/mel_inv_dataloader.wav')
shutil.copy(item_idx[0], OUTPATH + '/mel_target_dataloader.wav')
# check linear-spec
linear_spec = linear_input[0].cpu().numpy()
wav = self.ap.inv_spectrogram(linear_spec.T)
self.ap.save_wav(wav, OUTPATH + '/linear_inv_dataloader.wav')
shutil.copy(item_idx[0], OUTPATH + '/linear_target_dataloader.wav')
# check the last time step to be zero padded
assert linear_input[0, -1].sum() == 0
assert linear_input[0, -2].sum() != 0
assert mel_input[0, -1].sum() == 0
assert mel_input[0, -2].sum() != 0
assert stop_target[0, -1] == 1
assert stop_target[0, -2] == 0
assert stop_target.sum() == 1
assert len(mel_lengths.shape) == 1
assert mel_lengths[0] == linear_input[0].shape[0]
assert mel_lengths[0] == mel_input[0].shape[0]
# Test for batch size 2
dataloader, dataset = self._create_dataloader(2, 1, 0)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
if mel_lengths[0] > mel_lengths[1]:
idx = 0
else:
idx = 1
# check the first item in the batch
assert linear_input[idx, -1].sum() == 0
assert linear_input[idx, -2].sum() != 0, linear_input
assert mel_input[idx, -1].sum() == 0
assert mel_input[idx, -2].sum() != 0, mel_input
assert stop_target[idx, -1] == 1
assert stop_target[idx, -2] == 0
assert stop_target[idx].sum() == 1
assert len(mel_lengths.shape) == 1
assert mel_lengths[idx] == mel_input[idx].shape[0]
assert mel_lengths[idx] == linear_input[idx].shape[0]
# check the second itme in the batch
assert linear_input[1 - idx, -1].sum() == 0
assert mel_input[1 - idx, -1].sum() == 0
assert stop_target[1 - idx, -1] == 1
assert len(mel_lengths.shape) == 1
# check batch conditions
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
class TestTTSDatasetCached(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestTTSDatasetCached, self).__init__(*args, **kwargs)
self.max_loader_iter = 4
self.c = load_config(os.path.join(c.data_path_cache, 'config.json'))
self.ap = AudioProcessor(**self.c.audio)
def _create_dataloader(self, batch_size, r, bgs):
dataset = TTSDatasetCached.MyDataset(
c.data_path_cache,
'tts_metadata.csv',
r,
c.text_cleaner,
preprocessor=tts_cache,
ap=self.ap,
batch_group_size=bgs,
min_seq_len=c.min_seq_len)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=True,
num_workers=c.num_loader_workers)
return dataloader, dataset
def test_loader(self):
if ok_ljspeech:
dataloader, dataset = self._create_dataloader(2, c.r, 0)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
neg_values = text_input[text_input < 0]
check_count = len(neg_values)
assert check_count == 0, \
" !! Negative values in text_input: {}".format(check_count)
# TODO: more assertion here
assert mel_input.shape[0] == c.batch_size
assert mel_input.shape[2] == c.audio['num_mels']
if self.ap.symmetric_norm:
assert mel_input.max() <= self.ap.max_norm
assert mel_input.min() >= -self.ap.max_norm
assert mel_input.min() < 0
else:
assert mel_input.max() <= self.ap.max_norm
assert mel_input.min() >= 0
def test_batch_group_shuffle(self):
if ok_ljspeech:
dataloader, dataset = self._create_dataloader(2, c.r, 16)
frames = dataset.items
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
neg_values = text_input[text_input < 0]
check_count = len(neg_values)
assert check_count == 0, \
" !! Negative values in text_input: {}".format(check_count)
# TODO: more assertion here
assert mel_input.shape[0] == c.batch_size
assert mel_input.shape[2] == c.audio['num_mels']
dataloader.dataset.sort_items()
assert frames[0] != dataloader.dataset.items[0]
def test_padding_and_spec(self):
if ok_ljspeech:
dataloader, dataset = self._create_dataloader(1, 1, 0)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
# check mel_spec consistency
if item_idx[0].split('.')[-1] == 'npy':
wav = np.load(item_idx[0])
else:
wav = self.ap.load_wav(item_idx[0])
mel = self.ap.melspectrogram(wav)
mel_dl = mel_input[0].cpu().numpy()
assert (abs(mel.T).astype("float32") - abs(
mel_dl[:-1])).sum() == 0, (
abs(mel.T).astype("float32") - abs(mel_dl[:-1])).sum()
# check mel-spec correctness
mel_spec = mel_input[0].cpu().numpy()
wav = self.ap.inv_mel_spectrogram(mel_spec.T)
self.ap.save_wav(wav,
OUTPATH + '/mel_inv_dataloader_cache.wav')
shutil.copy(item_idx[0], OUTPATH + '/mel_target_dataloader_cache.wav')
# check the last time step to be zero padded
assert mel_input[0, -1].sum() == 0
assert mel_input[0, -2].sum() != 0
assert stop_target[0, -1] == 1
assert stop_target[0, -2] == 0
assert stop_target.sum() == 1
assert len(mel_lengths.shape) == 1
assert mel_lengths[0] == mel_input[0].shape[0]
# Test for batch size 2
dataloader, dataset = self._create_dataloader(2, 1, 0)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
if mel_lengths[0] > mel_lengths[1]:
idx = 0
else:
idx = 1
# check the first item in the batch
assert mel_input[idx, -1].sum() == 0
assert mel_input[idx, -2].sum() != 0, mel_input
assert stop_target[idx, -1] == 1
assert stop_target[idx, -2] == 0
assert stop_target[idx].sum() == 1
assert len(mel_lengths.shape) == 1
assert mel_lengths[idx] == mel_input[idx].shape[0]
# check the second itme in the batch
assert mel_input[1 - idx, -1].sum() == 0
assert stop_target[1 - idx, -1] == 1
assert len(mel_lengths.shape) == 1
# check batch conditions
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
# class TestTTSDatasetMemory(unittest.TestCase):
# def __init__(self, *args, **kwargs):
# super(TestTTSDatasetMemory, self).__init__(*args, **kwargs)
# self.max_loader_iter = 4
# self.c = load_config(os.path.join(c.data_path_cache, 'config.json'))
# self.ap = AudioProcessor(**c.audio)
# def test_loader(self):
# if ok_ljspeech:
# dataset = TTSDatasetMemory.MyDataset(
# c.data_path_cache,
# 'tts_metadata.csv',
# c.r,
# c.text_cleaner,
# preprocessor=tts_cache,
# ap=self.ap,
# min_seq_len=c.min_seq_len)
# dataloader = DataLoader(
# dataset,
# batch_size=2,
# shuffle=True,
# collate_fn=dataset.collate_fn,
# drop_last=True,
# num_workers=c.num_loader_workers)
# for i, data in enumerate(dataloader):
# if i == self.max_loader_iter:
# break
# text_input = data[0]
# text_lengths = data[1]
# linear_input = data[2]
# mel_input = data[3]
# mel_lengths = data[4]
# stop_target = data[5]
# item_idx = data[6]
# neg_values = text_input[text_input < 0]
# check_count = len(neg_values)
# assert check_count == 0, \
# " !! Negative values in text_input: {}".format(check_count)
# # check mel-spec shape
# assert mel_input.shape[0] == c.batch_size
# assert mel_input.shape[2] == c.audio['num_mels']
# assert mel_input.max() <= self.ap.max_norm
# # check data range
# if self.ap.symmetric_norm:
# assert mel_input.max() <= self.ap.max_norm
# assert mel_input.min() >= -self.ap.max_norm
# assert mel_input.min() < 0
# else:
# assert mel_input.max() <= self.ap.max_norm
# assert mel_input.min() >= 0
# def test_batch_group_shuffle(self):
# if ok_ljspeech:
# dataset = TTSDatasetMemory.MyDataset(
# c.data_path_cache,
# 'tts_metadata.csv',
# c.r,
# c.text_cleaner,
# preprocessor=ljspeech,
# ap=self.ap,
# batch_group_size=16,
# min_seq_len=c.min_seq_len)
# dataloader = DataLoader(
# dataset,
# batch_size=2,
# shuffle=True,
# collate_fn=dataset.collate_fn,
# drop_last=True,
# num_workers=c.num_loader_workers)
# frames = dataset.items
# for i, data in enumerate(dataloader):
# if i == self.max_loader_iter:
# break
# text_input = data[0]
# text_lengths = data[1]
# linear_input = data[2]
# mel_input = data[3]
# mel_lengths = data[4]
# stop_target = data[5]
# item_idx = data[6]
# neg_values = text_input[text_input < 0]
# check_count = len(neg_values)
# assert check_count == 0, \
# " !! Negative values in text_input: {}".format(check_count)
# assert mel_input.shape[0] == c.batch_size
# assert mel_input.shape[2] == c.audio['num_mels']
# dataloader.dataset.sort_items()
# assert frames[0] != dataloader.dataset.items[0]
# def test_padding_and_spec(self):
# if ok_ljspeech:
# dataset = TTSDatasetMemory.MyDataset(
# c.data_path_cache,
# 'tts_meta_data.csv',
# 1,
# c.text_cleaner,
# preprocessor=ljspeech,
# ap=self.ap,
# min_seq_len=c.min_seq_len)
# # Test for batch size 1
# dataloader = DataLoader(
# dataset,
# batch_size=1,
# shuffle=False,
# collate_fn=dataset.collate_fn,
# drop_last=True,
# num_workers=c.num_loader_workers)
# for i, data in enumerate(dataloader):
# if i == self.max_loader_iter:
# break
# text_input = data[0]
# text_lengths = data[1]
# linear_input = data[2]
# mel_input = data[3]
# mel_lengths = data[4]
# stop_target = data[5]
# item_idx = data[6]
# # check mel_spec consistency
# if item_idx[0].split('.')[-1] == 'npy':
# wav = np.load(item_idx[0])
# else:
# wav = self.ap.load_wav(item_idx[0])
# mel = self.ap.melspectrogram(wav)
# mel_dl = mel_input[0].cpu().numpy()
# assert (
# abs(mel.T).astype("float32") - abs(mel_dl[:-1])).sum() == 0
# # check mel-spec correctness
# mel_spec = mel_input[0].cpu().numpy()
# wav = self.ap.inv_mel_spectrogram(mel_spec.T)
# self.ap.save_wav(wav, OUTPATH + '/mel_inv_dataloader_memo.wav')
# shutil.copy(item_idx[0], OUTPATH + '/mel_target_dataloader_memo.wav')
# # check the last time step to be zero padded
# assert mel_input[0, -1].sum() == 0
# assert mel_input[0, -2].sum() != 0
# assert stop_target[0, -1] == 1
# assert stop_target[0, -2] == 0
# assert stop_target.sum() == 1
# assert len(mel_lengths.shape) == 1
# assert mel_lengths[0] == mel_input[0].shape[0]
# # Test for batch size 2
# dataloader = DataLoader(
# dataset,
# batch_size=2,
# shuffle=False,
# collate_fn=dataset.collate_fn,
# drop_last=False,
# num_workers=c.num_loader_workers)
# for i, data in enumerate(dataloader):
# if i == self.max_loader_iter:
# break
# text_input = data[0]
# text_lengths = data[1]
# linear_input = data[2]
# mel_input = data[3]
# mel_lengths = data[4]
# stop_target = data[5]
# item_idx = data[6]
# if mel_lengths[0] > mel_lengths[1]:
# idx = 0
# else:
# idx = 1
# # check the first item in the batch
# assert mel_input[idx, -1].sum() == 0
# assert mel_input[idx, -2].sum() != 0, mel_input
# assert stop_target[idx, -1] == 1
# assert stop_target[idx, -2] == 0
# assert stop_target[idx].sum() == 1
# assert len(mel_lengths.shape) == 1
# assert mel_lengths[idx] == mel_input[idx].shape[0]
# # check the second itme in the batch
# assert mel_input[1 - idx, -1].sum() == 0
# assert stop_target[1 - idx, -1] == 1
# assert len(mel_lengths.shape) == 1
# # check batch conditions
# assert (mel_input * stop_target.unsqueeze(2)).sum() == 0