mirror of https://github.com/coqui-ai/TTS.git
Add option to sort input sequnce by audio len
This commit is contained in:
parent
695a6439d3
commit
f186856e5d
|
@ -97,7 +97,7 @@ Example run:
|
||||||
enable_eos_bos=C.enable_eos_bos_chars,
|
enable_eos_bos=C.enable_eos_bos_chars,
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset.sort_items()
|
dataset.sort_and_filter_items(C.get("sort_by_audio_len", default=False))
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
|
|
|
@ -46,7 +46,7 @@ def setup_loader(ap, r, verbose=False):
|
||||||
if c.use_phonemes and c.compute_input_seq_cache:
|
if c.use_phonemes and c.compute_input_seq_cache:
|
||||||
# precompute phonemes to have a better estimate of sequence lengths.
|
# precompute phonemes to have a better estimate of sequence lengths.
|
||||||
dataset.compute_input_seq(c.num_loader_workers)
|
dataset.compute_input_seq(c.num_loader_workers)
|
||||||
dataset.sort_items()
|
dataset.sort_and_filter_items(c.get("sort_by_audio_len", default=False))
|
||||||
|
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
|
|
|
@ -120,8 +120,9 @@ class VitsConfig(BaseTTSConfig):
|
||||||
compute_linear_spec: bool = True
|
compute_linear_spec: bool = True
|
||||||
|
|
||||||
# overrides
|
# overrides
|
||||||
min_seq_len: int = 32
|
sort_by_audio_len: bool = True
|
||||||
max_seq_len: int = 1000
|
min_seq_len: int = 0
|
||||||
|
max_seq_len: int = 500000
|
||||||
r: int = 1 # DO NOT CHANGE
|
r: int = 1 # DO NOT CHANGE
|
||||||
add_blank: bool = True
|
add_blank: bool = True
|
||||||
|
|
||||||
|
|
|
@ -69,7 +69,7 @@ class TTSDataset(Dataset):
|
||||||
batch. Set 0 to disable. Defaults to 0.
|
batch. Set 0 to disable. Defaults to 0.
|
||||||
|
|
||||||
min_seq_len (int): Minimum input sequence length to be processed
|
min_seq_len (int): Minimum input sequence length to be processed
|
||||||
by the loader. Filter out input sequences that are shorter than this. Some models have a
|
by sort_inputs`. Filter out input sequences that are shorter than this. Some models have a
|
||||||
minimum input length due to its architecture. Defaults to 0.
|
minimum input length due to its architecture. Defaults to 0.
|
||||||
|
|
||||||
max_seq_len (int): Maximum input sequence length. Filter out input sequences that are longer than this.
|
max_seq_len (int): Maximum input sequence length. Filter out input sequences that are longer than this.
|
||||||
|
@ -302,10 +302,23 @@ class TTSDataset(Dataset):
|
||||||
for idx, p in enumerate(phonemes):
|
for idx, p in enumerate(phonemes):
|
||||||
self.items[idx][0] = p
|
self.items[idx][0] = p
|
||||||
|
|
||||||
def sort_items(self):
|
def sort_and_filter_items(self, by_audio_len=False):
|
||||||
r"""Sort instances based on text length in ascending order"""
|
r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length
|
||||||
lengths = np.array([len(ins[0]) for ins in self.items])
|
range.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
by_audio_len (bool): if True, sort by audio length else by text length.
|
||||||
|
"""
|
||||||
|
# compute the target sequence length
|
||||||
|
if by_audio_len:
|
||||||
|
lengths = []
|
||||||
|
for item in self.items:
|
||||||
|
lengths.append(os.path.getsize(item[1]))
|
||||||
|
lengths = np.array(lengths)
|
||||||
|
else:
|
||||||
|
lengths = np.array([len(ins[0]) for ins in self.items])
|
||||||
|
|
||||||
|
# sort items based on the sequence length in ascending order
|
||||||
idxs = np.argsort(lengths)
|
idxs = np.argsort(lengths)
|
||||||
new_items = []
|
new_items = []
|
||||||
ignored = []
|
ignored = []
|
||||||
|
@ -315,7 +328,10 @@ class TTSDataset(Dataset):
|
||||||
ignored.append(idx)
|
ignored.append(idx)
|
||||||
else:
|
else:
|
||||||
new_items.append(self.items[idx])
|
new_items.append(self.items[idx])
|
||||||
|
|
||||||
# shuffle batch groups
|
# shuffle batch groups
|
||||||
|
# create batches with similar length items
|
||||||
|
# the larger the `batch_group_size`, the higher the length variety in a batch.
|
||||||
if self.batch_group_size > 0:
|
if self.batch_group_size > 0:
|
||||||
for i in range(len(new_items) // self.batch_group_size):
|
for i in range(len(new_items) // self.batch_group_size):
|
||||||
offset = i * self.batch_group_size
|
offset = i * self.batch_group_size
|
||||||
|
@ -325,6 +341,7 @@ class TTSDataset(Dataset):
|
||||||
new_items[offset:end_offset] = temp_items
|
new_items[offset:end_offset] = temp_items
|
||||||
self.items = new_items
|
self.items = new_items
|
||||||
|
|
||||||
|
# logging
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(" | > Max length sequence: {}".format(np.max(lengths)))
|
print(" | > Max length sequence: {}".format(np.max(lengths)))
|
||||||
print(" | > Min length sequence: {}".format(np.min(lengths)))
|
print(" | > Min length sequence: {}".format(np.min(lengths)))
|
||||||
|
|
|
@ -243,7 +243,7 @@ class BaseTTS(BaseModel):
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
# sort input sequences from short to long
|
# sort input sequences from short to long
|
||||||
dataset.sort_items()
|
dataset.sort_and_filter_items(config.get("sort_by_audio_len", default=False))
|
||||||
|
|
||||||
# sampler for DDP
|
# sampler for DDP
|
||||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||||
|
|
|
@ -43,7 +43,7 @@ config = VitsConfig(
|
||||||
print_step=25,
|
print_step=25,
|
||||||
print_eval=True,
|
print_eval=True,
|
||||||
mixed_precision=True,
|
mixed_precision=True,
|
||||||
max_seq_len=5000,
|
max_seq_len=500000,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
datasets=[dataset_config],
|
datasets=[dataset_config],
|
||||||
)
|
)
|
||||||
|
|
|
@ -124,7 +124,7 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
|
|
||||||
avg_length = mel_lengths.numpy().mean()
|
avg_length = mel_lengths.numpy().mean()
|
||||||
assert avg_length >= last_length
|
assert avg_length >= last_length
|
||||||
dataloader.dataset.sort_items()
|
dataloader.dataset.sort_and_filter_items()
|
||||||
is_items_reordered = False
|
is_items_reordered = False
|
||||||
for idx, item in enumerate(dataloader.dataset.items):
|
for idx, item in enumerate(dataloader.dataset.items):
|
||||||
if item != frames[idx]:
|
if item != frames[idx]:
|
||||||
|
|
Loading…
Reference in New Issue