diff --git a/TTS/tts/configs/shared_configs.py b/TTS/tts/configs/shared_configs.py index ad3bbe70..5c271f07 100644 --- a/TTS/tts/configs/shared_configs.py +++ b/TTS/tts/configs/shared_configs.py @@ -172,6 +172,10 @@ class BaseTTSConfig(BaseTrainingConfig): use_noise_augment (bool): Augment the input audio with random noise. + start_by_longest (bool): + If True, the data loader will start loading the longest batch first. It is useful for checking OOM issues. + Defaults to False. + add_blank (bool): Add blank characters between each other two characters. It improves performance for some models at expense of slower run-time due to the longer input sequence. @@ -231,6 +235,7 @@ class BaseTTSConfig(BaseTrainingConfig): compute_linear_spec: bool = False precompute_num_workers: int = 0 use_noise_augment: bool = False + start_by_longest: bool = False # dataset datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()]) # optimizer diff --git a/TTS/tts/configs/vits_config.py b/TTS/tts/configs/vits_config.py index 36c948af..d306552d 100644 --- a/TTS/tts/configs/vits_config.py +++ b/TTS/tts/configs/vits_config.py @@ -67,15 +67,6 @@ class VitsConfig(BaseTTSConfig): compute_linear_spec (bool): If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`. - sort_by_audio_len (bool): - If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `True`. - - min_seq_len (int): - Minimum sequnce length to be considered for training. Defaults to `0`. - - max_seq_len (int): - Maximum sequnce length to be considered for training. Defaults to `500000`. - r (int): Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`. @@ -123,6 +114,7 @@ class VitsConfig(BaseTTSConfig): feat_loss_alpha: float = 1.0 mel_loss_alpha: float = 45.0 dur_loss_alpha: float = 1.0 + aligner_loss_alpha = 1.0 speaker_encoder_loss_alpha: float = 1.0 # data loader params @@ -130,9 +122,6 @@ class VitsConfig(BaseTTSConfig): compute_linear_spec: bool = True # overrides - sort_by_audio_len: bool = True - min_seq_len: int = 0 - max_seq_len: int = 500000 r: int = 1 # DO NOT CHANGE add_blank: bool = True diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index a98afc95..a1bb23c3 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -56,6 +56,7 @@ class TTSDataset(Dataset): d_vector_mapping: Dict = None, language_id_mapping: Dict = None, use_noise_augment: bool = False, + start_by_longest: bool = False, verbose: bool = False, ): """Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs. @@ -109,6 +110,8 @@ class TTSDataset(Dataset): use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False. + start_by_longest (bool): Start by longest sequence. It is especially useful to check OOM. Defaults to False. + verbose (bool): Print diagnostic information. Defaults to false. """ super().__init__() @@ -130,6 +133,7 @@ class TTSDataset(Dataset): self.d_vector_mapping = d_vector_mapping self.language_id_mapping = language_id_mapping self.use_noise_augment = use_noise_augment + self.start_by_longest = start_by_longest self.verbose = verbose self.rescue_item_idx = 1 @@ -315,6 +319,12 @@ class TTSDataset(Dataset): samples, audio_lengths, _ = self.select_samples_by_idx(keep_idx) sorted_idxs = self.sort_by_length(audio_lengths) + + if self.start_by_longest: + longest_idxs = sorted_idxs[-1] + sorted_idxs[-1] = sorted_idxs[0] + sorted_idxs[0] = longest_idxs + samples, audio_lengths, text_lengtsh = self.select_samples_by_idx(sorted_idxs) if len(samples) == 0: diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 9a6a56df..7cdfa915 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -290,6 +290,7 @@ class BaseTTS(BaseModel): speaker_id_mapping=speaker_id_mapping, d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, tokenizer=self.tokenizer, + start_by_longest=config.start_by_longest, language_id_mapping=language_id_mapping, ) diff --git a/tests/data_tests/test_loader.py b/tests/data_tests/test_loader.py index 3ecd42e1..f96154bc 100644 --- a/tests/data_tests/test_loader.py +++ b/tests/data_tests/test_loader.py @@ -63,6 +63,7 @@ class TestTTSDataset(unittest.TestCase): max_text_len=c.max_text_len, min_audio_len=c.min_audio_len, max_audio_len=c.max_audio_len, + start_by_longest=start_by_longest ) dataloader = DataLoader( dataset, @@ -142,6 +143,23 @@ class TestTTSDataset(unittest.TestCase): self.assertGreaterEqual(avg_length, last_length) self.assertTrue(is_items_reordered) + def test_start_by_longest(self): + """Test start_by_longest option. + + Ther first item of the fist batch must be longer than all the other items. + """ + if ok_ljspeech: + dataloader, _ = self._create_dataloader(2, c.r, 0, True) + dataloader.dataset.preprocess_samples() + for i, data in enumerate(dataloader): + if i == self.max_loader_iter: + break + mel_lengths = data["mel_lengths"] + if i == 0: + max_len = mel_lengths[0] + print(mel_lengths) + self.assertTrue(all(max_len >= mel_lengths)) + def test_padding_and_spectrograms(self): def check_conditions(idx, linear_input, mel_input, stop_target, mel_lengths): self.assertNotEqual(linear_input[idx, -1].sum(), 0) # check padding