Updates BaseTTS and configs

This commit is contained in:
Eren Gölge 2021-11-30 15:52:01 +01:00
parent 176b712c1a
commit 4cd690e4c1
2 changed files with 24 additions and 49 deletions

View File

@ -146,11 +146,19 @@ class BaseTTSConfig(BaseTrainingConfig):
sort_by_audio_len (bool):
If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `False`.
min_seq_len (int):
Minimum sequence length to be used at training.
min_text_len (int):
Minimum length of input text to be used. All shorter samples will be ignored. Defaults to 0.
max_seq_len (int):
Maximum sequence length to be used at training. Larger values result in more VRAM usage.
max_text_len (int):
Maximum length of input text to be used. All longer samples will be ignored. Defaults to float("inf").
min_audio_len (int):
Minimum length of input audio to be used. All shorter samples will be ignored. Defaults to 0.
max_audio_len (int):
Maximum length of input audio to be used. All longer samples will be ignored. The maximum length in the
dataset defines the VRAM used in the training. Hence, pay attention to this value if you encounter an
OOM error in training. Defaults to float("inf").
compute_f0 (int):
(Not in use yet).
@ -211,8 +219,10 @@ class BaseTTSConfig(BaseTrainingConfig):
loss_masking: bool = None
# dataloading
sort_by_audio_len: bool = False
min_seq_len: int = 1
max_seq_len: int = float("inf")
min_audio_len: int = 1
max_audio_len: int = float("inf")
min_text_len: int = 1
max_text_len: int = float("inf")
compute_f0: bool = False
compute_linear_spec: bool = False
use_noise_augment: bool = False

View File

@ -168,8 +168,8 @@ class BaseTTS(BaseModel):
Dict: [description]
"""
# setup input batch
text_input = batch["text"]
text_lengths = batch["text_lengths"]
text_input = batch["token_id"]
text_lengths = batch["token_id_lengths"]
speaker_names = batch["speaker_names"]
linear_input = batch["linear"]
mel_input = batch["mel"]
@ -261,10 +261,6 @@ class BaseTTS(BaseModel):
d_vector_mapping = None
# setup custom symbols if needed
custom_symbols = None
if hasattr(self, "make_symbols"):
custom_symbols = self.make_symbols(self.config)
if hasattr(self, "language_manager"):
language_id_mapping = (
self.language_manager.language_id_mapping if self.args.use_language_embedding else None
@ -282,8 +278,10 @@ class BaseTTS(BaseModel):
ap=self.ap,
return_wav=config.return_wav if "return_wav" in config else False,
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
min_seq_len=config.min_seq_len,
max_seq_len=config.max_seq_len,
min_text_len=config.min_text_len,
max_text_len=config.max_text_len,
min_audio_len=config.min_audio_len,
max_audio_len=config.max_audio_len,
phoneme_cache_path=config.phoneme_cache_path,
use_noise_augment=False if is_eval else config.use_noise_augment,
verbose=verbose,
@ -292,45 +290,12 @@ class BaseTTS(BaseModel):
tokenizer=self.tokenizer,
)
# pre-compute phonemes
if config.use_phonemes and config.compute_input_seq_cache and rank in [None, 0]:
if hasattr(self, "eval_data_items") and is_eval:
dataset.items = self.eval_data_items
elif hasattr(self, "train_data_items") and not is_eval:
dataset.items = self.train_data_items
else:
# precompute phonemes for precise estimate of sequence lengths.
# otherwise `dataset.sort_items()` uses raw text lengths
dataset.compute_input_seq(config.num_loader_workers)
# TODO: find a more efficient solution
# cheap hack - store items in the model state to avoid recomputing when reinit the dataset
if is_eval:
self.eval_data_items = dataset.items
else:
self.train_data_items = dataset.items
# halt DDP processes for the main process to finish computing the phoneme cache
# wait all the DDP process to be ready
if num_gpus > 1:
dist.barrier()
# sort input sequences from short to long
dataset.sort_and_filter_items(config.get("sort_by_audio_len", default=False))
# compute pitch frames and write to files.
if config.compute_f0 and rank in [None, 0]:
if not os.path.exists(config.f0_cache_path):
dataset.pitch_extractor.compute_pitch(
self.ap, config.get("f0_cache_path", None), config.num_loader_workers
)
# halt DDP processes for the main process to finish computing the F0 cache
if num_gpus > 1:
dist.barrier()
# load pitch stats computed above by all the workers
if config.compute_f0:
dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None))
dataset.preprocess_samples()
# sampler for DDP
sampler = DistributedSampler(dataset) if num_gpus > 1 else None