mirror of https://github.com/coqui-ai/TTS.git
Use step wise LR scheduler + adapt train.py for passing squence mask directly
This commit is contained in:
parent
e0bce1d2d1
commit
96e2e3c776
88
train.py
88
train.py
|
@ -16,11 +16,12 @@ from tensorboardX import SummaryWriter
|
||||||
from utils.generic_utils import (
|
from utils.generic_utils import (
|
||||||
synthesis, remove_experiment_folder, create_experiment_folder,
|
synthesis, remove_experiment_folder, create_experiment_folder,
|
||||||
save_checkpoint, save_best_model, load_config, lr_decay, count_parameters,
|
save_checkpoint, save_best_model, load_config, lr_decay, count_parameters,
|
||||||
check_update, get_commit_hash)
|
check_update, get_commit_hash, sequence_mask)
|
||||||
from utils.visual import plot_alignment, plot_spectrogram
|
from utils.visual import plot_alignment, plot_spectrogram
|
||||||
from models.tacotron import Tacotron
|
from models.tacotron import Tacotron
|
||||||
from layers.losses import L1LossMasked
|
from layers.losses import L1LossMasked
|
||||||
from utils.audio import AudioProcessor
|
from utils.audio import AudioProcessor
|
||||||
|
from torch.optim.lr_scheduler import StepLR
|
||||||
|
|
||||||
torch.manual_seed(1)
|
torch.manual_seed(1)
|
||||||
torch.set_num_threads(4)
|
torch.set_num_threads(4)
|
||||||
|
@ -28,7 +29,7 @@ use_cuda = torch.cuda.is_available()
|
||||||
|
|
||||||
|
|
||||||
def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
ap, epoch):
|
scheduler, ap, epoch):
|
||||||
model = model.train()
|
model = model.train()
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
avg_linear_loss = 0
|
avg_linear_loss = 0
|
||||||
|
@ -58,15 +59,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
epoch * len(data_loader) + 1
|
epoch * len(data_loader) + 1
|
||||||
|
|
||||||
# setup lr
|
# setup lr
|
||||||
current_lr = lr_decay(c.lr, current_step, c.warmup_steps)
|
scheduler.step()
|
||||||
current_lr_st = lr_decay(c.lr, current_step, c.warmup_steps)
|
|
||||||
|
|
||||||
for params_group in optimizer.param_groups:
|
|
||||||
params_group['lr'] = current_lr
|
|
||||||
|
|
||||||
for params_group in optimizer_st.param_groups:
|
|
||||||
params_group['lr'] = current_lr_st
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
optimizer_st.zero_grad()
|
optimizer_st.zero_grad()
|
||||||
|
|
||||||
|
@ -79,9 +72,12 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
linear_input = linear_input.cuda()
|
linear_input = linear_input.cuda()
|
||||||
stop_targets = stop_targets.cuda()
|
stop_targets = stop_targets.cuda()
|
||||||
|
|
||||||
|
# compute mask for padding
|
||||||
|
mask = sequence_mask(text_lengths)
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
mel_output, linear_output, alignments, stop_tokens =\
|
mel_output, linear_output, alignments, stop_tokens = torch.nn.parallel.data_parallel(
|
||||||
model.forward(text_input, mel_input, text_lengths)
|
model, (text_input, mel_input, mask))
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||||
|
@ -94,7 +90,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
|
|
||||||
# backpass and check the grad norm for spec losses
|
# backpass and check the grad norm for spec losses
|
||||||
loss.backward(retain_graph=True)
|
loss.backward(retain_graph=True)
|
||||||
grad_norm, skip_flag = check_update(model, 0.5, 100)
|
grad_norm, skip_flag = check_update(model, 1)
|
||||||
if skip_flag:
|
if skip_flag:
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
print(" | > Iteration skipped!!", flush=True)
|
print(" | > Iteration skipped!!", flush=True)
|
||||||
|
@ -103,8 +99,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
|
|
||||||
# backpass and check the grad norm for stop loss
|
# backpass and check the grad norm for stop loss
|
||||||
stop_loss.backward()
|
stop_loss.backward()
|
||||||
grad_norm_st, skip_flag = check_update(model.module.decoder.stopnet,
|
grad_norm_st, skip_flag = check_update(model.decoder.stopnet, 0.5)
|
||||||
0.5, 100)
|
|
||||||
if skip_flag:
|
if skip_flag:
|
||||||
optimizer_st.zero_grad()
|
optimizer_st.zero_grad()
|
||||||
print(" | | > Iteration skipped fro stopnet!!")
|
print(" | | > Iteration skipped fro stopnet!!")
|
||||||
|
@ -115,18 +110,14 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
epoch_time += step_time
|
epoch_time += step_time
|
||||||
|
|
||||||
if current_step % c.print_step == 0:
|
if current_step % c.print_step == 0:
|
||||||
print(" | | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} "
|
print(
|
||||||
"MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "
|
" | | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} "
|
||||||
"GradNormST:{:.5f} StepTime:{:.2f}".format(num_iter,
|
"MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "
|
||||||
batch_n_iter,
|
"GradNormST:{:.5f} StepTime:{:.2f}".format(
|
||||||
current_step,
|
num_iter, batch_n_iter, current_step, loss.item(),
|
||||||
loss.item(),
|
linear_loss.item(), mel_loss.item(), stop_loss.item(),
|
||||||
linear_loss.item(),
|
grad_norm.item(), grad_norm_st.item(), step_time),
|
||||||
mel_loss.item(),
|
flush=True)
|
||||||
stop_loss.item(),
|
|
||||||
grad_norm.item(),
|
|
||||||
grad_norm_st.item(),
|
|
||||||
step_time), flush=True)
|
|
||||||
|
|
||||||
avg_linear_loss += linear_loss.item()
|
avg_linear_loss += linear_loss.item()
|
||||||
avg_mel_loss += mel_loss.item()
|
avg_mel_loss += mel_loss.item()
|
||||||
|
@ -184,16 +175,14 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
avg_step_time /= (num_iter + 1)
|
avg_step_time /= (num_iter + 1)
|
||||||
|
|
||||||
# print epoch stats
|
# print epoch stats
|
||||||
print(" | | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "
|
print(
|
||||||
"AvgLinearLoss:{:.5f} AvgMelLoss:{:.5f} "
|
" | | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "
|
||||||
"AvgStopLoss:{:.5f} EpochTime:{:.2f} "
|
"AvgLinearLoss:{:.5f} AvgMelLoss:{:.5f} "
|
||||||
"AvgStepTime:{:.2f}".format(current_step,
|
"AvgStopLoss:{:.5f} EpochTime:{:.2f} "
|
||||||
avg_total_loss,
|
"AvgStepTime:{:.2f}".format(current_step, avg_total_loss,
|
||||||
avg_linear_loss,
|
avg_linear_loss, avg_mel_loss,
|
||||||
avg_mel_loss,
|
avg_stop_loss, epoch_time, avg_step_time),
|
||||||
avg_stop_loss,
|
flush=True)
|
||||||
epoch_time,
|
|
||||||
avg_step_time), flush=True)
|
|
||||||
|
|
||||||
# Plot Training Epoch Stats
|
# Plot Training Epoch Stats
|
||||||
tb.add_scalar('TrainEpochLoss/TotalLoss', avg_total_loss, current_step)
|
tb.add_scalar('TrainEpochLoss/TotalLoss', avg_total_loss, current_step)
|
||||||
|
@ -266,8 +255,10 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
|
||||||
if num_iter % c.print_step == 0:
|
if num_iter % c.print_step == 0:
|
||||||
print(
|
print(
|
||||||
" | | > TotalLoss: {:.5f} LinearLoss: {:.5f} MelLoss:{:.5f} "
|
" | | > TotalLoss: {:.5f} LinearLoss: {:.5f} MelLoss:{:.5f} "
|
||||||
"StopLoss: {:.5f} ".format(loss.item(), linear_loss.item(),
|
"StopLoss: {:.5f} ".format(loss.item(),
|
||||||
mel_loss.item(), stop_loss.item()),
|
linear_loss.item(),
|
||||||
|
mel_loss.item(),
|
||||||
|
stop_loss.item()),
|
||||||
flush=True)
|
flush=True)
|
||||||
|
|
||||||
avg_linear_loss += linear_loss.item()
|
avg_linear_loss += linear_loss.item()
|
||||||
|
@ -322,11 +313,11 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
|
||||||
ap.griffin_lim_iters = 60
|
ap.griffin_lim_iters = 60
|
||||||
for idx, test_sentence in enumerate(test_sentences):
|
for idx, test_sentence in enumerate(test_sentences):
|
||||||
try:
|
try:
|
||||||
wav, linear_spec, alignments = synthesis(model, ap, test_sentence,
|
wav, linear_spec, alignments = synthesis(model, ap, test_sentence,
|
||||||
use_cuda, c.text_cleaner)
|
use_cuda, c.text_cleaner)
|
||||||
wav_name = 'TestSentences/{}'.format(idx)
|
wav_name = 'TestSentences/{}'.format(idx)
|
||||||
tb.add_audio(
|
tb.add_audio(
|
||||||
wav_name, wav, current_step, sample_rate=c.sample_rate)
|
wav_name, wav, current_step, sample_rate=c.sample_rate)
|
||||||
except:
|
except:
|
||||||
print(" !! Error as creating Test Sentence -", idx)
|
print(" !! Error as creating Test Sentence -", idx)
|
||||||
pass
|
pass
|
||||||
|
@ -405,7 +396,7 @@ def main(args):
|
||||||
checkpoint = torch.load(args.restore_path)
|
checkpoint = torch.load(args.restore_path)
|
||||||
model.load_state_dict(checkpoint['model'])
|
model.load_state_dict(checkpoint['model'])
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
model = nn.DataParallel(model.cuda())
|
model = model.cuda()
|
||||||
criterion.cuda()
|
criterion.cuda()
|
||||||
criterion_st.cuda()
|
criterion_st.cuda()
|
||||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||||
|
@ -423,10 +414,11 @@ def main(args):
|
||||||
args.restore_step = 0
|
args.restore_step = 0
|
||||||
print("\n > Starting a new training", flush=True)
|
print("\n > Starting a new training", flush=True)
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
model = nn.DataParallel(model.cuda())
|
model = model.cuda()
|
||||||
criterion.cuda()
|
criterion.cuda()
|
||||||
criterion_st.cuda()
|
criterion_st.cuda()
|
||||||
|
|
||||||
|
scheduler = StepLR(optimizer, step_size=c.decay_step, gamma=c.lr_decay)
|
||||||
num_params = count_parameters(model)
|
num_params = count_parameters(model)
|
||||||
print(" | > Model has {} parameters".format(num_params), flush=True)
|
print(" | > Model has {} parameters".format(num_params), flush=True)
|
||||||
|
|
||||||
|
@ -439,7 +431,7 @@ def main(args):
|
||||||
for epoch in range(0, c.epochs):
|
for epoch in range(0, c.epochs):
|
||||||
train_loss, current_step = train(model, criterion, criterion_st,
|
train_loss, current_step = train(model, criterion, criterion_st,
|
||||||
train_loader, optimizer, optimizer_st,
|
train_loader, optimizer, optimizer_st,
|
||||||
ap, epoch)
|
scheduler, ap, epoch)
|
||||||
val_loss = evaluate(model, criterion, criterion_st, val_loader, ap,
|
val_loss = evaluate(model, criterion, criterion_st, val_loader, ap,
|
||||||
current_step)
|
current_step)
|
||||||
print(
|
print(
|
||||||
|
|
Loading…
Reference in New Issue