bug fix for gradual training

This commit is contained in:
erogol 2020-12-13 02:36:23 +01:00
parent 5c50e104d6
commit 29b17c0808
1 changed files with 34 additions and 25 deletions

View File

@ -39,33 +39,34 @@ from TTS.utils.training import (NoamLR, adam_weight_decay, check_update,
use_cuda, num_gpus = setup_torch_training_env(True, False) use_cuda, num_gpus = setup_torch_training_env(True, False)
def setup_loader(ap, r, is_val=False, verbose=False): def setup_loader(ap, r, is_val=False, verbose=False, dataset=None):
if is_val and not c.run_eval: if is_val and not c.run_eval:
loader = None loader = None
else: else:
dataset = MyDataset( if dataset is None:
r, dataset = MyDataset(
c.text_cleaner, r,
compute_linear_spec=c.model.lower() == 'tacotron', c.text_cleaner,
meta_data=meta_data_eval if is_val else meta_data_train, compute_linear_spec=c.model.lower() == 'tacotron',
ap=ap, meta_data=meta_data_eval if is_val else meta_data_train,
tp=c.characters if 'characters' in c.keys() else None, ap=ap,
add_blank=c['add_blank'] if 'add_blank' in c.keys() else False, tp=c.characters if 'characters' in c.keys() else None,
batch_group_size=0 if is_val else c.batch_group_size * add_blank=c['add_blank'] if 'add_blank' in c.keys() else False,
c.batch_size, batch_group_size=0 if is_val else c.batch_group_size *
min_seq_len=c.min_seq_len, c.batch_size,
max_seq_len=c.max_seq_len, min_seq_len=c.min_seq_len,
phoneme_cache_path=c.phoneme_cache_path, max_seq_len=c.max_seq_len,
use_phonemes=c.use_phonemes, phoneme_cache_path=c.phoneme_cache_path,
phoneme_language=c.phoneme_language, use_phonemes=c.use_phonemes,
enable_eos_bos=c.enable_eos_bos_chars, phoneme_language=c.phoneme_language,
verbose=verbose, enable_eos_bos=c.enable_eos_bos_chars,
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None) verbose=verbose,
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None)
if c.use_phonemes and c.compute_input_seq_cache: if c.use_phonemes and c.compute_input_seq_cache:
# precompute phonemes to have a better estimate of sequence lengths. # precompute phonemes to have a better estimate of sequence lengths.
dataset.compute_input_seq(c.num_loader_workers) dataset.compute_input_seq(c.num_loader_workers)
dataset.sort_items() dataset.sort_items()
sampler = DistributedSampler(dataset) if num_gpus > 1 else None sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader( loader = DataLoader(
@ -604,15 +605,23 @@ def main(args): # pylint: disable=redefined-outer-name
if c.bidirectional_decoder: if c.bidirectional_decoder:
model.decoder_backward.set_r(r) model.decoder_backward.set_r(r)
train_loader.dataset.outputs_per_step = r train_loader.dataset.outputs_per_step = r
train_loader.batch_size = c.batch_size
eval_loader.dataset.outputs_per_step = r eval_loader.dataset.outputs_per_step = r
eval_loader.batch_size = c.batch_size train_loader = setup_loader(ap,
model.decoder.r,
is_val=False,
dataset=train_loader.dataset)
eval_loader = setup_loader(ap,
model.decoder.r,
is_val=True,
dataset=eval_loader.dataset)
print("\n > Number of output frames:", model.decoder.r) print("\n > Number of output frames:", model.decoder.r)
# train one epoch
train_avg_loss_dict, global_step = train(train_loader, model, train_avg_loss_dict, global_step = train(train_loader, model,
criterion, optimizer, criterion, optimizer,
optimizer_st, scheduler, ap, optimizer_st, scheduler, ap,
global_step, epoch, scaler, global_step, epoch, scaler,
scaler_st) scaler_st)
# eval one epoch
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap,
global_step, epoch) global_step, epoch)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict) c_logger.print_epoch_end(epoch, eval_avg_loss_dict)