mirror of https://github.com/coqui-ai/TTS.git
bug fix for gradual training
This commit is contained in:
parent
5c50e104d6
commit
29b17c0808
|
@ -39,10 +39,11 @@ from TTS.utils.training import (NoamLR, adam_weight_decay, check_update,
|
|||
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
||||
|
||||
|
||||
def setup_loader(ap, r, is_val=False, verbose=False):
|
||||
def setup_loader(ap, r, is_val=False, verbose=False, dataset=None):
|
||||
if is_val and not c.run_eval:
|
||||
loader = None
|
||||
else:
|
||||
if dataset is None:
|
||||
dataset = MyDataset(
|
||||
r,
|
||||
c.text_cleaner,
|
||||
|
@ -604,15 +605,23 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
if c.bidirectional_decoder:
|
||||
model.decoder_backward.set_r(r)
|
||||
train_loader.dataset.outputs_per_step = r
|
||||
train_loader.batch_size = c.batch_size
|
||||
eval_loader.dataset.outputs_per_step = r
|
||||
eval_loader.batch_size = c.batch_size
|
||||
train_loader = setup_loader(ap,
|
||||
model.decoder.r,
|
||||
is_val=False,
|
||||
dataset=train_loader.dataset)
|
||||
eval_loader = setup_loader(ap,
|
||||
model.decoder.r,
|
||||
is_val=True,
|
||||
dataset=eval_loader.dataset)
|
||||
print("\n > Number of output frames:", model.decoder.r)
|
||||
# train one epoch
|
||||
train_avg_loss_dict, global_step = train(train_loader, model,
|
||||
criterion, optimizer,
|
||||
optimizer_st, scheduler, ap,
|
||||
global_step, epoch, scaler,
|
||||
scaler_st)
|
||||
# eval one epoch
|
||||
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap,
|
||||
global_step, epoch)
|
||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||
|
|
Loading…
Reference in New Issue