mirror of https://github.com/coqui-ai/TTS.git
conditional stopnet and separate stopnet training
This commit is contained in:
parent
66453b81d3
commit
82d873b35a
56
train.py
56
train.py
|
@ -109,7 +109,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
if c.lr_decay:
|
if c.lr_decay:
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
optimizer_st.zero_grad()
|
if optimizer_st: optimizer_st.zero_grad();
|
||||||
|
|
||||||
# dispatch data to GPU
|
# dispatch data to GPU
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
|
@ -125,7 +125,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
text_input, text_lengths, mel_input)
|
text_input, text_lengths, mel_input)
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
stop_loss = criterion_st(stop_tokens, stop_targets) if c.stopnet else torch.zeros(1)
|
||||||
if c.loss_masking:
|
if c.loss_masking:
|
||||||
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
|
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
|
||||||
if c.model == "Tacotron":
|
if c.model == "Tacotron":
|
||||||
|
@ -139,18 +139,26 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
else:
|
else:
|
||||||
postnet_loss = criterion(postnet_output, mel_input)
|
postnet_loss = criterion(postnet_output, mel_input)
|
||||||
loss = decoder_loss + postnet_loss
|
loss = decoder_loss + postnet_loss
|
||||||
|
if not c.separate_stopnet and c.stopnet:
|
||||||
|
loss += stop_loss
|
||||||
|
|
||||||
# backpass and check the grad norm for spec losses
|
# backpass and check the grad norm for spec losses
|
||||||
loss.backward(retain_graph=True)
|
if c.separate_stopnet:
|
||||||
|
loss.backward(retain_graph=True)
|
||||||
|
else:
|
||||||
|
loss.backward()
|
||||||
optimizer, current_lr = weight_decay(optimizer, c.wd)
|
optimizer, current_lr = weight_decay(optimizer, c.wd)
|
||||||
grad_norm, _ = check_update(model, c.grad_clip)
|
grad_norm, _ = check_update(model, c.grad_clip)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
# backpass and check the grad norm for stop loss
|
# backpass and check the grad norm for stop loss
|
||||||
stop_loss.backward()
|
if c.separate_stopnet:
|
||||||
optimizer_st, _ = weight_decay(optimizer_st, c.wd)
|
stop_loss.backward()
|
||||||
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
|
optimizer_st, _ = weight_decay(optimizer_st, c.wd)
|
||||||
optimizer_st.step()
|
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
|
||||||
|
optimizer_st.step()
|
||||||
|
else:
|
||||||
|
grad_norm_st = 0
|
||||||
|
|
||||||
step_time = time.time() - start_time
|
step_time = time.time() - start_time
|
||||||
epoch_time += step_time
|
epoch_time += step_time
|
||||||
|
@ -175,7 +183,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
if args.rank == 0:
|
if args.rank == 0:
|
||||||
avg_postnet_loss += float(postnet_loss.item())
|
avg_postnet_loss += float(postnet_loss.item())
|
||||||
avg_decoder_loss += float(decoder_loss.item())
|
avg_decoder_loss += float(decoder_loss.item())
|
||||||
avg_stop_loss += stop_loss.item()
|
avg_stop_loss += stop_loss if type(stop_loss) is float else float(stop_loss.item())
|
||||||
avg_step_time += step_time
|
avg_step_time += step_time
|
||||||
|
|
||||||
# Plot Training Iter Stats
|
# Plot Training Iter Stats
|
||||||
|
@ -290,7 +298,7 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
||||||
model.forward(text_input, text_lengths, mel_input)
|
model.forward(text_input, text_lengths, mel_input)
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
stop_loss = criterion_st(stop_tokens, stop_targets) if c.stopnet else torch.zeros(1)
|
||||||
if c.loss_masking:
|
if c.loss_masking:
|
||||||
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
|
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
|
||||||
if c.model == "Tacotron":
|
if c.model == "Tacotron":
|
||||||
|
@ -395,14 +403,17 @@ def main(args):
|
||||||
print(" | > Num output units : {}".format(ap.num_freq), flush=True)
|
print(" | > Num output units : {}".format(ap.num_freq), flush=True)
|
||||||
|
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr, weight_decay=0)
|
optimizer = optim.Adam(model.parameters(), lr=c.lr, weight_decay=0)
|
||||||
optimizer_st = optim.Adam(
|
if c.stopnet and c.separate_stopnet:
|
||||||
model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0)
|
optimizer_st = optim.Adam(
|
||||||
|
model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0)
|
||||||
|
else:
|
||||||
|
optimizer_st = None
|
||||||
|
|
||||||
if c.loss_masking:
|
if c.loss_masking:
|
||||||
criterion = L1LossMasked() if c.model == "Tacotron" else MSELossMasked()
|
criterion = L1LossMasked() if c.model == "Tacotron" else MSELossMasked()
|
||||||
else:
|
else:
|
||||||
criterion = nn.L1Loss() if c.model == "Tacotron" else nn.MSELoss()
|
criterion = nn.L1Loss() if c.model == "Tacotron" else nn.MSELoss()
|
||||||
criterion_st = nn.BCEWithLogitsLoss()
|
criterion_st = nn.BCEWithLogitsLoss() if c.stopnet else None
|
||||||
|
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
checkpoint = torch.load(args.restore_path)
|
checkpoint = torch.load(args.restore_path)
|
||||||
|
@ -420,23 +431,19 @@ def main(args):
|
||||||
model_dict = set_init_dict(model_dict, checkpoint, c)
|
model_dict = set_init_dict(model_dict, checkpoint, c)
|
||||||
model.load_state_dict(model_dict)
|
model.load_state_dict(model_dict)
|
||||||
del model_dict
|
del model_dict
|
||||||
if use_cuda:
|
|
||||||
model = model.cuda()
|
|
||||||
criterion.cuda()
|
|
||||||
criterion_st.cuda()
|
|
||||||
for group in optimizer.param_groups:
|
for group in optimizer.param_groups:
|
||||||
group['lr'] = c.lr
|
group['lr'] = c.lr
|
||||||
print(
|
print(
|
||||||
" > Model restored from step %d" % checkpoint['step'], flush=True)
|
" > Model restored from step %d" % checkpoint['step'], flush=True)
|
||||||
start_epoch = checkpoint['epoch']
|
start_epoch = checkpoint['epoch']
|
||||||
# best_loss = checkpoint['postnet_loss']
|
|
||||||
args.restore_step = checkpoint['step']
|
args.restore_step = checkpoint['step']
|
||||||
else:
|
else:
|
||||||
args.restore_step = 0
|
args.restore_step = 0
|
||||||
if use_cuda:
|
|
||||||
model = model.cuda()
|
if use_cuda:
|
||||||
criterion.cuda()
|
model = model.cuda()
|
||||||
criterion_st.cuda()
|
criterion.cuda()
|
||||||
|
if criterion_st: criterion_st.cuda();
|
||||||
|
|
||||||
# DISTRUBUTED
|
# DISTRUBUTED
|
||||||
if num_gpus > 1:
|
if num_gpus > 1:
|
||||||
|
@ -457,9 +464,10 @@ def main(args):
|
||||||
best_loss = float('inf')
|
best_loss = float('inf')
|
||||||
|
|
||||||
for epoch in range(0, c.epochs):
|
for epoch in range(0, c.epochs):
|
||||||
train_loss, current_step = train(model, criterion, criterion_st,
|
# train_loss, current_step = train(model, criterion, criterion_st,
|
||||||
optimizer, optimizer_st, scheduler,
|
# optimizer, optimizer_st, scheduler,
|
||||||
ap, epoch)
|
# ap, epoch)
|
||||||
|
train_loss, current_step = 0, 0
|
||||||
val_loss = evaluate(model, criterion, criterion_st, ap, current_step, epoch)
|
val_loss = evaluate(model, criterion, criterion_st, ap, current_step, epoch)
|
||||||
print(
|
print(
|
||||||
" | > Training Loss: {:.5f} Validation Loss: {:.5f}".format(
|
" | > Training Loss: {:.5f} Validation Loss: {:.5f}".format(
|
||||||
|
|
Loading…
Reference in New Issue