mirror of https://github.com/coqui-ai/TTS.git
udpate gan_datasets tests
This commit is contained in:
parent
0ee0458309
commit
ba80e82520
|
@ -19,8 +19,8 @@ test_data_path = os.path.join(get_tests_path(), "data/ljspeech/")
|
|||
ok_ljspeech = os.path.exists(test_data_path)
|
||||
|
||||
|
||||
def gan_dataset_case(batch_size, seq_len, hop_len, conv_pad, return_segments, use_noise_augment, use_cache, num_workers):
|
||||
''' run dataloader with given parameters and check conditions '''
|
||||
def gan_dataset_case(batch_size, seq_len, hop_len, conv_pad, return_pairs, return_segments, use_noise_augment, use_cache, num_workers):
|
||||
'''Run dataloader with given parameters and check conditions '''
|
||||
ap = AudioProcessor(**C.audio)
|
||||
_, train_items = load_wav_data(test_data_path, 10)
|
||||
dataset = GANDataset(ap,
|
||||
|
@ -29,6 +29,7 @@ def gan_dataset_case(batch_size, seq_len, hop_len, conv_pad, return_segments, us
|
|||
hop_len=hop_len,
|
||||
pad_short=2000,
|
||||
conv_pad=conv_pad,
|
||||
return_pairs=return_pairs,
|
||||
return_segments=return_segments,
|
||||
use_noise_augment=use_noise_augment,
|
||||
use_cache=use_cache)
|
||||
|
@ -42,31 +43,41 @@ def gan_dataset_case(batch_size, seq_len, hop_len, conv_pad, return_segments, us
|
|||
max_iter = 10
|
||||
count_iter = 0
|
||||
|
||||
def check_item(feat, wav):
|
||||
"""Pass a single pair of features and waveform"""
|
||||
expected_feat_shape = (batch_size, ap.num_mels, seq_len // hop_len + conv_pad * 2)
|
||||
|
||||
# check shapes
|
||||
assert np.all(feat.shape == expected_feat_shape), f" [!] {feat.shape} vs {expected_feat_shape}"
|
||||
assert (feat.shape[2] - conv_pad * 2) * hop_len == wav.shape[2]
|
||||
|
||||
# check feature vs audio match
|
||||
if not use_noise_augment:
|
||||
for idx in range(batch_size):
|
||||
audio = wav[idx].squeeze()
|
||||
feat = feat[idx]
|
||||
mel = ap.melspectrogram(audio)
|
||||
# the first 2 and the last 2 frames are skipped due to the padding
|
||||
# differences in stft
|
||||
max_diff = abs((feat - mel[:, :feat.shape[-1]])[:, 2:-2]).max()
|
||||
assert max_diff <= 0, f' [!] {max_diff}'
|
||||
|
||||
|
||||
# return random segments or return the whole audio
|
||||
if return_segments:
|
||||
for item1, _ in loader:
|
||||
feat1, wav1 = item1
|
||||
# feat2, wav2 = item2
|
||||
expected_feat_shape = (batch_size, ap.num_mels, seq_len // hop_len + conv_pad * 2)
|
||||
if return_pairs:
|
||||
for item1, item2 in loader:
|
||||
feat1, wav1 = item1
|
||||
feat2, wav2 = item2
|
||||
|
||||
# check shapes
|
||||
assert np.all(feat1.shape == expected_feat_shape), f" [!] {feat1.shape} vs {expected_feat_shape}"
|
||||
assert (feat1.shape[2] - conv_pad * 2) * hop_len == wav1.shape[2]
|
||||
|
||||
# check feature vs audio match
|
||||
if not use_noise_augment:
|
||||
for idx in range(batch_size):
|
||||
audio = wav1[idx].squeeze()
|
||||
feat = feat1[idx]
|
||||
mel = ap.melspectrogram(audio)
|
||||
# the first 2 and the last 2 frames are skipped due to the padding
|
||||
# differences in stft
|
||||
max_diff = abs((feat - mel[:, :feat1.shape[-1]])[:, 2:-2]).max()
|
||||
assert max_diff <= 0, f' [!] {max_diff}'
|
||||
|
||||
count_iter += 1
|
||||
# if count_iter == max_iter:
|
||||
# break
|
||||
check_item(feat1, wav1)
|
||||
check_item(feat2, wav2)
|
||||
count_iter += 1
|
||||
else:
|
||||
for item1 in loader:
|
||||
feat1, wav1 = item1
|
||||
check_item(feat1, wav1)
|
||||
count_iter += 1
|
||||
else:
|
||||
for item in loader:
|
||||
feat, wav = item
|
||||
|
@ -81,15 +92,16 @@ def gan_dataset_case(batch_size, seq_len, hop_len, conv_pad, return_segments, us
|
|||
def test_parametrized_gan_dataset():
|
||||
''' test dataloader with different parameters '''
|
||||
params = [
|
||||
[32, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, False, True, 0],
|
||||
[32, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, False, True, 4],
|
||||
[1, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, True, True, 0],
|
||||
[1, C.audio['hop_length'], C.audio['hop_length'], 0, True, True, True, 0],
|
||||
[1, C.audio['hop_length'] * 10, C.audio['hop_length'], 2, True, True, True, 0],
|
||||
[1, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, False, True, True, 0],
|
||||
[1, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, False, True, 0],
|
||||
[1, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, True, False, 0],
|
||||
[1, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, False, False, False, 0],
|
||||
[32, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, True, False, True, 0],
|
||||
[32, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, True, False, True, 4],
|
||||
[1, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, True, True, True, 0],
|
||||
[1, C.audio['hop_length'], C.audio['hop_length'], 0, True, True, True, True, 0],
|
||||
[1, C.audio['hop_length'] * 10, C.audio['hop_length'], 2, True, True, True, True, 0],
|
||||
[1, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, False, True, True, 0],
|
||||
[1, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, True, False, True, 0],
|
||||
[1, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, False, True, True, False, 0],
|
||||
[1, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, False, False, False, 0],
|
||||
[1, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, False, False, False, 0]
|
||||
]
|
||||
for param in params:
|
||||
print(param)
|
||||
|
|
Loading…
Reference in New Issue