mirror of https://github.com/coqui-ai/TTS.git
tests updates
This commit is contained in:
parent
caae1af4f6
commit
dce1715e0f
|
@ -61,7 +61,7 @@ class MyDataset(Dataset):
|
||||||
self.use_phonemes = use_phonemes
|
self.use_phonemes = use_phonemes
|
||||||
self.phoneme_cache_path = phoneme_cache_path
|
self.phoneme_cache_path = phoneme_cache_path
|
||||||
self.phoneme_language = phoneme_language
|
self.phoneme_language = phoneme_language
|
||||||
if not os.path.isdir(phoneme_cache_path):
|
if use_phonemes and not os.path.isdir(phoneme_cache_path):
|
||||||
os.makedirs(phoneme_cache_path)
|
os.makedirs(phoneme_cache_path)
|
||||||
print(" > DataLoader initialization")
|
print(" > DataLoader initialization")
|
||||||
print(" | > Data path: {}".format(root_path))
|
print(" | > Data path: {}".format(root_path))
|
||||||
|
|
|
@ -38,7 +38,7 @@ class CBHGTests(unittest.TestCase):
|
||||||
|
|
||||||
class DecoderTests(unittest.TestCase):
|
class DecoderTests(unittest.TestCase):
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
layer = Decoder(in_features=256, memory_dim=80, r=2)
|
layer = Decoder(in_features=256, memory_dim=80, r=2, memory_size=4, attn_windowing=False)
|
||||||
dummy_input = T.rand(4, 8, 256)
|
dummy_input = T.rand(4, 8, 256)
|
||||||
dummy_memory = T.rand(4, 2, 80)
|
dummy_memory = T.rand(4, 2, 80)
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ import numpy as np
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from utils.generic_utils import load_config
|
from utils.generic_utils import load_config
|
||||||
from utils.audio import AudioProcessor
|
from utils.audio import AudioProcessor
|
||||||
from datasets import TTSDataset, TTSDatasetCached, TTSDatasetMemory
|
from datasets import TTSDataset
|
||||||
from datasets.preprocess import ljspeech, tts_cache
|
from datasets.preprocess import ljspeech, tts_cache
|
||||||
|
|
||||||
file_path = os.path.dirname(os.path.realpath(__file__))
|
file_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
@ -41,7 +41,9 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
preprocessor=ljspeech,
|
preprocessor=ljspeech,
|
||||||
ap=self.ap,
|
ap=self.ap,
|
||||||
batch_group_size=bgs,
|
batch_group_size=bgs,
|
||||||
min_seq_len=c.min_seq_len)
|
min_seq_len=c.min_seq_len,
|
||||||
|
max_seq_len=float("inf"),
|
||||||
|
use_phonemes=False)
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
@ -190,366 +192,4 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
|
|
||||||
# check batch conditions
|
# check batch conditions
|
||||||
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
||||||
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
||||||
|
|
||||||
|
|
||||||
class TestTTSDatasetCached(unittest.TestCase):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super(TestTTSDatasetCached, self).__init__(*args, **kwargs)
|
|
||||||
self.max_loader_iter = 4
|
|
||||||
self.c = load_config(os.path.join(c.data_path_cache, 'config.json'))
|
|
||||||
self.ap = AudioProcessor(**self.c.audio)
|
|
||||||
|
|
||||||
def _create_dataloader(self, batch_size, r, bgs):
|
|
||||||
|
|
||||||
dataset = TTSDataset.MyDataset(
|
|
||||||
c.data_path_cache,
|
|
||||||
'tts_metadata.csv',
|
|
||||||
r,
|
|
||||||
c.text_cleaner,
|
|
||||||
preprocessor=tts_cache,
|
|
||||||
ap=self.ap,
|
|
||||||
batch_group_size=bgs,
|
|
||||||
min_seq_len=c.min_seq_len,
|
|
||||||
max_seq_len=c.max_seq_len,
|
|
||||||
cached=True)
|
|
||||||
|
|
||||||
dataloader = DataLoader(
|
|
||||||
dataset,
|
|
||||||
batch_size=batch_size,
|
|
||||||
shuffle=False,
|
|
||||||
collate_fn=dataset.collate_fn,
|
|
||||||
drop_last=True,
|
|
||||||
num_workers=c.num_loader_workers)
|
|
||||||
return dataloader, dataset
|
|
||||||
|
|
||||||
def test_loader(self):
|
|
||||||
if ok_ljspeech:
|
|
||||||
dataloader, dataset = self._create_dataloader(2, c.r, 0)
|
|
||||||
for i, data in enumerate(dataloader):
|
|
||||||
if i == self.max_loader_iter:
|
|
||||||
break
|
|
||||||
text_input = data[0]
|
|
||||||
text_lengths = data[1]
|
|
||||||
linear_input = data[2]
|
|
||||||
mel_input = data[3]
|
|
||||||
mel_lengths = data[4]
|
|
||||||
stop_target = data[5]
|
|
||||||
item_idx = data[6]
|
|
||||||
|
|
||||||
neg_values = text_input[text_input < 0]
|
|
||||||
check_count = len(neg_values)
|
|
||||||
assert check_count == 0, \
|
|
||||||
" !! Negative values in text_input: {}".format(check_count)
|
|
||||||
# TODO: more assertion here
|
|
||||||
assert mel_input.shape[0] == c.batch_size
|
|
||||||
assert mel_input.shape[2] == c.audio['num_mels']
|
|
||||||
|
|
||||||
if self.ap.symmetric_norm:
|
|
||||||
assert mel_input.max() <= self.ap.max_norm
|
|
||||||
assert mel_input.min() >= -self.ap.max_norm
|
|
||||||
assert mel_input.min() < 0
|
|
||||||
else:
|
|
||||||
assert mel_input.max() <= self.ap.max_norm
|
|
||||||
assert mel_input.min() >= 0
|
|
||||||
|
|
||||||
def test_batch_group_shuffle(self):
|
|
||||||
if ok_ljspeech:
|
|
||||||
dataloader, dataset = self._create_dataloader(2, c.r, 16)
|
|
||||||
frames = dataset.items
|
|
||||||
for i, data in enumerate(dataloader):
|
|
||||||
if i == self.max_loader_iter:
|
|
||||||
break
|
|
||||||
text_input = data[0]
|
|
||||||
text_lengths = data[1]
|
|
||||||
linear_input = data[2]
|
|
||||||
mel_input = data[3]
|
|
||||||
mel_lengths = data[4]
|
|
||||||
stop_target = data[5]
|
|
||||||
item_idx = data[6]
|
|
||||||
|
|
||||||
neg_values = text_input[text_input < 0]
|
|
||||||
check_count = len(neg_values)
|
|
||||||
assert check_count == 0, \
|
|
||||||
" !! Negative values in text_input: {}".format(check_count)
|
|
||||||
# TODO: more assertion here
|
|
||||||
assert mel_input.shape[0] == c.batch_size
|
|
||||||
assert mel_input.shape[2] == c.audio['num_mels']
|
|
||||||
dataloader.dataset.sort_items()
|
|
||||||
assert frames[0] != dataloader.dataset.items[0]
|
|
||||||
|
|
||||||
def test_padding_and_spec(self):
|
|
||||||
if ok_ljspeech:
|
|
||||||
dataloader, dataset = self._create_dataloader(1, 1, 0)
|
|
||||||
for i, data in enumerate(dataloader):
|
|
||||||
if i == self.max_loader_iter:
|
|
||||||
break
|
|
||||||
text_input = data[0]
|
|
||||||
text_lengths = data[1]
|
|
||||||
linear_input = data[2]
|
|
||||||
mel_input = data[3]
|
|
||||||
mel_lengths = data[4]
|
|
||||||
stop_target = data[5]
|
|
||||||
item_idx = data[6]
|
|
||||||
|
|
||||||
# check mel_spec consistency
|
|
||||||
if item_idx[0].split('.')[-1] == 'npy':
|
|
||||||
wav = np.load(item_idx[0])
|
|
||||||
else:
|
|
||||||
wav = self.ap.load_wav(item_idx[0])
|
|
||||||
mel = self.ap.melspectrogram(wav)
|
|
||||||
mel_dl = mel_input[0].cpu().numpy()
|
|
||||||
assert (abs(mel.T).astype("float32") - abs(
|
|
||||||
mel_dl[:-1])).sum() == 0, (
|
|
||||||
abs(mel.T).astype("float32") - abs(mel_dl[:-1])).sum()
|
|
||||||
|
|
||||||
# check mel-spec correctness
|
|
||||||
mel_spec = mel_input[-1].cpu().numpy()
|
|
||||||
wav = self.ap.inv_mel_spectrogram(mel_spec.T)
|
|
||||||
self.ap.save_wav(wav,
|
|
||||||
OUTPATH + '/mel_inv_dataloader_cache.wav')
|
|
||||||
shutil.copy(item_idx[-1], OUTPATH + '/mel_target_dataloader_cache.wav')
|
|
||||||
|
|
||||||
# check linear-spec
|
|
||||||
linear_spec = linear_input[-1].cpu().numpy()
|
|
||||||
wav = self.ap.inv_spectrogram(linear_spec.T)
|
|
||||||
self.ap.save_wav(wav, OUTPATH + '/linear_inv_dataloader_cache.wav')
|
|
||||||
shutil.copy(item_idx[-1], OUTPATH + '/linear_target_dataloader_cache.wav')
|
|
||||||
|
|
||||||
# check the last time step to be zero padded
|
|
||||||
assert mel_input[0, -1].sum() == 0
|
|
||||||
assert mel_input[0, -2].sum() != 0
|
|
||||||
assert stop_target[0, -1] == 1
|
|
||||||
assert stop_target[0, -2] == 0
|
|
||||||
assert stop_target.sum() == 1
|
|
||||||
assert len(mel_lengths.shape) == 1
|
|
||||||
assert mel_lengths[0] == mel_input[0].shape[0]
|
|
||||||
|
|
||||||
# Test for batch size 2
|
|
||||||
dataloader, dataset = self._create_dataloader(2, 1, 0)
|
|
||||||
for i, data in enumerate(dataloader):
|
|
||||||
if i == self.max_loader_iter:
|
|
||||||
break
|
|
||||||
text_input = data[0]
|
|
||||||
text_lengths = data[1]
|
|
||||||
linear_input = data[2]
|
|
||||||
mel_input = data[3]
|
|
||||||
mel_lengths = data[4]
|
|
||||||
stop_target = data[5]
|
|
||||||
item_idx = data[6]
|
|
||||||
|
|
||||||
if mel_lengths[0] > mel_lengths[1]:
|
|
||||||
idx = 0
|
|
||||||
else:
|
|
||||||
idx = 1
|
|
||||||
|
|
||||||
# check the first item in the batch
|
|
||||||
assert mel_input[idx, -1].sum() == 0
|
|
||||||
assert mel_input[idx, -2].sum() != 0, mel_input
|
|
||||||
assert stop_target[idx, -1] == 1
|
|
||||||
assert stop_target[idx, -2] == 0
|
|
||||||
assert stop_target[idx].sum() == 1
|
|
||||||
assert len(mel_lengths.shape) == 1
|
|
||||||
assert mel_lengths[idx] == mel_input[idx].shape[0]
|
|
||||||
|
|
||||||
# check the second itme in the batch
|
|
||||||
assert mel_input[1 - idx, -1].sum() == 0
|
|
||||||
assert stop_target[1 - idx, -1] == 1
|
|
||||||
assert len(mel_lengths.shape) == 1
|
|
||||||
|
|
||||||
# check batch conditions
|
|
||||||
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
|
||||||
|
|
||||||
|
|
||||||
# class TestTTSDatasetMemory(unittest.TestCase):
|
|
||||||
# def __init__(self, *args, **kwargs):
|
|
||||||
# super(TestTTSDatasetMemory, self).__init__(*args, **kwargs)
|
|
||||||
# self.max_loader_iter = 4
|
|
||||||
# self.c = load_config(os.path.join(c.data_path_cache, 'config.json'))
|
|
||||||
# self.ap = AudioProcessor(**c.audio)
|
|
||||||
|
|
||||||
# def test_loader(self):
|
|
||||||
# if ok_ljspeech:
|
|
||||||
# dataset = TTSDatasetMemory.MyDataset(
|
|
||||||
# c.data_path_cache,
|
|
||||||
# 'tts_metadata.csv',
|
|
||||||
# c.r,
|
|
||||||
# c.text_cleaner,
|
|
||||||
# preprocessor=tts_cache,
|
|
||||||
# ap=self.ap,
|
|
||||||
# min_seq_len=c.min_seq_len)
|
|
||||||
|
|
||||||
# dataloader = DataLoader(
|
|
||||||
# dataset,
|
|
||||||
# batch_size=2,
|
|
||||||
# shuffle=True,
|
|
||||||
# collate_fn=dataset.collate_fn,
|
|
||||||
# drop_last=True,
|
|
||||||
# num_workers=c.num_loader_workers)
|
|
||||||
|
|
||||||
# for i, data in enumerate(dataloader):
|
|
||||||
# if i == self.max_loader_iter:
|
|
||||||
# break
|
|
||||||
# text_input = data[0]
|
|
||||||
# text_lengths = data[1]
|
|
||||||
# linear_input = data[2]
|
|
||||||
# mel_input = data[3]
|
|
||||||
# mel_lengths = data[4]
|
|
||||||
# stop_target = data[5]
|
|
||||||
# item_idx = data[6]
|
|
||||||
|
|
||||||
# neg_values = text_input[text_input < 0]
|
|
||||||
# check_count = len(neg_values)
|
|
||||||
# assert check_count == 0, \
|
|
||||||
# " !! Negative values in text_input: {}".format(check_count)
|
|
||||||
# # check mel-spec shape
|
|
||||||
# assert mel_input.shape[0] == c.batch_size
|
|
||||||
# assert mel_input.shape[2] == c.audio['num_mels']
|
|
||||||
# assert mel_input.max() <= self.ap.max_norm
|
|
||||||
# # check data range
|
|
||||||
# if self.ap.symmetric_norm:
|
|
||||||
# assert mel_input.max() <= self.ap.max_norm
|
|
||||||
# assert mel_input.min() >= -self.ap.max_norm
|
|
||||||
# assert mel_input.min() < 0
|
|
||||||
# else:
|
|
||||||
# assert mel_input.max() <= self.ap.max_norm
|
|
||||||
# assert mel_input.min() >= 0
|
|
||||||
|
|
||||||
# def test_batch_group_shuffle(self):
|
|
||||||
# if ok_ljspeech:
|
|
||||||
# dataset = TTSDatasetMemory.MyDataset(
|
|
||||||
# c.data_path_cache,
|
|
||||||
# 'tts_metadata.csv',
|
|
||||||
# c.r,
|
|
||||||
# c.text_cleaner,
|
|
||||||
# preprocessor=ljspeech,
|
|
||||||
# ap=self.ap,
|
|
||||||
# batch_group_size=16,
|
|
||||||
# min_seq_len=c.min_seq_len)
|
|
||||||
|
|
||||||
# dataloader = DataLoader(
|
|
||||||
# dataset,
|
|
||||||
# batch_size=2,
|
|
||||||
# shuffle=True,
|
|
||||||
# collate_fn=dataset.collate_fn,
|
|
||||||
# drop_last=True,
|
|
||||||
# num_workers=c.num_loader_workers)
|
|
||||||
|
|
||||||
# frames = dataset.items
|
|
||||||
# for i, data in enumerate(dataloader):
|
|
||||||
# if i == self.max_loader_iter:
|
|
||||||
# break
|
|
||||||
# text_input = data[0]
|
|
||||||
# text_lengths = data[1]
|
|
||||||
# linear_input = data[2]
|
|
||||||
# mel_input = data[3]
|
|
||||||
# mel_lengths = data[4]
|
|
||||||
# stop_target = data[5]
|
|
||||||
# item_idx = data[6]
|
|
||||||
|
|
||||||
# neg_values = text_input[text_input < 0]
|
|
||||||
# check_count = len(neg_values)
|
|
||||||
# assert check_count == 0, \
|
|
||||||
# " !! Negative values in text_input: {}".format(check_count)
|
|
||||||
# assert mel_input.shape[0] == c.batch_size
|
|
||||||
# assert mel_input.shape[2] == c.audio['num_mels']
|
|
||||||
# dataloader.dataset.sort_items()
|
|
||||||
# assert frames[0] != dataloader.dataset.items[0]
|
|
||||||
|
|
||||||
# def test_padding_and_spec(self):
|
|
||||||
# if ok_ljspeech:
|
|
||||||
# dataset = TTSDatasetMemory.MyDataset(
|
|
||||||
# c.data_path_cache,
|
|
||||||
# 'tts_meta_data.csv',
|
|
||||||
# 1,
|
|
||||||
# c.text_cleaner,
|
|
||||||
# preprocessor=ljspeech,
|
|
||||||
# ap=self.ap,
|
|
||||||
# min_seq_len=c.min_seq_len)
|
|
||||||
|
|
||||||
# # Test for batch size 1
|
|
||||||
# dataloader = DataLoader(
|
|
||||||
# dataset,
|
|
||||||
# batch_size=1,
|
|
||||||
# shuffle=False,
|
|
||||||
# collate_fn=dataset.collate_fn,
|
|
||||||
# drop_last=True,
|
|
||||||
# num_workers=c.num_loader_workers)
|
|
||||||
|
|
||||||
# for i, data in enumerate(dataloader):
|
|
||||||
# if i == self.max_loader_iter:
|
|
||||||
# break
|
|
||||||
# text_input = data[0]
|
|
||||||
# text_lengths = data[1]
|
|
||||||
# linear_input = data[2]
|
|
||||||
# mel_input = data[3]
|
|
||||||
# mel_lengths = data[4]
|
|
||||||
# stop_target = data[5]
|
|
||||||
# item_idx = data[6]
|
|
||||||
|
|
||||||
# # check mel_spec consistency
|
|
||||||
# if item_idx[0].split('.')[-1] == 'npy':
|
|
||||||
# wav = np.load(item_idx[0])
|
|
||||||
# else:
|
|
||||||
# wav = self.ap.load_wav(item_idx[0])
|
|
||||||
# mel = self.ap.melspectrogram(wav)
|
|
||||||
# mel_dl = mel_input[0].cpu().numpy()
|
|
||||||
# assert (
|
|
||||||
# abs(mel.T).astype("float32") - abs(mel_dl[:-1])).sum() == 0
|
|
||||||
|
|
||||||
# # check mel-spec correctness
|
|
||||||
# mel_spec = mel_input[0].cpu().numpy()
|
|
||||||
# wav = self.ap.inv_mel_spectrogram(mel_spec.T)
|
|
||||||
# self.ap.save_wav(wav, OUTPATH + '/mel_inv_dataloader_memo.wav')
|
|
||||||
# shutil.copy(item_idx[0], OUTPATH + '/mel_target_dataloader_memo.wav')
|
|
||||||
|
|
||||||
# # check the last time step to be zero padded
|
|
||||||
# assert mel_input[0, -1].sum() == 0
|
|
||||||
# assert mel_input[0, -2].sum() != 0
|
|
||||||
# assert stop_target[0, -1] == 1
|
|
||||||
# assert stop_target[0, -2] == 0
|
|
||||||
# assert stop_target.sum() == 1
|
|
||||||
# assert len(mel_lengths.shape) == 1
|
|
||||||
# assert mel_lengths[0] == mel_input[0].shape[0]
|
|
||||||
|
|
||||||
# # Test for batch size 2
|
|
||||||
# dataloader = DataLoader(
|
|
||||||
# dataset,
|
|
||||||
# batch_size=2,
|
|
||||||
# shuffle=False,
|
|
||||||
# collate_fn=dataset.collate_fn,
|
|
||||||
# drop_last=False,
|
|
||||||
# num_workers=c.num_loader_workers)
|
|
||||||
|
|
||||||
# for i, data in enumerate(dataloader):
|
|
||||||
# if i == self.max_loader_iter:
|
|
||||||
# break
|
|
||||||
# text_input = data[0]
|
|
||||||
# text_lengths = data[1]
|
|
||||||
# linear_input = data[2]
|
|
||||||
# mel_input = data[3]
|
|
||||||
# mel_lengths = data[4]
|
|
||||||
# stop_target = data[5]
|
|
||||||
# item_idx = data[6]
|
|
||||||
|
|
||||||
# if mel_lengths[0] > mel_lengths[1]:
|
|
||||||
# idx = 0
|
|
||||||
# else:
|
|
||||||
# idx = 1
|
|
||||||
|
|
||||||
# # check the first item in the batch
|
|
||||||
# assert mel_input[idx, -1].sum() == 0
|
|
||||||
# assert mel_input[idx, -2].sum() != 0, mel_input
|
|
||||||
# assert stop_target[idx, -1] == 1
|
|
||||||
# assert stop_target[idx, -2] == 0
|
|
||||||
# assert stop_target[idx].sum() == 1
|
|
||||||
# assert len(mel_lengths.shape) == 1
|
|
||||||
# assert mel_lengths[idx] == mel_input[idx].shape[0]
|
|
||||||
|
|
||||||
# # check the second itme in the batch
|
|
||||||
# assert mel_input[1 - idx, -1].sum() == 0
|
|
||||||
# assert stop_target[1 - idx, -1] == 1
|
|
||||||
# assert len(mel_lengths.shape) == 1
|
|
||||||
|
|
||||||
# # check batch conditions
|
|
||||||
# assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
|
|
@ -35,8 +35,8 @@ class TacotronTrainTest(unittest.TestCase):
|
||||||
|
|
||||||
criterion = L1LossMasked().to(device)
|
criterion = L1LossMasked().to(device)
|
||||||
criterion_st = nn.BCELoss().to(device)
|
criterion_st = nn.BCELoss().to(device)
|
||||||
model = Tacotron(c.embedding_size, c.audio['num_freq'], c.audio['num_mels'],
|
model = Tacotron(32, c.embedding_size, c.audio['num_freq'], c.audio['num_mels'],
|
||||||
c.r).to(device)
|
c.r, c.memory_size).to(device)
|
||||||
model.train()
|
model.train()
|
||||||
model_ref = copy.deepcopy(model)
|
model_ref = copy.deepcopy(model)
|
||||||
count = 0
|
count = 0
|
||||||
|
|
|
@ -32,6 +32,7 @@
|
||||||
"mk": 1.0,
|
"mk": 1.0,
|
||||||
"priority_freq": false,
|
"priority_freq": false,
|
||||||
"num_loader_workers": 4,
|
"num_loader_workers": 4,
|
||||||
|
"memory_size": 5,
|
||||||
|
|
||||||
"save_step": 200,
|
"save_step": 200,
|
||||||
"data_path": "/home/erogol/Data/LJSpeech-1.1/",
|
"data_path": "/home/erogol/Data/LJSpeech-1.1/",
|
||||||
|
|
Loading…
Reference in New Issue