enalbe graudal training by config.json

This commit is contained in:
Eren Golge 2019-07-22 02:11:20 +02:00
parent f038b1aa3f
commit ee706b50f6
2 changed files with 28 additions and 25 deletions

View File

@ -20,7 +20,7 @@ from utils.generic_utils import (NoamLR, check_update, count_parameters,
load_config, remove_experiment_folder, load_config, remove_experiment_folder,
save_best_model, save_checkpoint, weight_decay, save_best_model, save_checkpoint, weight_decay,
set_init_dict, copy_config_file, setup_model, set_init_dict, copy_config_file, setup_model,
split_dataset) split_dataset, gradual_training_scheduler)
from utils.logger import Logger from utils.logger import Logger
from utils.speakers import load_speaker_mapping, save_speaker_mapping, \ from utils.speakers import load_speaker_mapping, save_speaker_mapping, \
get_speakers get_speakers
@ -81,20 +81,6 @@ def setup_loader(ap, is_val=False, verbose=False):
return loader return loader
def gradual_training_scheduler(global_step):
if global_step < 10000:
r, batch_size = 7, 32
elif global_step < 50000:
r, batch_size = 5, 32
elif global_step < 130000:
r, batch_size = 3, 32
elif global_step < 290000:
r, batch_size = 2, 16
else:
r, batch_size = 1, 16
return r, batch_size
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
ap, epoch): ap, epoch):
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0)) data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
@ -106,8 +92,10 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
avg_decoder_loss = 0 avg_decoder_loss = 0
avg_stop_loss = 0 avg_stop_loss = 0
avg_step_time = 0 avg_step_time = 0
avg_loader_time = 0
print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True) print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True)
batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus)) batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus))
end_time = time.time()
for num_iter, data in enumerate(data_loader): for num_iter, data in enumerate(data_loader):
start_time = time.time() start_time = time.time()
@ -121,6 +109,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
stop_targets = data[6] stop_targets = data[6]
avg_text_length = torch.mean(text_lengths.float()) avg_text_length = torch.mean(text_lengths.float())
avg_spec_length = torch.mean(mel_lengths.float()) avg_spec_length = torch.mean(mel_lengths.float())
loader_time = time.time() - end_time
if c.use_speaker_embedding: if c.use_speaker_embedding:
speaker_ids = [speaker_mapping[speaker_name] speaker_ids = [speaker_mapping[speaker_name]
@ -191,17 +180,16 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
else: else:
grad_norm_st = 0 grad_norm_st = 0
step_time = time.time() - start_time
epoch_time += step_time
if current_step % c.print_step == 0: if current_step % c.print_step == 0:
print( print(
" | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} PostnetLoss:{:.5f} " " | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} PostnetLoss:{:.5f} "
"DecoderLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} " "DecoderLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "
"GradNormST:{:.5f} AvgTextLen:{:.1f} AvgSpecLen:{:.1f} StepTime:{:.2f} LR:{:.6f}".format( "GradNormST:{:.5f} AvgTextLen:{:.1f} AvgSpecLen:{:.1f} StepTime:{:.2f} "
"LoaderTime:{:.2f} LR:{:.6f}".format(
num_iter, batch_n_iter, current_step, loss.item(), num_iter, batch_n_iter, current_step, loss.item(),
postnet_loss.item(), decoder_loss.item(), stop_loss.item(), postnet_loss.item(), decoder_loss.item(), stop_loss.item(),
grad_norm, grad_norm_st, avg_text_length, avg_spec_length, step_time, current_lr), grad_norm, grad_norm_st, avg_text_length, avg_spec_length, step_time,
loader_time, current_lr),
flush=True) flush=True)
# aggregate losses from processes # aggregate losses from processes
@ -216,6 +204,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
avg_decoder_loss += float(decoder_loss.item()) avg_decoder_loss += float(decoder_loss.item())
avg_stop_loss += stop_loss if isinstance(stop_loss, float) else float(stop_loss.item()) avg_stop_loss += stop_loss if isinstance(stop_loss, float) else float(stop_loss.item())
avg_step_time += step_time avg_step_time += step_time
avg_loader_time += loader_time
# Plot Training Iter Stats # Plot Training Iter Stats
iter_stats = {"loss_posnet": postnet_loss.item(), iter_stats = {"loss_posnet": postnet_loss.item(),
@ -254,11 +243,15 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
{'TrainAudio': train_audio}, {'TrainAudio': train_audio},
c.audio["sample_rate"]) c.audio["sample_rate"])
step_time = end_time - start_time
epoch_time += step_time
avg_postnet_loss /= (num_iter + 1) avg_postnet_loss /= (num_iter + 1)
avg_decoder_loss /= (num_iter + 1) avg_decoder_loss /= (num_iter + 1)
avg_stop_loss /= (num_iter + 1) avg_stop_loss /= (num_iter + 1)
avg_total_loss = avg_decoder_loss + avg_postnet_loss + avg_stop_loss avg_total_loss = avg_decoder_loss + avg_postnet_loss + avg_stop_loss
avg_step_time /= (num_iter + 1) avg_step_time /= (num_iter + 1)
avg_loader_time /= (num_iter + 1)
# print epoch stats # print epoch stats
print( print(
@ -267,7 +260,8 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
"AvgStopLoss:{:.5f} EpochTime:{:.2f} " "AvgStopLoss:{:.5f} EpochTime:{:.2f} "
"AvgStepTime:{:.2f}".format(current_step, avg_total_loss, "AvgStepTime:{:.2f}".format(current_step, avg_total_loss,
avg_postnet_loss, avg_decoder_loss, avg_postnet_loss, avg_decoder_loss,
avg_stop_loss, epoch_time, avg_step_time), avg_stop_loss, epoch_time, avg_step_time,
avg_loader_time),
flush=True) flush=True)
# Plot Epoch Stats # Plot Epoch Stats
@ -281,6 +275,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
if c.tb_model_param_stats: if c.tb_model_param_stats:
tb_logger.tb_model_weights(model, current_step) tb_logger.tb_model_weights(model, current_step)
end_time = time.time()
return avg_postnet_loss, current_step return avg_postnet_loss, current_step
@ -541,9 +536,10 @@ def main(args): #pylint: disable=redefined-outer-name
current_step = 0 current_step = 0
for epoch in range(0, c.epochs): for epoch in range(0, c.epochs):
# set gradual training # set gradual training
r, c.batch_size = gradual_training_scheduler(current_step) if c.gradual_training is not None:
c.r = r r, c.batch_size = gradual_training_scheduler(current_step, c)
model.decoder._set_r(r) c.r = r
model.decoder._set_r(r)
print(" > Number of outputs per iteration:", model.decoder.r) print(" > Number of outputs per iteration:", model.decoder.r)
train_loss, current_step = train(model, criterion, criterion_st, train_loss, current_step = train(model, criterion, criterion_st,
@ -592,7 +588,7 @@ if __name__ == '__main__':
'--output_folder', '--output_folder',
type=str, type=str,
default='', default='',
help='folder name for traning outputs.' help='folder name for training outputs.'
) )
# DISTRUBUTED # DISTRUBUTED

View File

@ -305,3 +305,10 @@ def split_dataset(items):
else: else:
return items[:eval_split_size], items[eval_split_size:] return items[:eval_split_size], items[eval_split_size:]
def gradual_training_scheduler(global_step, config):
new_values = None
for values in config.gradual_training:
if global_step >= values[0]:
new_values = values
return new_values[1], new_values[2]