mirror of https://github.com/coqui-ai/TTS.git
Implement `start_by_longest` option for TTSDatase
This commit is contained in:
parent
c4c471d61d
commit
ef63c99524
|
@ -172,6 +172,10 @@ class BaseTTSConfig(BaseTrainingConfig):
|
||||||
use_noise_augment (bool):
|
use_noise_augment (bool):
|
||||||
Augment the input audio with random noise.
|
Augment the input audio with random noise.
|
||||||
|
|
||||||
|
start_by_longest (bool):
|
||||||
|
If True, the data loader will start loading the longest batch first. It is useful for checking OOM issues.
|
||||||
|
Defaults to False.
|
||||||
|
|
||||||
add_blank (bool):
|
add_blank (bool):
|
||||||
Add blank characters between each other two characters. It improves performance for some models at expense
|
Add blank characters between each other two characters. It improves performance for some models at expense
|
||||||
of slower run-time due to the longer input sequence.
|
of slower run-time due to the longer input sequence.
|
||||||
|
@ -231,6 +235,7 @@ class BaseTTSConfig(BaseTrainingConfig):
|
||||||
compute_linear_spec: bool = False
|
compute_linear_spec: bool = False
|
||||||
precompute_num_workers: int = 0
|
precompute_num_workers: int = 0
|
||||||
use_noise_augment: bool = False
|
use_noise_augment: bool = False
|
||||||
|
start_by_longest: bool = False
|
||||||
# dataset
|
# dataset
|
||||||
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
|
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
|
||||||
# optimizer
|
# optimizer
|
||||||
|
|
|
@ -67,15 +67,6 @@ class VitsConfig(BaseTTSConfig):
|
||||||
compute_linear_spec (bool):
|
compute_linear_spec (bool):
|
||||||
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
|
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
|
||||||
|
|
||||||
sort_by_audio_len (bool):
|
|
||||||
If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `True`.
|
|
||||||
|
|
||||||
min_seq_len (int):
|
|
||||||
Minimum sequnce length to be considered for training. Defaults to `0`.
|
|
||||||
|
|
||||||
max_seq_len (int):
|
|
||||||
Maximum sequnce length to be considered for training. Defaults to `500000`.
|
|
||||||
|
|
||||||
r (int):
|
r (int):
|
||||||
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
|
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
|
||||||
|
|
||||||
|
@ -123,6 +114,7 @@ class VitsConfig(BaseTTSConfig):
|
||||||
feat_loss_alpha: float = 1.0
|
feat_loss_alpha: float = 1.0
|
||||||
mel_loss_alpha: float = 45.0
|
mel_loss_alpha: float = 45.0
|
||||||
dur_loss_alpha: float = 1.0
|
dur_loss_alpha: float = 1.0
|
||||||
|
aligner_loss_alpha = 1.0
|
||||||
speaker_encoder_loss_alpha: float = 1.0
|
speaker_encoder_loss_alpha: float = 1.0
|
||||||
|
|
||||||
# data loader params
|
# data loader params
|
||||||
|
@ -130,9 +122,6 @@ class VitsConfig(BaseTTSConfig):
|
||||||
compute_linear_spec: bool = True
|
compute_linear_spec: bool = True
|
||||||
|
|
||||||
# overrides
|
# overrides
|
||||||
sort_by_audio_len: bool = True
|
|
||||||
min_seq_len: int = 0
|
|
||||||
max_seq_len: int = 500000
|
|
||||||
r: int = 1 # DO NOT CHANGE
|
r: int = 1 # DO NOT CHANGE
|
||||||
add_blank: bool = True
|
add_blank: bool = True
|
||||||
|
|
||||||
|
|
|
@ -56,6 +56,7 @@ class TTSDataset(Dataset):
|
||||||
d_vector_mapping: Dict = None,
|
d_vector_mapping: Dict = None,
|
||||||
language_id_mapping: Dict = None,
|
language_id_mapping: Dict = None,
|
||||||
use_noise_augment: bool = False,
|
use_noise_augment: bool = False,
|
||||||
|
start_by_longest: bool = False,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
):
|
):
|
||||||
"""Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs.
|
"""Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs.
|
||||||
|
@ -109,6 +110,8 @@ class TTSDataset(Dataset):
|
||||||
|
|
||||||
use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False.
|
use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False.
|
||||||
|
|
||||||
|
start_by_longest (bool): Start by longest sequence. It is especially useful to check OOM. Defaults to False.
|
||||||
|
|
||||||
verbose (bool): Print diagnostic information. Defaults to false.
|
verbose (bool): Print diagnostic information. Defaults to false.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -130,6 +133,7 @@ class TTSDataset(Dataset):
|
||||||
self.d_vector_mapping = d_vector_mapping
|
self.d_vector_mapping = d_vector_mapping
|
||||||
self.language_id_mapping = language_id_mapping
|
self.language_id_mapping = language_id_mapping
|
||||||
self.use_noise_augment = use_noise_augment
|
self.use_noise_augment = use_noise_augment
|
||||||
|
self.start_by_longest = start_by_longest
|
||||||
|
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.rescue_item_idx = 1
|
self.rescue_item_idx = 1
|
||||||
|
@ -315,6 +319,12 @@ class TTSDataset(Dataset):
|
||||||
samples, audio_lengths, _ = self.select_samples_by_idx(keep_idx)
|
samples, audio_lengths, _ = self.select_samples_by_idx(keep_idx)
|
||||||
|
|
||||||
sorted_idxs = self.sort_by_length(audio_lengths)
|
sorted_idxs = self.sort_by_length(audio_lengths)
|
||||||
|
|
||||||
|
if self.start_by_longest:
|
||||||
|
longest_idxs = sorted_idxs[-1]
|
||||||
|
sorted_idxs[-1] = sorted_idxs[0]
|
||||||
|
sorted_idxs[0] = longest_idxs
|
||||||
|
|
||||||
samples, audio_lengths, text_lengtsh = self.select_samples_by_idx(sorted_idxs)
|
samples, audio_lengths, text_lengtsh = self.select_samples_by_idx(sorted_idxs)
|
||||||
|
|
||||||
if len(samples) == 0:
|
if len(samples) == 0:
|
||||||
|
|
|
@ -290,6 +290,7 @@ class BaseTTS(BaseModel):
|
||||||
speaker_id_mapping=speaker_id_mapping,
|
speaker_id_mapping=speaker_id_mapping,
|
||||||
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
|
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
|
start_by_longest=config.start_by_longest,
|
||||||
language_id_mapping=language_id_mapping,
|
language_id_mapping=language_id_mapping,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -63,6 +63,7 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
max_text_len=c.max_text_len,
|
max_text_len=c.max_text_len,
|
||||||
min_audio_len=c.min_audio_len,
|
min_audio_len=c.min_audio_len,
|
||||||
max_audio_len=c.max_audio_len,
|
max_audio_len=c.max_audio_len,
|
||||||
|
start_by_longest=start_by_longest
|
||||||
)
|
)
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
|
@ -142,6 +143,23 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
self.assertGreaterEqual(avg_length, last_length)
|
self.assertGreaterEqual(avg_length, last_length)
|
||||||
self.assertTrue(is_items_reordered)
|
self.assertTrue(is_items_reordered)
|
||||||
|
|
||||||
|
def test_start_by_longest(self):
|
||||||
|
"""Test start_by_longest option.
|
||||||
|
|
||||||
|
Ther first item of the fist batch must be longer than all the other items.
|
||||||
|
"""
|
||||||
|
if ok_ljspeech:
|
||||||
|
dataloader, _ = self._create_dataloader(2, c.r, 0, True)
|
||||||
|
dataloader.dataset.preprocess_samples()
|
||||||
|
for i, data in enumerate(dataloader):
|
||||||
|
if i == self.max_loader_iter:
|
||||||
|
break
|
||||||
|
mel_lengths = data["mel_lengths"]
|
||||||
|
if i == 0:
|
||||||
|
max_len = mel_lengths[0]
|
||||||
|
print(mel_lengths)
|
||||||
|
self.assertTrue(all(max_len >= mel_lengths))
|
||||||
|
|
||||||
def test_padding_and_spectrograms(self):
|
def test_padding_and_spectrograms(self):
|
||||||
def check_conditions(idx, linear_input, mel_input, stop_target, mel_lengths):
|
def check_conditions(idx, linear_input, mel_input, stop_target, mel_lengths):
|
||||||
self.assertNotEqual(linear_input[idx, -1].sum(), 0) # check padding
|
self.assertNotEqual(linear_input[idx, -1].sum(), 0) # check padding
|
||||||
|
|
Loading…
Reference in New Issue