diff --git a/TTS/bin/compute_attention_masks.py b/TTS/bin/compute_attention_masks.py index 88d60d7d..9637fe75 100644 --- a/TTS/bin/compute_attention_masks.py +++ b/TTS/bin/compute_attention_masks.py @@ -97,7 +97,7 @@ Example run: enable_eos_bos=C.enable_eos_bos_chars, ) - dataset.sort_items() + dataset.sort_and_filter_items(C.get("sort_by_audio_len", default=False)) loader = DataLoader( dataset, batch_size=args.batch_size, diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index debe5933..0eee3083 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -46,7 +46,7 @@ def setup_loader(ap, r, verbose=False): if c.use_phonemes and c.compute_input_seq_cache: # precompute phonemes to have a better estimate of sequence lengths. dataset.compute_input_seq(c.num_loader_workers) - dataset.sort_items() + dataset.sort_and_filter_items(c.get("sort_by_audio_len", default=False)) loader = DataLoader( dataset, diff --git a/TTS/trainer.py b/TTS/trainer.py index 6a5c925a..b22595b3 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- import importlib -import logging import multiprocessing import os import platform diff --git a/TTS/tts/configs/vits_config.py b/TTS/tts/configs/vits_config.py index 3bf0b13a..58fc66ee 100644 --- a/TTS/tts/configs/vits_config.py +++ b/TTS/tts/configs/vits_config.py @@ -120,8 +120,9 @@ class VitsConfig(BaseTTSConfig): compute_linear_spec: bool = True # overrides - min_seq_len: int = 32 - max_seq_len: int = 1000 + sort_by_audio_len: bool = True + min_seq_len: int = 0 + max_seq_len: int = 500000 r: int = 1 # DO NOT CHANGE add_blank: bool = True diff --git a/TTS/tts/datasets/TTSDataset.py b/TTS/tts/datasets/TTSDataset.py index 89326c9c..5d38243e 100644 --- a/TTS/tts/datasets/TTSDataset.py +++ b/TTS/tts/datasets/TTSDataset.py @@ -69,7 +69,7 @@ class TTSDataset(Dataset): batch. Set 0 to disable. Defaults to 0. min_seq_len (int): Minimum input sequence length to be processed - by the loader. Filter out input sequences that are shorter than this. Some models have a + by sort_inputs`. Filter out input sequences that are shorter than this. Some models have a minimum input length due to its architecture. Defaults to 0. max_seq_len (int): Maximum input sequence length. Filter out input sequences that are longer than this. @@ -302,10 +302,23 @@ class TTSDataset(Dataset): for idx, p in enumerate(phonemes): self.items[idx][0] = p - def sort_items(self): - r"""Sort instances based on text length in ascending order""" - lengths = np.array([len(ins[0]) for ins in self.items]) + def sort_and_filter_items(self, by_audio_len=False): + r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length + range. + Args: + by_audio_len (bool): if True, sort by audio length else by text length. + """ + # compute the target sequence length + if by_audio_len: + lengths = [] + for item in self.items: + lengths.append(os.path.getsize(item[1])) + lengths = np.array(lengths) + else: + lengths = np.array([len(ins[0]) for ins in self.items]) + + # sort items based on the sequence length in ascending order idxs = np.argsort(lengths) new_items = [] ignored = [] @@ -315,7 +328,10 @@ class TTSDataset(Dataset): ignored.append(idx) else: new_items.append(self.items[idx]) + # shuffle batch groups + # create batches with similar length items + # the larger the `batch_group_size`, the higher the length variety in a batch. if self.batch_group_size > 0: for i in range(len(new_items) // self.batch_group_size): offset = i * self.batch_group_size @@ -325,6 +341,7 @@ class TTSDataset(Dataset): new_items[offset:end_offset] = temp_items self.items = new_items + # logging if self.verbose: print(" | > Max length sequence: {}".format(np.max(lengths))) print(" | > Min length sequence: {}".format(np.min(lengths))) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 7f40b3eb..922761cb 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -243,7 +243,7 @@ class BaseTTS(BaseModel): dist.barrier() # sort input sequences from short to long - dataset.sort_items() + dataset.sort_and_filter_items(config.get("sort_by_audio_len", default=False)) # sampler for DDP sampler = DistributedSampler(dataset) if num_gpus > 1 else None diff --git a/recipes/ljspeech/vits_tts/train_vits.py b/recipes/ljspeech/vits_tts/train_vits.py index 45e9d429..8bff0ab7 100644 --- a/recipes/ljspeech/vits_tts/train_vits.py +++ b/recipes/ljspeech/vits_tts/train_vits.py @@ -43,7 +43,7 @@ config = VitsConfig( print_step=25, print_eval=True, mixed_precision=True, - max_seq_len=5000, + max_seq_len=500000, output_path=output_path, datasets=[dataset_config], ) diff --git a/tests/data_tests/test_loader.py b/tests/data_tests/test_loader.py index 10067094..717b2e0f 100644 --- a/tests/data_tests/test_loader.py +++ b/tests/data_tests/test_loader.py @@ -124,7 +124,7 @@ class TestTTSDataset(unittest.TestCase): avg_length = mel_lengths.numpy().mean() assert avg_length >= last_length - dataloader.dataset.sort_items() + dataloader.dataset.sort_and_filter_items() is_items_reordered = False for idx, item in enumerate(dataloader.dataset.items): if item != frames[idx]: