mirror of https://github.com/coqui-ai/TTS.git
Updates BaseTTS and configs
This commit is contained in:
parent
176b712c1a
commit
4cd690e4c1
|
@ -146,11 +146,19 @@ class BaseTTSConfig(BaseTrainingConfig):
|
|||
sort_by_audio_len (bool):
|
||||
If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `False`.
|
||||
|
||||
min_seq_len (int):
|
||||
Minimum sequence length to be used at training.
|
||||
min_text_len (int):
|
||||
Minimum length of input text to be used. All shorter samples will be ignored. Defaults to 0.
|
||||
|
||||
max_seq_len (int):
|
||||
Maximum sequence length to be used at training. Larger values result in more VRAM usage.
|
||||
max_text_len (int):
|
||||
Maximum length of input text to be used. All longer samples will be ignored. Defaults to float("inf").
|
||||
|
||||
min_audio_len (int):
|
||||
Minimum length of input audio to be used. All shorter samples will be ignored. Defaults to 0.
|
||||
|
||||
max_audio_len (int):
|
||||
Maximum length of input audio to be used. All longer samples will be ignored. The maximum length in the
|
||||
dataset defines the VRAM used in the training. Hence, pay attention to this value if you encounter an
|
||||
OOM error in training. Defaults to float("inf").
|
||||
|
||||
compute_f0 (int):
|
||||
(Not in use yet).
|
||||
|
@ -211,8 +219,10 @@ class BaseTTSConfig(BaseTrainingConfig):
|
|||
loss_masking: bool = None
|
||||
# dataloading
|
||||
sort_by_audio_len: bool = False
|
||||
min_seq_len: int = 1
|
||||
max_seq_len: int = float("inf")
|
||||
min_audio_len: int = 1
|
||||
max_audio_len: int = float("inf")
|
||||
min_text_len: int = 1
|
||||
max_text_len: int = float("inf")
|
||||
compute_f0: bool = False
|
||||
compute_linear_spec: bool = False
|
||||
use_noise_augment: bool = False
|
||||
|
|
|
@ -168,8 +168,8 @@ class BaseTTS(BaseModel):
|
|||
Dict: [description]
|
||||
"""
|
||||
# setup input batch
|
||||
text_input = batch["text"]
|
||||
text_lengths = batch["text_lengths"]
|
||||
text_input = batch["token_id"]
|
||||
text_lengths = batch["token_id_lengths"]
|
||||
speaker_names = batch["speaker_names"]
|
||||
linear_input = batch["linear"]
|
||||
mel_input = batch["mel"]
|
||||
|
@ -261,10 +261,6 @@ class BaseTTS(BaseModel):
|
|||
d_vector_mapping = None
|
||||
|
||||
# setup custom symbols if needed
|
||||
custom_symbols = None
|
||||
if hasattr(self, "make_symbols"):
|
||||
custom_symbols = self.make_symbols(self.config)
|
||||
|
||||
if hasattr(self, "language_manager"):
|
||||
language_id_mapping = (
|
||||
self.language_manager.language_id_mapping if self.args.use_language_embedding else None
|
||||
|
@ -282,8 +278,10 @@ class BaseTTS(BaseModel):
|
|||
ap=self.ap,
|
||||
return_wav=config.return_wav if "return_wav" in config else False,
|
||||
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
|
||||
min_seq_len=config.min_seq_len,
|
||||
max_seq_len=config.max_seq_len,
|
||||
min_text_len=config.min_text_len,
|
||||
max_text_len=config.max_text_len,
|
||||
min_audio_len=config.min_audio_len,
|
||||
max_audio_len=config.max_audio_len,
|
||||
phoneme_cache_path=config.phoneme_cache_path,
|
||||
use_noise_augment=False if is_eval else config.use_noise_augment,
|
||||
verbose=verbose,
|
||||
|
@ -292,45 +290,12 @@ class BaseTTS(BaseModel):
|
|||
tokenizer=self.tokenizer,
|
||||
)
|
||||
|
||||
# pre-compute phonemes
|
||||
if config.use_phonemes and config.compute_input_seq_cache and rank in [None, 0]:
|
||||
if hasattr(self, "eval_data_items") and is_eval:
|
||||
dataset.items = self.eval_data_items
|
||||
elif hasattr(self, "train_data_items") and not is_eval:
|
||||
dataset.items = self.train_data_items
|
||||
else:
|
||||
# precompute phonemes for precise estimate of sequence lengths.
|
||||
# otherwise `dataset.sort_items()` uses raw text lengths
|
||||
dataset.compute_input_seq(config.num_loader_workers)
|
||||
|
||||
# TODO: find a more efficient solution
|
||||
# cheap hack - store items in the model state to avoid recomputing when reinit the dataset
|
||||
if is_eval:
|
||||
self.eval_data_items = dataset.items
|
||||
else:
|
||||
self.train_data_items = dataset.items
|
||||
|
||||
# halt DDP processes for the main process to finish computing the phoneme cache
|
||||
# wait all the DDP process to be ready
|
||||
if num_gpus > 1:
|
||||
dist.barrier()
|
||||
|
||||
# sort input sequences from short to long
|
||||
dataset.sort_and_filter_items(config.get("sort_by_audio_len", default=False))
|
||||
|
||||
# compute pitch frames and write to files.
|
||||
if config.compute_f0 and rank in [None, 0]:
|
||||
if not os.path.exists(config.f0_cache_path):
|
||||
dataset.pitch_extractor.compute_pitch(
|
||||
self.ap, config.get("f0_cache_path", None), config.num_loader_workers
|
||||
)
|
||||
|
||||
# halt DDP processes for the main process to finish computing the F0 cache
|
||||
if num_gpus > 1:
|
||||
dist.barrier()
|
||||
|
||||
# load pitch stats computed above by all the workers
|
||||
if config.compute_f0:
|
||||
dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None))
|
||||
dataset.preprocess_samples()
|
||||
|
||||
# sampler for DDP
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
|
|
Loading…
Reference in New Issue