mirror of https://github.com/coqui-ai/TTS.git
bug fix for gradual training
This commit is contained in:
parent
5c50e104d6
commit
29b17c0808
|
@ -39,33 +39,34 @@ from TTS.utils.training import (NoamLR, adam_weight_decay, check_update,
|
||||||
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
||||||
|
|
||||||
|
|
||||||
def setup_loader(ap, r, is_val=False, verbose=False):
|
def setup_loader(ap, r, is_val=False, verbose=False, dataset=None):
|
||||||
if is_val and not c.run_eval:
|
if is_val and not c.run_eval:
|
||||||
loader = None
|
loader = None
|
||||||
else:
|
else:
|
||||||
dataset = MyDataset(
|
if dataset is None:
|
||||||
r,
|
dataset = MyDataset(
|
||||||
c.text_cleaner,
|
r,
|
||||||
compute_linear_spec=c.model.lower() == 'tacotron',
|
c.text_cleaner,
|
||||||
meta_data=meta_data_eval if is_val else meta_data_train,
|
compute_linear_spec=c.model.lower() == 'tacotron',
|
||||||
ap=ap,
|
meta_data=meta_data_eval if is_val else meta_data_train,
|
||||||
tp=c.characters if 'characters' in c.keys() else None,
|
ap=ap,
|
||||||
add_blank=c['add_blank'] if 'add_blank' in c.keys() else False,
|
tp=c.characters if 'characters' in c.keys() else None,
|
||||||
batch_group_size=0 if is_val else c.batch_group_size *
|
add_blank=c['add_blank'] if 'add_blank' in c.keys() else False,
|
||||||
c.batch_size,
|
batch_group_size=0 if is_val else c.batch_group_size *
|
||||||
min_seq_len=c.min_seq_len,
|
c.batch_size,
|
||||||
max_seq_len=c.max_seq_len,
|
min_seq_len=c.min_seq_len,
|
||||||
phoneme_cache_path=c.phoneme_cache_path,
|
max_seq_len=c.max_seq_len,
|
||||||
use_phonemes=c.use_phonemes,
|
phoneme_cache_path=c.phoneme_cache_path,
|
||||||
phoneme_language=c.phoneme_language,
|
use_phonemes=c.use_phonemes,
|
||||||
enable_eos_bos=c.enable_eos_bos_chars,
|
phoneme_language=c.phoneme_language,
|
||||||
verbose=verbose,
|
enable_eos_bos=c.enable_eos_bos_chars,
|
||||||
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None)
|
verbose=verbose,
|
||||||
|
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None)
|
||||||
|
|
||||||
if c.use_phonemes and c.compute_input_seq_cache:
|
if c.use_phonemes and c.compute_input_seq_cache:
|
||||||
# precompute phonemes to have a better estimate of sequence lengths.
|
# precompute phonemes to have a better estimate of sequence lengths.
|
||||||
dataset.compute_input_seq(c.num_loader_workers)
|
dataset.compute_input_seq(c.num_loader_workers)
|
||||||
dataset.sort_items()
|
dataset.sort_items()
|
||||||
|
|
||||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
|
@ -604,15 +605,23 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
if c.bidirectional_decoder:
|
if c.bidirectional_decoder:
|
||||||
model.decoder_backward.set_r(r)
|
model.decoder_backward.set_r(r)
|
||||||
train_loader.dataset.outputs_per_step = r
|
train_loader.dataset.outputs_per_step = r
|
||||||
train_loader.batch_size = c.batch_size
|
|
||||||
eval_loader.dataset.outputs_per_step = r
|
eval_loader.dataset.outputs_per_step = r
|
||||||
eval_loader.batch_size = c.batch_size
|
train_loader = setup_loader(ap,
|
||||||
|
model.decoder.r,
|
||||||
|
is_val=False,
|
||||||
|
dataset=train_loader.dataset)
|
||||||
|
eval_loader = setup_loader(ap,
|
||||||
|
model.decoder.r,
|
||||||
|
is_val=True,
|
||||||
|
dataset=eval_loader.dataset)
|
||||||
print("\n > Number of output frames:", model.decoder.r)
|
print("\n > Number of output frames:", model.decoder.r)
|
||||||
|
# train one epoch
|
||||||
train_avg_loss_dict, global_step = train(train_loader, model,
|
train_avg_loss_dict, global_step = train(train_loader, model,
|
||||||
criterion, optimizer,
|
criterion, optimizer,
|
||||||
optimizer_st, scheduler, ap,
|
optimizer_st, scheduler, ap,
|
||||||
global_step, epoch, scaler,
|
global_step, epoch, scaler,
|
||||||
scaler_st)
|
scaler_st)
|
||||||
|
# eval one epoch
|
||||||
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap,
|
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap,
|
||||||
global_step, epoch)
|
global_step, epoch)
|
||||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||||
|
|
Loading…
Reference in New Issue