From 29b17c080899075ad3851c7f0a906d5b4ca272ec Mon Sep 17 00:00:00 2001 From: erogol Date: Sun, 13 Dec 2020 02:36:23 +0100 Subject: [PATCH] bug fix for gradual training --- TTS/bin/train_tacotron.py | 59 ++++++++++++++++++++++----------------- 1 file changed, 34 insertions(+), 25 deletions(-) diff --git a/TTS/bin/train_tacotron.py b/TTS/bin/train_tacotron.py index f75b44af..8288023c 100644 --- a/TTS/bin/train_tacotron.py +++ b/TTS/bin/train_tacotron.py @@ -39,33 +39,34 @@ from TTS.utils.training import (NoamLR, adam_weight_decay, check_update, use_cuda, num_gpus = setup_torch_training_env(True, False) -def setup_loader(ap, r, is_val=False, verbose=False): +def setup_loader(ap, r, is_val=False, verbose=False, dataset=None): if is_val and not c.run_eval: loader = None else: - dataset = MyDataset( - r, - c.text_cleaner, - compute_linear_spec=c.model.lower() == 'tacotron', - meta_data=meta_data_eval if is_val else meta_data_train, - ap=ap, - tp=c.characters if 'characters' in c.keys() else None, - add_blank=c['add_blank'] if 'add_blank' in c.keys() else False, - batch_group_size=0 if is_val else c.batch_group_size * - c.batch_size, - min_seq_len=c.min_seq_len, - max_seq_len=c.max_seq_len, - phoneme_cache_path=c.phoneme_cache_path, - use_phonemes=c.use_phonemes, - phoneme_language=c.phoneme_language, - enable_eos_bos=c.enable_eos_bos_chars, - verbose=verbose, - speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None) + if dataset is None: + dataset = MyDataset( + r, + c.text_cleaner, + compute_linear_spec=c.model.lower() == 'tacotron', + meta_data=meta_data_eval if is_val else meta_data_train, + ap=ap, + tp=c.characters if 'characters' in c.keys() else None, + add_blank=c['add_blank'] if 'add_blank' in c.keys() else False, + batch_group_size=0 if is_val else c.batch_group_size * + c.batch_size, + min_seq_len=c.min_seq_len, + max_seq_len=c.max_seq_len, + phoneme_cache_path=c.phoneme_cache_path, + use_phonemes=c.use_phonemes, + phoneme_language=c.phoneme_language, + enable_eos_bos=c.enable_eos_bos_chars, + verbose=verbose, + speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None) - if c.use_phonemes and c.compute_input_seq_cache: - # precompute phonemes to have a better estimate of sequence lengths. - dataset.compute_input_seq(c.num_loader_workers) - dataset.sort_items() + if c.use_phonemes and c.compute_input_seq_cache: + # precompute phonemes to have a better estimate of sequence lengths. + dataset.compute_input_seq(c.num_loader_workers) + dataset.sort_items() sampler = DistributedSampler(dataset) if num_gpus > 1 else None loader = DataLoader( @@ -604,15 +605,23 @@ def main(args): # pylint: disable=redefined-outer-name if c.bidirectional_decoder: model.decoder_backward.set_r(r) train_loader.dataset.outputs_per_step = r - train_loader.batch_size = c.batch_size eval_loader.dataset.outputs_per_step = r - eval_loader.batch_size = c.batch_size + train_loader = setup_loader(ap, + model.decoder.r, + is_val=False, + dataset=train_loader.dataset) + eval_loader = setup_loader(ap, + model.decoder.r, + is_val=True, + dataset=eval_loader.dataset) print("\n > Number of output frames:", model.decoder.r) + # train one epoch train_avg_loss_dict, global_step = train(train_loader, model, criterion, optimizer, optimizer_st, scheduler, ap, global_step, epoch, scaler, scaler_st) + # eval one epoch eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, global_step, epoch) c_logger.print_epoch_end(epoch, eval_avg_loss_dict)