Use step wise LR scheduler + adapt train.py for passing squence mask directly

This commit is contained in:
Eren 2018-08-10 17:48:19 +02:00
parent e0bce1d2d1
commit 96e2e3c776
1 changed files with 40 additions and 48 deletions

View File

@ -16,11 +16,12 @@ from tensorboardX import SummaryWriter
from utils.generic_utils import ( from utils.generic_utils import (
synthesis, remove_experiment_folder, create_experiment_folder, synthesis, remove_experiment_folder, create_experiment_folder,
save_checkpoint, save_best_model, load_config, lr_decay, count_parameters, save_checkpoint, save_best_model, load_config, lr_decay, count_parameters,
check_update, get_commit_hash) check_update, get_commit_hash, sequence_mask)
from utils.visual import plot_alignment, plot_spectrogram from utils.visual import plot_alignment, plot_spectrogram
from models.tacotron import Tacotron from models.tacotron import Tacotron
from layers.losses import L1LossMasked from layers.losses import L1LossMasked
from utils.audio import AudioProcessor from utils.audio import AudioProcessor
from torch.optim.lr_scheduler import StepLR
torch.manual_seed(1) torch.manual_seed(1)
torch.set_num_threads(4) torch.set_num_threads(4)
@ -28,7 +29,7 @@ use_cuda = torch.cuda.is_available()
def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
ap, epoch): scheduler, ap, epoch):
model = model.train() model = model.train()
epoch_time = 0 epoch_time = 0
avg_linear_loss = 0 avg_linear_loss = 0
@ -58,15 +59,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
epoch * len(data_loader) + 1 epoch * len(data_loader) + 1
# setup lr # setup lr
current_lr = lr_decay(c.lr, current_step, c.warmup_steps) scheduler.step()
current_lr_st = lr_decay(c.lr, current_step, c.warmup_steps)
for params_group in optimizer.param_groups:
params_group['lr'] = current_lr
for params_group in optimizer_st.param_groups:
params_group['lr'] = current_lr_st
optimizer.zero_grad() optimizer.zero_grad()
optimizer_st.zero_grad() optimizer_st.zero_grad()
@ -79,9 +72,12 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
linear_input = linear_input.cuda() linear_input = linear_input.cuda()
stop_targets = stop_targets.cuda() stop_targets = stop_targets.cuda()
# compute mask for padding
mask = sequence_mask(text_lengths)
# forward pass # forward pass
mel_output, linear_output, alignments, stop_tokens =\ mel_output, linear_output, alignments, stop_tokens = torch.nn.parallel.data_parallel(
model.forward(text_input, mel_input, text_lengths) model, (text_input, mel_input, mask))
# loss computation # loss computation
stop_loss = criterion_st(stop_tokens, stop_targets) stop_loss = criterion_st(stop_tokens, stop_targets)
@ -94,7 +90,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
# backpass and check the grad norm for spec losses # backpass and check the grad norm for spec losses
loss.backward(retain_graph=True) loss.backward(retain_graph=True)
grad_norm, skip_flag = check_update(model, 0.5, 100) grad_norm, skip_flag = check_update(model, 1)
if skip_flag: if skip_flag:
optimizer.zero_grad() optimizer.zero_grad()
print(" | > Iteration skipped!!", flush=True) print(" | > Iteration skipped!!", flush=True)
@ -103,8 +99,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
# backpass and check the grad norm for stop loss # backpass and check the grad norm for stop loss
stop_loss.backward() stop_loss.backward()
grad_norm_st, skip_flag = check_update(model.module.decoder.stopnet, grad_norm_st, skip_flag = check_update(model.decoder.stopnet, 0.5)
0.5, 100)
if skip_flag: if skip_flag:
optimizer_st.zero_grad() optimizer_st.zero_grad()
print(" | | > Iteration skipped fro stopnet!!") print(" | | > Iteration skipped fro stopnet!!")
@ -115,18 +110,14 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
epoch_time += step_time epoch_time += step_time
if current_step % c.print_step == 0: if current_step % c.print_step == 0:
print(" | | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} " print(
"MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} " " | | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} "
"GradNormST:{:.5f} StepTime:{:.2f}".format(num_iter, "MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "
batch_n_iter, "GradNormST:{:.5f} StepTime:{:.2f}".format(
current_step, num_iter, batch_n_iter, current_step, loss.item(),
loss.item(), linear_loss.item(), mel_loss.item(), stop_loss.item(),
linear_loss.item(), grad_norm.item(), grad_norm_st.item(), step_time),
mel_loss.item(), flush=True)
stop_loss.item(),
grad_norm.item(),
grad_norm_st.item(),
step_time), flush=True)
avg_linear_loss += linear_loss.item() avg_linear_loss += linear_loss.item()
avg_mel_loss += mel_loss.item() avg_mel_loss += mel_loss.item()
@ -184,16 +175,14 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
avg_step_time /= (num_iter + 1) avg_step_time /= (num_iter + 1)
# print epoch stats # print epoch stats
print(" | | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} " print(
"AvgLinearLoss:{:.5f} AvgMelLoss:{:.5f} " " | | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "
"AvgStopLoss:{:.5f} EpochTime:{:.2f} " "AvgLinearLoss:{:.5f} AvgMelLoss:{:.5f} "
"AvgStepTime:{:.2f}".format(current_step, "AvgStopLoss:{:.5f} EpochTime:{:.2f} "
avg_total_loss, "AvgStepTime:{:.2f}".format(current_step, avg_total_loss,
avg_linear_loss, avg_linear_loss, avg_mel_loss,
avg_mel_loss, avg_stop_loss, epoch_time, avg_step_time),
avg_stop_loss, flush=True)
epoch_time,
avg_step_time), flush=True)
# Plot Training Epoch Stats # Plot Training Epoch Stats
tb.add_scalar('TrainEpochLoss/TotalLoss', avg_total_loss, current_step) tb.add_scalar('TrainEpochLoss/TotalLoss', avg_total_loss, current_step)
@ -266,8 +255,10 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
if num_iter % c.print_step == 0: if num_iter % c.print_step == 0:
print( print(
" | | > TotalLoss: {:.5f} LinearLoss: {:.5f} MelLoss:{:.5f} " " | | > TotalLoss: {:.5f} LinearLoss: {:.5f} MelLoss:{:.5f} "
"StopLoss: {:.5f} ".format(loss.item(), linear_loss.item(), "StopLoss: {:.5f} ".format(loss.item(),
mel_loss.item(), stop_loss.item()), linear_loss.item(),
mel_loss.item(),
stop_loss.item()),
flush=True) flush=True)
avg_linear_loss += linear_loss.item() avg_linear_loss += linear_loss.item()
@ -322,11 +313,11 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
ap.griffin_lim_iters = 60 ap.griffin_lim_iters = 60
for idx, test_sentence in enumerate(test_sentences): for idx, test_sentence in enumerate(test_sentences):
try: try:
wav, linear_spec, alignments = synthesis(model, ap, test_sentence, wav, linear_spec, alignments = synthesis(model, ap, test_sentence,
use_cuda, c.text_cleaner) use_cuda, c.text_cleaner)
wav_name = 'TestSentences/{}'.format(idx) wav_name = 'TestSentences/{}'.format(idx)
tb.add_audio( tb.add_audio(
wav_name, wav, current_step, sample_rate=c.sample_rate) wav_name, wav, current_step, sample_rate=c.sample_rate)
except: except:
print(" !! Error as creating Test Sentence -", idx) print(" !! Error as creating Test Sentence -", idx)
pass pass
@ -405,7 +396,7 @@ def main(args):
checkpoint = torch.load(args.restore_path) checkpoint = torch.load(args.restore_path)
model.load_state_dict(checkpoint['model']) model.load_state_dict(checkpoint['model'])
if use_cuda: if use_cuda:
model = nn.DataParallel(model.cuda()) model = model.cuda()
criterion.cuda() criterion.cuda()
criterion_st.cuda() criterion_st.cuda()
optimizer.load_state_dict(checkpoint['optimizer']) optimizer.load_state_dict(checkpoint['optimizer'])
@ -423,10 +414,11 @@ def main(args):
args.restore_step = 0 args.restore_step = 0
print("\n > Starting a new training", flush=True) print("\n > Starting a new training", flush=True)
if use_cuda: if use_cuda:
model = nn.DataParallel(model.cuda()) model = model.cuda()
criterion.cuda() criterion.cuda()
criterion_st.cuda() criterion_st.cuda()
scheduler = StepLR(optimizer, step_size=c.decay_step, gamma=c.lr_decay)
num_params = count_parameters(model) num_params = count_parameters(model)
print(" | > Model has {} parameters".format(num_params), flush=True) print(" | > Model has {} parameters".format(num_params), flush=True)
@ -439,7 +431,7 @@ def main(args):
for epoch in range(0, c.epochs): for epoch in range(0, c.epochs):
train_loss, current_step = train(model, criterion, criterion_st, train_loss, current_step = train(model, criterion, criterion_st,
train_loader, optimizer, optimizer_st, train_loader, optimizer, optimizer_st,
ap, epoch) scheduler, ap, epoch)
val_loss = evaluate(model, criterion, criterion_st, val_loader, ap, val_loss = evaluate(model, criterion, criterion_st, val_loader, ap,
current_step) current_step)
print( print(