Implement `start_by_longest` option for TTSDatase

This commit is contained in:
Eren Gölge 2022-01-21 15:29:06 +00:00
parent c4c471d61d
commit ef63c99524
5 changed files with 35 additions and 12 deletions

View File

@ -172,6 +172,10 @@ class BaseTTSConfig(BaseTrainingConfig):
use_noise_augment (bool): use_noise_augment (bool):
Augment the input audio with random noise. Augment the input audio with random noise.
start_by_longest (bool):
If True, the data loader will start loading the longest batch first. It is useful for checking OOM issues.
Defaults to False.
add_blank (bool): add_blank (bool):
Add blank characters between each other two characters. It improves performance for some models at expense Add blank characters between each other two characters. It improves performance for some models at expense
of slower run-time due to the longer input sequence. of slower run-time due to the longer input sequence.
@ -231,6 +235,7 @@ class BaseTTSConfig(BaseTrainingConfig):
compute_linear_spec: bool = False compute_linear_spec: bool = False
precompute_num_workers: int = 0 precompute_num_workers: int = 0
use_noise_augment: bool = False use_noise_augment: bool = False
start_by_longest: bool = False
# dataset # dataset
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()]) datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
# optimizer # optimizer

View File

@ -67,15 +67,6 @@ class VitsConfig(BaseTTSConfig):
compute_linear_spec (bool): compute_linear_spec (bool):
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`. If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
sort_by_audio_len (bool):
If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `True`.
min_seq_len (int):
Minimum sequnce length to be considered for training. Defaults to `0`.
max_seq_len (int):
Maximum sequnce length to be considered for training. Defaults to `500000`.
r (int): r (int):
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`. Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
@ -123,6 +114,7 @@ class VitsConfig(BaseTTSConfig):
feat_loss_alpha: float = 1.0 feat_loss_alpha: float = 1.0
mel_loss_alpha: float = 45.0 mel_loss_alpha: float = 45.0
dur_loss_alpha: float = 1.0 dur_loss_alpha: float = 1.0
aligner_loss_alpha = 1.0
speaker_encoder_loss_alpha: float = 1.0 speaker_encoder_loss_alpha: float = 1.0
# data loader params # data loader params
@ -130,9 +122,6 @@ class VitsConfig(BaseTTSConfig):
compute_linear_spec: bool = True compute_linear_spec: bool = True
# overrides # overrides
sort_by_audio_len: bool = True
min_seq_len: int = 0
max_seq_len: int = 500000
r: int = 1 # DO NOT CHANGE r: int = 1 # DO NOT CHANGE
add_blank: bool = True add_blank: bool = True

View File

@ -56,6 +56,7 @@ class TTSDataset(Dataset):
d_vector_mapping: Dict = None, d_vector_mapping: Dict = None,
language_id_mapping: Dict = None, language_id_mapping: Dict = None,
use_noise_augment: bool = False, use_noise_augment: bool = False,
start_by_longest: bool = False,
verbose: bool = False, verbose: bool = False,
): ):
"""Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs. """Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs.
@ -109,6 +110,8 @@ class TTSDataset(Dataset):
use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False. use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False.
start_by_longest (bool): Start by longest sequence. It is especially useful to check OOM. Defaults to False.
verbose (bool): Print diagnostic information. Defaults to false. verbose (bool): Print diagnostic information. Defaults to false.
""" """
super().__init__() super().__init__()
@ -130,6 +133,7 @@ class TTSDataset(Dataset):
self.d_vector_mapping = d_vector_mapping self.d_vector_mapping = d_vector_mapping
self.language_id_mapping = language_id_mapping self.language_id_mapping = language_id_mapping
self.use_noise_augment = use_noise_augment self.use_noise_augment = use_noise_augment
self.start_by_longest = start_by_longest
self.verbose = verbose self.verbose = verbose
self.rescue_item_idx = 1 self.rescue_item_idx = 1
@ -315,6 +319,12 @@ class TTSDataset(Dataset):
samples, audio_lengths, _ = self.select_samples_by_idx(keep_idx) samples, audio_lengths, _ = self.select_samples_by_idx(keep_idx)
sorted_idxs = self.sort_by_length(audio_lengths) sorted_idxs = self.sort_by_length(audio_lengths)
if self.start_by_longest:
longest_idxs = sorted_idxs[-1]
sorted_idxs[-1] = sorted_idxs[0]
sorted_idxs[0] = longest_idxs
samples, audio_lengths, text_lengtsh = self.select_samples_by_idx(sorted_idxs) samples, audio_lengths, text_lengtsh = self.select_samples_by_idx(sorted_idxs)
if len(samples) == 0: if len(samples) == 0:

View File

@ -290,6 +290,7 @@ class BaseTTS(BaseModel):
speaker_id_mapping=speaker_id_mapping, speaker_id_mapping=speaker_id_mapping,
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
start_by_longest=config.start_by_longest,
language_id_mapping=language_id_mapping, language_id_mapping=language_id_mapping,
) )

View File

@ -63,6 +63,7 @@ class TestTTSDataset(unittest.TestCase):
max_text_len=c.max_text_len, max_text_len=c.max_text_len,
min_audio_len=c.min_audio_len, min_audio_len=c.min_audio_len,
max_audio_len=c.max_audio_len, max_audio_len=c.max_audio_len,
start_by_longest=start_by_longest
) )
dataloader = DataLoader( dataloader = DataLoader(
dataset, dataset,
@ -142,6 +143,23 @@ class TestTTSDataset(unittest.TestCase):
self.assertGreaterEqual(avg_length, last_length) self.assertGreaterEqual(avg_length, last_length)
self.assertTrue(is_items_reordered) self.assertTrue(is_items_reordered)
def test_start_by_longest(self):
"""Test start_by_longest option.
Ther first item of the fist batch must be longer than all the other items.
"""
if ok_ljspeech:
dataloader, _ = self._create_dataloader(2, c.r, 0, True)
dataloader.dataset.preprocess_samples()
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
mel_lengths = data["mel_lengths"]
if i == 0:
max_len = mel_lengths[0]
print(mel_lengths)
self.assertTrue(all(max_len >= mel_lengths))
def test_padding_and_spectrograms(self): def test_padding_and_spectrograms(self):
def check_conditions(idx, linear_input, mel_input, stop_target, mel_lengths): def check_conditions(idx, linear_input, mel_input, stop_target, mel_lengths):
self.assertNotEqual(linear_input[idx, -1].sum(), 0) # check padding self.assertNotEqual(linear_input[idx, -1].sum(), 0) # check padding