mirror of https://github.com/coqui-ai/TTS.git
Test updates for loaders
This commit is contained in:
parent
5d7338ddf6
commit
8864252941
|
@ -2,7 +2,8 @@ import unittest
|
|||
import torch as T
|
||||
|
||||
from TTS.layers.tacotron import Prenet, CBHG, Decoder, Encoder
|
||||
from TTS.layers.losses import L1LossMasked, _sequence_mask
|
||||
from TTS.layers.losses import L1LossMasked
|
||||
from TTS.utils.generic_utils import sequence_mask
|
||||
|
||||
|
||||
class PrenetTests(unittest.TestCase):
|
||||
|
@ -79,7 +80,7 @@ class L1LossMaskedTests(unittest.TestCase):
|
|||
dummy_input = T.ones(4, 8, 128).float()
|
||||
dummy_target = T.zeros(4, 8, 128).float()
|
||||
dummy_length = (T.arange(5, 9)).long()
|
||||
mask = ((_sequence_mask(dummy_length).float() - 1.0)
|
||||
mask = ((sequence_mask(dummy_length).float() - 1.0)
|
||||
* 100.0).unsqueeze(2)
|
||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||
assert output.item() == 1.0, "1.0 vs {}".format(output.data[0])
|
||||
|
|
|
@ -4,7 +4,8 @@ import numpy as np
|
|||
|
||||
from torch.utils.data import DataLoader
|
||||
from TTS.utils.generic_utils import load_config
|
||||
from TTS.datasets.LJSpeech import LJSpeechDataset
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.datasets import LJSpeech, Kusal
|
||||
|
||||
file_path = os.path.dirname(os.path.realpath(__file__))
|
||||
c = load_config(os.path.join(file_path, 'test_config.json'))
|
||||
|
@ -15,21 +16,25 @@ class TestLJSpeechDataset(unittest.TestCase):
|
|||
def __init__(self, *args, **kwargs):
|
||||
super(TestLJSpeechDataset, self).__init__(*args, **kwargs)
|
||||
self.max_loader_iter = 4
|
||||
self.ap = AudioProcessor(sample_rate=c.sample_rate,
|
||||
num_mels=c.num_mels,
|
||||
min_level_db=c.min_level_db,
|
||||
frame_shift_ms=c.frame_shift_ms,
|
||||
frame_length_ms=c.frame_length_ms,
|
||||
ref_level_db=c.ref_level_db,
|
||||
num_freq=c.num_freq,
|
||||
power=c.power,
|
||||
preemphasis=c.preemphasis,
|
||||
min_mel_freq=c.min_mel_freq,
|
||||
max_mel_freq=c.max_mel_freq)
|
||||
|
||||
def test_loader(self):
|
||||
dataset = LJSpeechDataset(os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
||||
os.path.join(c.data_path_LJSpeech, 'wavs'),
|
||||
dataset = LJSpeech.MyDataset(os.path.join(c.data_path_LJSpeech),
|
||||
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
||||
c.r,
|
||||
c.sample_rate,
|
||||
c.text_cleaner,
|
||||
c.num_mels,
|
||||
c.min_level_db,
|
||||
c.frame_shift_ms,
|
||||
c.frame_length_ms,
|
||||
c.preemphasis,
|
||||
c.ref_level_db,
|
||||
c.num_freq,
|
||||
c.power
|
||||
ap = self.ap,
|
||||
min_seq_len=c.min_seq_len
|
||||
)
|
||||
|
||||
dataloader = DataLoader(dataset, batch_size=2,
|
||||
|
@ -57,19 +62,12 @@ class TestLJSpeechDataset(unittest.TestCase):
|
|||
assert mel_input.shape[2] == c.num_mels
|
||||
|
||||
def test_padding(self):
|
||||
dataset = LJSpeechDataset(os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
||||
os.path.join(c.data_path_LJSpeech, 'wavs'),
|
||||
dataset = LJSpeech.MyDataset(os.path.join(c.data_path_LJSpeech),
|
||||
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
||||
1,
|
||||
c.sample_rate,
|
||||
c.text_cleaner,
|
||||
c.num_mels,
|
||||
c.min_level_db,
|
||||
c.frame_shift_ms,
|
||||
c.frame_length_ms,
|
||||
c.preemphasis,
|
||||
c.ref_level_db,
|
||||
c.num_freq,
|
||||
c.power
|
||||
ap = self.ap,
|
||||
min_seq_len=c.min_seq_len
|
||||
)
|
||||
|
||||
# Test for batch size 1
|
||||
|
@ -141,6 +139,135 @@ class TestLJSpeechDataset(unittest.TestCase):
|
|||
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
||||
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
||||
|
||||
|
||||
class TestKusalDataset(unittest.TestCase):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(TestKusalDataset, self).__init__(*args, **kwargs)
|
||||
self.max_loader_iter = 4
|
||||
self.ap = AudioProcessor(sample_rate=c.sample_rate,
|
||||
num_mels=c.num_mels,
|
||||
min_level_db=c.min_level_db,
|
||||
frame_shift_ms=c.frame_shift_ms,
|
||||
frame_length_ms=c.frame_length_ms,
|
||||
ref_level_db=c.ref_level_db,
|
||||
num_freq=c.num_freq,
|
||||
power=c.power,
|
||||
preemphasis=c.preemphasis,
|
||||
min_mel_freq=c.min_mel_freq,
|
||||
max_mel_freq=c.max_mel_freq)
|
||||
|
||||
def test_loader(self):
|
||||
dataset = Kusal.MyDataset(os.path.join(c.data_path_Kusal),
|
||||
os.path.join(c.data_path_Kusal, 'prompts.txt'),
|
||||
c.r,
|
||||
c.text_cleaner,
|
||||
ap = self.ap,
|
||||
min_seq_len=c.min_seq_len
|
||||
)
|
||||
|
||||
dataloader = DataLoader(dataset, batch_size=2,
|
||||
shuffle=True, collate_fn=dataset.collate_fn,
|
||||
drop_last=True, num_workers=c.num_loader_workers)
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
stop_target = data[5]
|
||||
item_idx = data[6]
|
||||
|
||||
neg_values = text_input[text_input < 0]
|
||||
check_count = len(neg_values)
|
||||
assert check_count == 0, \
|
||||
" !! Negative values in text_input: {}".format(check_count)
|
||||
# TODO: more assertion here
|
||||
assert linear_input.shape[0] == c.batch_size
|
||||
assert mel_input.shape[0] == c.batch_size
|
||||
assert mel_input.shape[2] == c.num_mels
|
||||
|
||||
def test_padding(self):
|
||||
dataset = Kusal.MyDataset(os.path.join(c.data_path_Kusal),
|
||||
os.path.join(c.data_path_Kusal, 'prompts.txt'),
|
||||
1,
|
||||
c.text_cleaner,
|
||||
ap = self.ap,
|
||||
min_seq_len=c.min_seq_len
|
||||
)
|
||||
|
||||
# Test for batch size 1
|
||||
dataloader = DataLoader(dataset, batch_size=1,
|
||||
shuffle=False, collate_fn=dataset.collate_fn,
|
||||
drop_last=True, num_workers=c.num_loader_workers)
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
stop_target = data[5]
|
||||
item_idx = data[6]
|
||||
|
||||
# check the last time step to be zero padded
|
||||
assert mel_input[0, -1].sum() == 0
|
||||
# assert mel_input[0, -2].sum() != 0
|
||||
assert linear_input[0, -1].sum() == 0
|
||||
# assert linear_input[0, -2].sum() != 0
|
||||
assert stop_target[0, -1] == 1
|
||||
assert stop_target[0, -2] == 0
|
||||
assert stop_target.sum() == 1
|
||||
assert len(mel_lengths.shape) == 1
|
||||
assert mel_lengths[0] == mel_input[0].shape[0]
|
||||
|
||||
# Test for batch size 2
|
||||
dataloader = DataLoader(dataset, batch_size=2,
|
||||
shuffle=False, collate_fn=dataset.collate_fn,
|
||||
drop_last=False, num_workers=c.num_loader_workers)
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
stop_target = data[5]
|
||||
item_idx = data[6]
|
||||
|
||||
if mel_lengths[0] > mel_lengths[1]:
|
||||
idx = 0
|
||||
else:
|
||||
idx = 1
|
||||
|
||||
# check the first item in the batch
|
||||
assert mel_input[idx, -1].sum() == 0
|
||||
assert mel_input[idx, -2].sum() != 0, mel_input
|
||||
assert linear_input[idx, -1].sum() == 0
|
||||
assert linear_input[idx, -2].sum() != 0
|
||||
assert stop_target[idx, -1] == 1
|
||||
assert stop_target[idx, -2] == 0
|
||||
assert stop_target[idx].sum() == 1
|
||||
assert len(mel_lengths.shape) == 1
|
||||
assert mel_lengths[idx] == mel_input[idx].shape[0]
|
||||
|
||||
# check the second itme in the batch
|
||||
assert mel_input[1-idx, -1].sum() == 0
|
||||
assert linear_input[1-idx, -1].sum() == 0
|
||||
assert stop_target[1-idx, -1] == 1
|
||||
assert len(mel_lengths.shape) == 1
|
||||
|
||||
# check batch conditions
|
||||
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
||||
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
||||
|
||||
|
||||
# class TestTWEBDataset(unittest.TestCase):
|
||||
|
||||
|
|
|
@ -9,6 +9,8 @@
|
|||
"ref_level_db": 20,
|
||||
"hidden_size": 128,
|
||||
"embedding_size": 256,
|
||||
"min_mel_freq": null,
|
||||
"max_mel_freq": null,
|
||||
"text_cleaner": "english_cleaners",
|
||||
|
||||
"epochs": 2000,
|
||||
|
@ -27,8 +29,9 @@
|
|||
"num_loader_workers": 4,
|
||||
|
||||
"save_step": 200,
|
||||
"data_path_LJSpeech": "/data/shared/KeithIto/LJSpeech-1.0",
|
||||
"data_path_TWEB": "/data/shared/BibleSpeech",
|
||||
"data_path_LJSpeech": "C:/Users/erogol/Data/LJSpeech-1.1",
|
||||
"data_path_Kusal": "C:/Users/erogol/Data/Kusal",
|
||||
"output_path": "result",
|
||||
"min_seq_len": 0,
|
||||
"log_dir": "/home/erogol/projects/TTS/logs/"
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue