mirror of https://github.com/coqui-ai/TTS.git
Updates BaseTTS and configs
This commit is contained in:
parent
176b712c1a
commit
4cd690e4c1
|
@ -146,11 +146,19 @@ class BaseTTSConfig(BaseTrainingConfig):
|
||||||
sort_by_audio_len (bool):
|
sort_by_audio_len (bool):
|
||||||
If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `False`.
|
If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `False`.
|
||||||
|
|
||||||
min_seq_len (int):
|
min_text_len (int):
|
||||||
Minimum sequence length to be used at training.
|
Minimum length of input text to be used. All shorter samples will be ignored. Defaults to 0.
|
||||||
|
|
||||||
max_seq_len (int):
|
max_text_len (int):
|
||||||
Maximum sequence length to be used at training. Larger values result in more VRAM usage.
|
Maximum length of input text to be used. All longer samples will be ignored. Defaults to float("inf").
|
||||||
|
|
||||||
|
min_audio_len (int):
|
||||||
|
Minimum length of input audio to be used. All shorter samples will be ignored. Defaults to 0.
|
||||||
|
|
||||||
|
max_audio_len (int):
|
||||||
|
Maximum length of input audio to be used. All longer samples will be ignored. The maximum length in the
|
||||||
|
dataset defines the VRAM used in the training. Hence, pay attention to this value if you encounter an
|
||||||
|
OOM error in training. Defaults to float("inf").
|
||||||
|
|
||||||
compute_f0 (int):
|
compute_f0 (int):
|
||||||
(Not in use yet).
|
(Not in use yet).
|
||||||
|
@ -211,8 +219,10 @@ class BaseTTSConfig(BaseTrainingConfig):
|
||||||
loss_masking: bool = None
|
loss_masking: bool = None
|
||||||
# dataloading
|
# dataloading
|
||||||
sort_by_audio_len: bool = False
|
sort_by_audio_len: bool = False
|
||||||
min_seq_len: int = 1
|
min_audio_len: int = 1
|
||||||
max_seq_len: int = float("inf")
|
max_audio_len: int = float("inf")
|
||||||
|
min_text_len: int = 1
|
||||||
|
max_text_len: int = float("inf")
|
||||||
compute_f0: bool = False
|
compute_f0: bool = False
|
||||||
compute_linear_spec: bool = False
|
compute_linear_spec: bool = False
|
||||||
use_noise_augment: bool = False
|
use_noise_augment: bool = False
|
||||||
|
|
|
@ -168,8 +168,8 @@ class BaseTTS(BaseModel):
|
||||||
Dict: [description]
|
Dict: [description]
|
||||||
"""
|
"""
|
||||||
# setup input batch
|
# setup input batch
|
||||||
text_input = batch["text"]
|
text_input = batch["token_id"]
|
||||||
text_lengths = batch["text_lengths"]
|
text_lengths = batch["token_id_lengths"]
|
||||||
speaker_names = batch["speaker_names"]
|
speaker_names = batch["speaker_names"]
|
||||||
linear_input = batch["linear"]
|
linear_input = batch["linear"]
|
||||||
mel_input = batch["mel"]
|
mel_input = batch["mel"]
|
||||||
|
@ -261,10 +261,6 @@ class BaseTTS(BaseModel):
|
||||||
d_vector_mapping = None
|
d_vector_mapping = None
|
||||||
|
|
||||||
# setup custom symbols if needed
|
# setup custom symbols if needed
|
||||||
custom_symbols = None
|
|
||||||
if hasattr(self, "make_symbols"):
|
|
||||||
custom_symbols = self.make_symbols(self.config)
|
|
||||||
|
|
||||||
if hasattr(self, "language_manager"):
|
if hasattr(self, "language_manager"):
|
||||||
language_id_mapping = (
|
language_id_mapping = (
|
||||||
self.language_manager.language_id_mapping if self.args.use_language_embedding else None
|
self.language_manager.language_id_mapping if self.args.use_language_embedding else None
|
||||||
|
@ -282,8 +278,10 @@ class BaseTTS(BaseModel):
|
||||||
ap=self.ap,
|
ap=self.ap,
|
||||||
return_wav=config.return_wav if "return_wav" in config else False,
|
return_wav=config.return_wav if "return_wav" in config else False,
|
||||||
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
|
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
|
||||||
min_seq_len=config.min_seq_len,
|
min_text_len=config.min_text_len,
|
||||||
max_seq_len=config.max_seq_len,
|
max_text_len=config.max_text_len,
|
||||||
|
min_audio_len=config.min_audio_len,
|
||||||
|
max_audio_len=config.max_audio_len,
|
||||||
phoneme_cache_path=config.phoneme_cache_path,
|
phoneme_cache_path=config.phoneme_cache_path,
|
||||||
use_noise_augment=False if is_eval else config.use_noise_augment,
|
use_noise_augment=False if is_eval else config.use_noise_augment,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
|
@ -292,45 +290,12 @@ class BaseTTS(BaseModel):
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
# pre-compute phonemes
|
# wait all the DDP process to be ready
|
||||||
if config.use_phonemes and config.compute_input_seq_cache and rank in [None, 0]:
|
|
||||||
if hasattr(self, "eval_data_items") and is_eval:
|
|
||||||
dataset.items = self.eval_data_items
|
|
||||||
elif hasattr(self, "train_data_items") and not is_eval:
|
|
||||||
dataset.items = self.train_data_items
|
|
||||||
else:
|
|
||||||
# precompute phonemes for precise estimate of sequence lengths.
|
|
||||||
# otherwise `dataset.sort_items()` uses raw text lengths
|
|
||||||
dataset.compute_input_seq(config.num_loader_workers)
|
|
||||||
|
|
||||||
# TODO: find a more efficient solution
|
|
||||||
# cheap hack - store items in the model state to avoid recomputing when reinit the dataset
|
|
||||||
if is_eval:
|
|
||||||
self.eval_data_items = dataset.items
|
|
||||||
else:
|
|
||||||
self.train_data_items = dataset.items
|
|
||||||
|
|
||||||
# halt DDP processes for the main process to finish computing the phoneme cache
|
|
||||||
if num_gpus > 1:
|
if num_gpus > 1:
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
# sort input sequences from short to long
|
# sort input sequences from short to long
|
||||||
dataset.sort_and_filter_items(config.get("sort_by_audio_len", default=False))
|
dataset.preprocess_samples()
|
||||||
|
|
||||||
# compute pitch frames and write to files.
|
|
||||||
if config.compute_f0 and rank in [None, 0]:
|
|
||||||
if not os.path.exists(config.f0_cache_path):
|
|
||||||
dataset.pitch_extractor.compute_pitch(
|
|
||||||
self.ap, config.get("f0_cache_path", None), config.num_loader_workers
|
|
||||||
)
|
|
||||||
|
|
||||||
# halt DDP processes for the main process to finish computing the F0 cache
|
|
||||||
if num_gpus > 1:
|
|
||||||
dist.barrier()
|
|
||||||
|
|
||||||
# load pitch stats computed above by all the workers
|
|
||||||
if config.compute_f0:
|
|
||||||
dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None))
|
|
||||||
|
|
||||||
# sampler for DDP
|
# sampler for DDP
|
||||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||||
|
|
Loading…
Reference in New Issue