conditional stopnet and separate stopnet training

This commit is contained in:
Eren Golge 2019-05-14 17:49:20 +02:00
parent 66453b81d3
commit 82d873b35a
1 changed files with 32 additions and 24 deletions

View File

@ -109,7 +109,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
if c.lr_decay: if c.lr_decay:
scheduler.step() scheduler.step()
optimizer.zero_grad() optimizer.zero_grad()
optimizer_st.zero_grad() if optimizer_st: optimizer_st.zero_grad();
# dispatch data to GPU # dispatch data to GPU
if use_cuda: if use_cuda:
@ -125,7 +125,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
text_input, text_lengths, mel_input) text_input, text_lengths, mel_input)
# loss computation # loss computation
stop_loss = criterion_st(stop_tokens, stop_targets) stop_loss = criterion_st(stop_tokens, stop_targets) if c.stopnet else torch.zeros(1)
if c.loss_masking: if c.loss_masking:
decoder_loss = criterion(decoder_output, mel_input, mel_lengths) decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
if c.model == "Tacotron": if c.model == "Tacotron":
@ -139,18 +139,26 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
else: else:
postnet_loss = criterion(postnet_output, mel_input) postnet_loss = criterion(postnet_output, mel_input)
loss = decoder_loss + postnet_loss loss = decoder_loss + postnet_loss
if not c.separate_stopnet and c.stopnet:
loss += stop_loss
# backpass and check the grad norm for spec losses # backpass and check the grad norm for spec losses
if c.separate_stopnet:
loss.backward(retain_graph=True) loss.backward(retain_graph=True)
else:
loss.backward()
optimizer, current_lr = weight_decay(optimizer, c.wd) optimizer, current_lr = weight_decay(optimizer, c.wd)
grad_norm, _ = check_update(model, c.grad_clip) grad_norm, _ = check_update(model, c.grad_clip)
optimizer.step() optimizer.step()
# backpass and check the grad norm for stop loss # backpass and check the grad norm for stop loss
if c.separate_stopnet:
stop_loss.backward() stop_loss.backward()
optimizer_st, _ = weight_decay(optimizer_st, c.wd) optimizer_st, _ = weight_decay(optimizer_st, c.wd)
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0) grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
optimizer_st.step() optimizer_st.step()
else:
grad_norm_st = 0
step_time = time.time() - start_time step_time = time.time() - start_time
epoch_time += step_time epoch_time += step_time
@ -175,7 +183,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
if args.rank == 0: if args.rank == 0:
avg_postnet_loss += float(postnet_loss.item()) avg_postnet_loss += float(postnet_loss.item())
avg_decoder_loss += float(decoder_loss.item()) avg_decoder_loss += float(decoder_loss.item())
avg_stop_loss += stop_loss.item() avg_stop_loss += stop_loss if type(stop_loss) is float else float(stop_loss.item())
avg_step_time += step_time avg_step_time += step_time
# Plot Training Iter Stats # Plot Training Iter Stats
@ -290,7 +298,7 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
model.forward(text_input, text_lengths, mel_input) model.forward(text_input, text_lengths, mel_input)
# loss computation # loss computation
stop_loss = criterion_st(stop_tokens, stop_targets) stop_loss = criterion_st(stop_tokens, stop_targets) if c.stopnet else torch.zeros(1)
if c.loss_masking: if c.loss_masking:
decoder_loss = criterion(decoder_output, mel_input, mel_lengths) decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
if c.model == "Tacotron": if c.model == "Tacotron":
@ -395,14 +403,17 @@ def main(args):
print(" | > Num output units : {}".format(ap.num_freq), flush=True) print(" | > Num output units : {}".format(ap.num_freq), flush=True)
optimizer = optim.Adam(model.parameters(), lr=c.lr, weight_decay=0) optimizer = optim.Adam(model.parameters(), lr=c.lr, weight_decay=0)
if c.stopnet and c.separate_stopnet:
optimizer_st = optim.Adam( optimizer_st = optim.Adam(
model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0) model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0)
else:
optimizer_st = None
if c.loss_masking: if c.loss_masking:
criterion = L1LossMasked() if c.model == "Tacotron" else MSELossMasked() criterion = L1LossMasked() if c.model == "Tacotron" else MSELossMasked()
else: else:
criterion = nn.L1Loss() if c.model == "Tacotron" else nn.MSELoss() criterion = nn.L1Loss() if c.model == "Tacotron" else nn.MSELoss()
criterion_st = nn.BCEWithLogitsLoss() criterion_st = nn.BCEWithLogitsLoss() if c.stopnet else None
if args.restore_path: if args.restore_path:
checkpoint = torch.load(args.restore_path) checkpoint = torch.load(args.restore_path)
@ -420,23 +431,19 @@ def main(args):
model_dict = set_init_dict(model_dict, checkpoint, c) model_dict = set_init_dict(model_dict, checkpoint, c)
model.load_state_dict(model_dict) model.load_state_dict(model_dict)
del model_dict del model_dict
if use_cuda:
model = model.cuda()
criterion.cuda()
criterion_st.cuda()
for group in optimizer.param_groups: for group in optimizer.param_groups:
group['lr'] = c.lr group['lr'] = c.lr
print( print(
" > Model restored from step %d" % checkpoint['step'], flush=True) " > Model restored from step %d" % checkpoint['step'], flush=True)
start_epoch = checkpoint['epoch'] start_epoch = checkpoint['epoch']
# best_loss = checkpoint['postnet_loss']
args.restore_step = checkpoint['step'] args.restore_step = checkpoint['step']
else: else:
args.restore_step = 0 args.restore_step = 0
if use_cuda: if use_cuda:
model = model.cuda() model = model.cuda()
criterion.cuda() criterion.cuda()
criterion_st.cuda() if criterion_st: criterion_st.cuda();
# DISTRUBUTED # DISTRUBUTED
if num_gpus > 1: if num_gpus > 1:
@ -457,9 +464,10 @@ def main(args):
best_loss = float('inf') best_loss = float('inf')
for epoch in range(0, c.epochs): for epoch in range(0, c.epochs):
train_loss, current_step = train(model, criterion, criterion_st, # train_loss, current_step = train(model, criterion, criterion_st,
optimizer, optimizer_st, scheduler, # optimizer, optimizer_st, scheduler,
ap, epoch) # ap, epoch)
train_loss, current_step = 0, 0
val_loss = evaluate(model, criterion, criterion_st, ap, current_step, epoch) val_loss = evaluate(model, criterion, criterion_st, ap, current_step, epoch)
print( print(
" | > Training Loss: {:.5f} Validation Loss: {:.5f}".format( " | > Training Loss: {:.5f} Validation Loss: {:.5f}".format(