mirror of https://github.com/coqui-ai/TTS.git
Update data loader tests
This commit is contained in:
parent
75c507c36a
commit
196ae74273
|
@ -38,7 +38,7 @@ class TTSDataset(Dataset):
|
||||||
outputs_per_step: int,
|
outputs_per_step: int,
|
||||||
compute_linear_spec: bool,
|
compute_linear_spec: bool,
|
||||||
ap: AudioProcessor,
|
ap: AudioProcessor,
|
||||||
meta_data: List[Dict],
|
samples: List[Dict],
|
||||||
tokenizer: "TTSTokenizer" = None,
|
tokenizer: "TTSTokenizer" = None,
|
||||||
compute_f0: bool = False,
|
compute_f0: bool = False,
|
||||||
f0_cache_path: str = None,
|
f0_cache_path: str = None,
|
||||||
|
@ -67,7 +67,7 @@ class TTSDataset(Dataset):
|
||||||
|
|
||||||
ap (TTS.tts.utils.AudioProcessor): Audio processor object.
|
ap (TTS.tts.utils.AudioProcessor): Audio processor object.
|
||||||
|
|
||||||
meta_data (list): List of dataset samples.
|
samples (list): List of dataset samples.
|
||||||
|
|
||||||
tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else
|
tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else
|
||||||
use the given. Defaults to None.
|
use the given. Defaults to None.
|
||||||
|
@ -111,7 +111,7 @@ class TTSDataset(Dataset):
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.batch_group_size = batch_group_size
|
self.batch_group_size = batch_group_size
|
||||||
self._samples = meta_data
|
self._samples = samples
|
||||||
self.outputs_per_step = outputs_per_step
|
self.outputs_per_step = outputs_per_step
|
||||||
self.sample_rate = ap.sample_rate
|
self.sample_rate = ap.sample_rate
|
||||||
self.compute_linear_spec = compute_linear_spec
|
self.compute_linear_spec = compute_linear_spec
|
||||||
|
@ -200,7 +200,7 @@ class TTSDataset(Dataset):
|
||||||
token_ids = self.get_phonemes(idx, text)["token_ids"]
|
token_ids = self.get_phonemes(idx, text)["token_ids"]
|
||||||
else:
|
else:
|
||||||
token_ids = self.tokenizer.text_to_ids(text)
|
token_ids = self.tokenizer.text_to_ids(text)
|
||||||
return token_ids
|
return np.array(token_ids, dtype=np.int32)
|
||||||
|
|
||||||
def load_data(self, idx):
|
def load_data(self, idx):
|
||||||
item = self.samples[idx]
|
item = self.samples[idx]
|
||||||
|
@ -258,7 +258,7 @@ class TTSDataset(Dataset):
|
||||||
return audio_lengths, text_lengths
|
return audio_lengths, text_lengths
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sort_and_filter_by_length(lengths: List[int], min_len: int, max_len: int):
|
def filter_by_length(lengths: List[int], min_len: int, max_len: int):
|
||||||
idxs = np.argsort(lengths) # ascending order
|
idxs = np.argsort(lengths) # ascending order
|
||||||
ignore_idx = []
|
ignore_idx = []
|
||||||
keep_idx = []
|
keep_idx = []
|
||||||
|
@ -270,6 +270,11 @@ class TTSDataset(Dataset):
|
||||||
keep_idx.append(idx)
|
keep_idx.append(idx)
|
||||||
return ignore_idx, keep_idx
|
return ignore_idx, keep_idx
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def sort_by_length(lengths: List[int]):
|
||||||
|
idxs = np.argsort(lengths) # ascending order
|
||||||
|
return idxs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_buckets(samples, batch_group_size: int):
|
def create_buckets(samples, batch_group_size: int):
|
||||||
for i in range(len(samples) // batch_group_size):
|
for i in range(len(samples) // batch_group_size):
|
||||||
|
@ -280,24 +285,33 @@ class TTSDataset(Dataset):
|
||||||
samples[offset:end_offset] = temp_items
|
samples[offset:end_offset] = temp_items
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
def select_samples_by_idx(self, idxs):
|
||||||
|
samples = []
|
||||||
|
audio_lengths = []
|
||||||
|
text_lengths = []
|
||||||
|
for idx in idxs:
|
||||||
|
samples.append(self.samples[idx])
|
||||||
|
audio_lengths.append(self.audio_lengths[idx])
|
||||||
|
text_lengths.append(self.text_lengths[idx])
|
||||||
|
return samples, audio_lengths, text_lengths
|
||||||
|
|
||||||
def preprocess_samples(self):
|
def preprocess_samples(self):
|
||||||
r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length
|
r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length
|
||||||
range.
|
range.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# sort items based on the sequence length in ascending order
|
# sort items based on the sequence length in ascending order
|
||||||
text_ignore_idx, text_keep_idx = self.sort_and_filter_by_length(
|
text_ignore_idx, text_keep_idx = self.filter_by_length(self.text_lengths, self.min_text_len, self.max_text_len)
|
||||||
self.text_lengths, self.min_text_len, self.max_text_len
|
audio_ignore_idx, audio_keep_idx = self.filter_by_length(
|
||||||
)
|
|
||||||
audio_ignore_idx, audio_keep_idx = self.sort_and_filter_by_length(
|
|
||||||
self.audio_lengths, self.min_audio_len, self.max_audio_len
|
self.audio_lengths, self.min_audio_len, self.max_audio_len
|
||||||
)
|
)
|
||||||
keep_idx = list(set(audio_keep_idx) | set(text_keep_idx))
|
keep_idx = list(set(audio_keep_idx) | set(text_keep_idx))
|
||||||
ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx))
|
ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx))
|
||||||
|
|
||||||
samples = []
|
samples, audio_lengths, _ = self.select_samples_by_idx(keep_idx)
|
||||||
for idx in keep_idx:
|
|
||||||
samples.append(self.samples[idx])
|
sorted_idxs = self.sort_by_length(audio_lengths)
|
||||||
|
samples, audio_lengths, text_lengtsh = self.select_samples_by_idx(sorted_idxs)
|
||||||
|
|
||||||
if len(samples) == 0:
|
if len(samples) == 0:
|
||||||
raise RuntimeError(" [!] No samples left")
|
raise RuntimeError(" [!] No samples left")
|
||||||
|
@ -309,6 +323,8 @@ class TTSDataset(Dataset):
|
||||||
|
|
||||||
# update items to the new sorted items
|
# update items to the new sorted items
|
||||||
self.samples = samples
|
self.samples = samples
|
||||||
|
self.audio_lengths = audio_lengths
|
||||||
|
self.text_lengths = text_lengtsh
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(" | > Preprocessing samples")
|
print(" | > Preprocessing samples")
|
||||||
|
@ -391,7 +407,7 @@ class TTSDataset(Dataset):
|
||||||
stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step)
|
stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step)
|
||||||
|
|
||||||
# PAD sequences with longest instance in the batch
|
# PAD sequences with longest instance in the batch
|
||||||
text = prepare_data(batch["token_ids"]).astype(np.int32)
|
token_ids = prepare_data(batch["token_ids"]).astype(np.int32)
|
||||||
|
|
||||||
# PAD features with longest instance
|
# PAD features with longest instance
|
||||||
mel = prepare_tensor(mel, self.outputs_per_step)
|
mel = prepare_tensor(mel, self.outputs_per_step)
|
||||||
|
@ -401,7 +417,7 @@ class TTSDataset(Dataset):
|
||||||
|
|
||||||
# convert things to pytorch
|
# convert things to pytorch
|
||||||
token_ids_lengths = torch.LongTensor(token_ids_lengths)
|
token_ids_lengths = torch.LongTensor(token_ids_lengths)
|
||||||
text = torch.LongTensor(text)
|
token_ids = torch.LongTensor(token_ids)
|
||||||
mel = torch.FloatTensor(mel).contiguous()
|
mel = torch.FloatTensor(mel).contiguous()
|
||||||
mel_lengths = torch.LongTensor(mel_lengths)
|
mel_lengths = torch.LongTensor(mel_lengths)
|
||||||
stop_targets = torch.FloatTensor(stop_targets)
|
stop_targets = torch.FloatTensor(stop_targets)
|
||||||
|
@ -453,7 +469,7 @@ class TTSDataset(Dataset):
|
||||||
attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing]
|
attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing]
|
||||||
for idx, attn in enumerate(attns):
|
for idx, attn in enumerate(attns):
|
||||||
pad2 = mel.shape[1] - attn.shape[1]
|
pad2 = mel.shape[1] - attn.shape[1]
|
||||||
pad1 = text.shape[1] - attn.shape[0]
|
pad1 = token_ids.shape[1] - attn.shape[0]
|
||||||
assert pad1 >= 0 and pad2 >= 0, f"[!] Negative padding - {pad1} and {pad2}"
|
assert pad1 >= 0 and pad2 >= 0, f"[!] Negative padding - {pad1} and {pad2}"
|
||||||
attn = np.pad(attn, [[0, pad1], [0, pad2]])
|
attn = np.pad(attn, [[0, pad1], [0, pad2]])
|
||||||
attns[idx] = attn
|
attns[idx] = attn
|
||||||
|
@ -461,7 +477,7 @@ class TTSDataset(Dataset):
|
||||||
attns = torch.FloatTensor(attns).unsqueeze(1)
|
attns = torch.FloatTensor(attns).unsqueeze(1)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"token_id": text,
|
"token_id": token_ids,
|
||||||
"token_id_lengths": token_ids_lengths,
|
"token_id_lengths": token_ids_lengths,
|
||||||
"speaker_names": batch["speaker_name"],
|
"speaker_names": batch["speaker_name"],
|
||||||
"linear": linear,
|
"linear": linear,
|
||||||
|
@ -786,7 +802,7 @@ if __name__ == "__main__":
|
||||||
dataset = TTSDataset(
|
dataset = TTSDataset(
|
||||||
outputs_per_step=1,
|
outputs_per_step=1,
|
||||||
compute_linear_spec=False,
|
compute_linear_spec=False,
|
||||||
meta_data=samples,
|
samples=samples,
|
||||||
ap=ap,
|
ap=ap,
|
||||||
return_wav=False,
|
return_wav=False,
|
||||||
batch_group_size=0,
|
batch_group_size=0,
|
||||||
|
|
|
@ -147,6 +147,7 @@ class TTSTokenizer:
|
||||||
if isinstance(config.text_cleaner, (str, list)):
|
if isinstance(config.text_cleaner, (str, list)):
|
||||||
text_cleaner = getattr(cleaners, config.text_cleaner)
|
text_cleaner = getattr(cleaners, config.text_cleaner)
|
||||||
|
|
||||||
|
phonemizer = None
|
||||||
if config.use_phonemes:
|
if config.use_phonemes:
|
||||||
# init phoneme set
|
# init phoneme set
|
||||||
characters = IPAPhonemes().init_from_config(config)
|
characters = IPAPhonemes().init_from_config(config)
|
||||||
|
|
|
@ -7,9 +7,9 @@ import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from tests import get_tests_output_path
|
from tests import get_tests_output_path
|
||||||
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
from TTS.tts.configs.shared_configs import BaseTTSConfig, BaseDatasetConfig
|
||||||
from TTS.tts.datasets import TTSDataset, load_tts_samples
|
from TTS.tts.datasets import TTSDataset, load_tts_samples
|
||||||
from TTS.config.shared_configs import BaseDatasetConfig
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
# pylint: disable=unused-variable
|
# pylint: disable=unused-variable
|
||||||
|
@ -50,18 +50,19 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
meta_data_train, meta_data_eval = load_tts_samples(dataset_config, eval_split=True, eval_split_size=0.2)
|
meta_data_train, meta_data_eval = load_tts_samples(dataset_config, eval_split=True, eval_split_size=0.2)
|
||||||
items = meta_data_train + meta_data_eval
|
items = meta_data_train + meta_data_eval
|
||||||
|
|
||||||
|
tokenizer = TTSTokenizer.init_from_config(c)
|
||||||
dataset = TTSDataset(
|
dataset = TTSDataset(
|
||||||
r,
|
outputs_per_step=r,
|
||||||
c.text_cleaner,
|
|
||||||
compute_linear_spec=True,
|
compute_linear_spec=True,
|
||||||
return_wav=True,
|
return_wav=True,
|
||||||
|
tokenizer=tokenizer,
|
||||||
ap=self.ap,
|
ap=self.ap,
|
||||||
meta_data=items,
|
samples=items,
|
||||||
characters=c.characters,
|
|
||||||
batch_group_size=bgs,
|
batch_group_size=bgs,
|
||||||
min_seq_len=c.min_seq_len,
|
min_text_len=c.min_text_len,
|
||||||
max_seq_len=float("inf"),
|
max_text_len=c.max_text_len,
|
||||||
use_phonemes=False,
|
min_audio_len=c.min_audio_len,
|
||||||
|
max_audio_len=c.max_audio_len,
|
||||||
)
|
)
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
|
@ -80,27 +81,26 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
if i == self.max_loader_iter:
|
if i == self.max_loader_iter:
|
||||||
break
|
break
|
||||||
text_input = data["text"]
|
text_input = data["token_id"]
|
||||||
text_lengths = data["text_lengths"]
|
_ = data["token_id_lengths"]
|
||||||
speaker_name = data["speaker_names"]
|
speaker_name = data["speaker_names"]
|
||||||
linear_input = data["linear"]
|
linear_input = data["linear"]
|
||||||
mel_input = data["mel"]
|
mel_input = data["mel"]
|
||||||
mel_lengths = data["mel_lengths"]
|
mel_lengths = data["mel_lengths"]
|
||||||
stop_target = data["stop_targets"]
|
_ = data["stop_targets"]
|
||||||
item_idx = data["item_idxs"]
|
_ = data["item_idxs"]
|
||||||
wavs = data["waveform"]
|
wavs = data["waveform"]
|
||||||
|
|
||||||
neg_values = text_input[text_input < 0]
|
neg_values = text_input[text_input < 0]
|
||||||
check_count = len(neg_values)
|
check_count = len(neg_values)
|
||||||
assert check_count == 0, " !! Negative values in text_input: {}".format(check_count)
|
|
||||||
assert isinstance(speaker_name[0], str)
|
# check basic conditions
|
||||||
assert linear_input.shape[0] == c.batch_size
|
self.assertEqual(check_count, 0)
|
||||||
assert linear_input.shape[2] == self.ap.fft_size // 2 + 1
|
self.assertEqual(linear_input.shape[0], mel_input.shape[0], c.batch_size)
|
||||||
assert mel_input.shape[0] == c.batch_size
|
self.assertEqual(linear_input.shape[2], self.ap.fft_size // 2 + 1)
|
||||||
assert mel_input.shape[2] == c.audio["num_mels"]
|
self.assertEqual(mel_input.shape[2], c.audio["num_mels"])
|
||||||
assert (
|
self.assertEqual(wavs.shape[1], mel_input.shape[1] * c.audio.hop_length)
|
||||||
wavs.shape[1] == mel_input.shape[1] * c.audio.hop_length
|
self.assertIsInstance(speaker_name[0], str)
|
||||||
), f"wavs.shape: {wavs.shape[1]}, mel_input.shape: {mel_input.shape[1] * c.audio.hop_length}"
|
|
||||||
|
|
||||||
# make sure that the computed mels and the waveform match and correctly computed
|
# make sure that the computed mels and the waveform match and correctly computed
|
||||||
mel_new = self.ap.melspectrogram(wavs[0].squeeze().numpy())
|
mel_new = self.ap.melspectrogram(wavs[0].squeeze().numpy())
|
||||||
|
@ -109,55 +109,58 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
# guarantee that both mel-spectrograms have the same size and that we will remove waveform padding
|
# guarantee that both mel-spectrograms have the same size and that we will remove waveform padding
|
||||||
mel_new = mel_new[:, :mel_lengths[0]]
|
mel_new = mel_new[:, :mel_lengths[0]]
|
||||||
ignore_seg = -(1 + c.audio.win_length // c.audio.hop_length)
|
ignore_seg = -(1 + c.audio.win_length // c.audio.hop_length)
|
||||||
mel_diff = (mel_new - mel_dataloader)[:, 0:ignore_seg]
|
mel_diff = (mel_new[:, : mel_input.shape[1]] - mel_input[0].T.numpy())[:, 0:ignore_seg]
|
||||||
assert abs(mel_diff.sum()) < 1e-5
|
self.assertLess(abs(mel_diff.sum()), 1e-5)
|
||||||
|
|
||||||
# check normalization ranges
|
# check normalization ranges
|
||||||
if self.ap.symmetric_norm:
|
if self.ap.symmetric_norm:
|
||||||
assert mel_input.max() <= self.ap.max_norm
|
self.assertLessEqual(mel_input.max(), self.ap.max_norm)
|
||||||
assert mel_input.min() >= -self.ap.max_norm # pylint: disable=invalid-unary-operand-type
|
self.assertGreaterEqual(
|
||||||
assert mel_input.min() < 0
|
mel_input.min(), -self.ap.max_norm
|
||||||
|
) # pylint: disable=invalid-unary-operand-type
|
||||||
|
self.assertLess(mel_input.min(), 0)
|
||||||
else:
|
else:
|
||||||
assert mel_input.max() <= self.ap.max_norm
|
self.assertLessEqual(mel_input.max(), self.ap.max_norm)
|
||||||
assert mel_input.min() >= 0
|
self.assertGreaterEqual(mel_input.min(), 0)
|
||||||
|
|
||||||
def test_batch_group_shuffle(self):
|
def test_batch_group_shuffle(self):
|
||||||
if ok_ljspeech:
|
if ok_ljspeech:
|
||||||
dataloader, dataset = self._create_dataloader(2, c.r, 16)
|
dataloader, dataset = self._create_dataloader(2, c.r, 16)
|
||||||
last_length = 0
|
last_length = 0
|
||||||
frames = dataset.items
|
frames = dataset.samples
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
if i == self.max_loader_iter:
|
if i == self.max_loader_iter:
|
||||||
break
|
break
|
||||||
text_input = data["text"]
|
|
||||||
text_lengths = data["text_lengths"]
|
|
||||||
speaker_name = data["speaker_names"]
|
|
||||||
linear_input = data["linear"]
|
|
||||||
mel_input = data["mel"]
|
|
||||||
mel_lengths = data["mel_lengths"]
|
mel_lengths = data["mel_lengths"]
|
||||||
stop_target = data["stop_targets"]
|
|
||||||
item_idx = data["item_idxs"]
|
|
||||||
|
|
||||||
avg_length = mel_lengths.numpy().mean()
|
avg_length = mel_lengths.numpy().mean()
|
||||||
assert avg_length >= last_length
|
dataloader.dataset.preprocess_samples()
|
||||||
dataloader.dataset.sort_and_filter_items()
|
|
||||||
is_items_reordered = False
|
is_items_reordered = False
|
||||||
for idx, item in enumerate(dataloader.dataset.items):
|
for idx, item in enumerate(dataloader.dataset.samples):
|
||||||
if item != frames[idx]:
|
if item != frames[idx]:
|
||||||
is_items_reordered = True
|
is_items_reordered = True
|
||||||
break
|
break
|
||||||
assert is_items_reordered
|
self.assertGreaterEqual(avg_length, last_length)
|
||||||
|
self.assertTrue(is_items_reordered)
|
||||||
|
|
||||||
|
def test_padding_and_spectrograms(self):
|
||||||
|
def check_conditions(idx, linear_input, mel_input, stop_target, mel_lengths):
|
||||||
|
self.assertNotEqual(linear_input[idx, -1].sum(), 0) # check padding
|
||||||
|
self.assertNotEqual(linear_input[idx, -2].sum(), 0)
|
||||||
|
self.assertNotEqual(mel_input[idx, -1].sum(), 0)
|
||||||
|
self.assertNotEqual(mel_input[idx, -2].sum(), 0)
|
||||||
|
self.assertEqual(stop_target[idx, -1], 1)
|
||||||
|
self.assertEqual(stop_target[idx, -2], 0)
|
||||||
|
self.assertEqual(stop_target[idx].sum(), 1)
|
||||||
|
self.assertEqual(len(mel_lengths.shape), 1)
|
||||||
|
self.assertEqual(mel_lengths[idx], linear_input[idx].shape[0])
|
||||||
|
self.assertEqual(mel_lengths[idx], mel_input[idx].shape[0])
|
||||||
|
|
||||||
def test_padding_and_spec(self):
|
|
||||||
if ok_ljspeech:
|
if ok_ljspeech:
|
||||||
dataloader, dataset = self._create_dataloader(1, 1, 0)
|
dataloader, _ = self._create_dataloader(1, 1, 0)
|
||||||
|
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
if i == self.max_loader_iter:
|
if i == self.max_loader_iter:
|
||||||
break
|
break
|
||||||
text_input = data["text"]
|
|
||||||
text_lengths = data["text_lengths"]
|
|
||||||
speaker_name = data["speaker_names"]
|
|
||||||
linear_input = data["linear"]
|
linear_input = data["linear"]
|
||||||
mel_input = data["mel"]
|
mel_input = data["mel"]
|
||||||
mel_lengths = data["mel_lengths"]
|
mel_lengths = data["mel_lengths"]
|
||||||
|
@ -172,7 +175,7 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
# NOTE: Below needs to check == 0 but due to an unknown reason
|
# NOTE: Below needs to check == 0 but due to an unknown reason
|
||||||
# there is a slight difference between two matrices.
|
# there is a slight difference between two matrices.
|
||||||
# TODO: Check this assert cond more in detail.
|
# TODO: Check this assert cond more in detail.
|
||||||
assert abs(mel.T - mel_dl).max() < 1e-5, abs(mel.T - mel_dl).max()
|
self.assertLess(abs(mel.T - mel_dl).max(), 1e-5)
|
||||||
|
|
||||||
# check mel-spec correctness
|
# check mel-spec correctness
|
||||||
mel_spec = mel_input[0].cpu().numpy()
|
mel_spec = mel_input[0].cpu().numpy()
|
||||||
|
@ -186,56 +189,36 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
self.ap.save_wav(wav, OUTPATH + "/linear_inv_dataloader.wav")
|
self.ap.save_wav(wav, OUTPATH + "/linear_inv_dataloader.wav")
|
||||||
shutil.copy(item_idx[0], OUTPATH + "/linear_target_dataloader.wav")
|
shutil.copy(item_idx[0], OUTPATH + "/linear_target_dataloader.wav")
|
||||||
|
|
||||||
# check the last time step to be zero padded
|
# check the outputs
|
||||||
assert linear_input[0, -1].sum() != 0
|
check_conditions(0, linear_input, mel_input, stop_target, mel_lengths)
|
||||||
assert linear_input[0, -2].sum() != 0
|
|
||||||
assert mel_input[0, -1].sum() != 0
|
|
||||||
assert mel_input[0, -2].sum() != 0
|
|
||||||
assert stop_target[0, -1] == 1
|
|
||||||
assert stop_target[0, -2] == 0
|
|
||||||
assert stop_target.sum() == 1
|
|
||||||
assert len(mel_lengths.shape) == 1
|
|
||||||
assert mel_lengths[0] == linear_input[0].shape[0]
|
|
||||||
assert mel_lengths[0] == mel_input[0].shape[0]
|
|
||||||
|
|
||||||
# Test for batch size 2
|
# Test for batch size 2
|
||||||
dataloader, dataset = self._create_dataloader(2, 1, 0)
|
dataloader, _ = self._create_dataloader(2, 1, 0)
|
||||||
|
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
if i == self.max_loader_iter:
|
if i == self.max_loader_iter:
|
||||||
break
|
break
|
||||||
text_input = data["text"]
|
|
||||||
text_lengths = data["text_lengths"]
|
|
||||||
speaker_name = data["speaker_names"]
|
|
||||||
linear_input = data["linear"]
|
linear_input = data["linear"]
|
||||||
mel_input = data["mel"]
|
mel_input = data["mel"]
|
||||||
mel_lengths = data["mel_lengths"]
|
mel_lengths = data["mel_lengths"]
|
||||||
stop_target = data["stop_targets"]
|
stop_target = data["stop_targets"]
|
||||||
item_idx = data["item_idxs"]
|
item_idx = data["item_idxs"]
|
||||||
|
|
||||||
|
# set id to the longest sequence in the batch
|
||||||
if mel_lengths[0] > mel_lengths[1]:
|
if mel_lengths[0] > mel_lengths[1]:
|
||||||
idx = 0
|
idx = 0
|
||||||
else:
|
else:
|
||||||
idx = 1
|
idx = 1
|
||||||
|
|
||||||
# check the first item in the batch
|
# check the longer item in the batch
|
||||||
assert linear_input[idx, -1].sum() != 0
|
check_conditions(idx, linear_input, mel_input, stop_target, mel_lengths)
|
||||||
assert linear_input[idx, -2].sum() != 0, linear_input
|
|
||||||
assert mel_input[idx, -1].sum() != 0
|
|
||||||
assert mel_input[idx, -2].sum() != 0, mel_input
|
|
||||||
assert stop_target[idx, -1] == 1
|
|
||||||
assert stop_target[idx, -2] == 0
|
|
||||||
assert stop_target[idx].sum() == 1
|
|
||||||
assert len(mel_lengths.shape) == 1
|
|
||||||
assert mel_lengths[idx] == mel_input[idx].shape[0]
|
|
||||||
assert mel_lengths[idx] == linear_input[idx].shape[0]
|
|
||||||
|
|
||||||
# check the second itme in the batch
|
# check the other item in the batch
|
||||||
assert linear_input[1 - idx, -1].sum() == 0
|
self.assertEqual(linear_input[1 - idx, -1].sum(), 0)
|
||||||
assert mel_input[1 - idx, -1].sum() == 0
|
self.assertEqual(mel_input[1 - idx, -1].sum(), 0)
|
||||||
assert stop_target[1, mel_lengths[1] - 1] == 1
|
self.assertEqual(stop_target[1, mel_lengths[1] - 1], 1)
|
||||||
assert stop_target[1, mel_lengths[1] :].sum() == stop_target.shape[1] - mel_lengths[1]
|
self.assertEqual(stop_target[1, mel_lengths[1] :].sum(), stop_target.shape[1] - mel_lengths[1])
|
||||||
assert len(mel_lengths.shape) == 1
|
self.assertEqual(len(mel_lengths.shape), 1)
|
||||||
|
|
||||||
# check batch zero-frame conditions (zero-frame disabled)
|
# check batch zero-frame conditions (zero-frame disabled)
|
||||||
# assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
# assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
||||||
|
|
Loading…
Reference in New Issue