mirror of https://github.com/coqui-ai/TTS.git
Separate backward pass for stop-token prediction
This commit is contained in:
parent
5a5b9263e2
commit
dc3f909ddc
29
train.py
29
train.py
|
@ -59,7 +59,7 @@ LOG_DIR = OUT_PATH
|
||||||
tb = SummaryWriter(LOG_DIR)
|
tb = SummaryWriter(LOG_DIR)
|
||||||
|
|
||||||
|
|
||||||
def train(model, criterion, criterion_st, data_loader, optimizer, epoch):
|
def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, epoch):
|
||||||
model = model.train()
|
model = model.train()
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
avg_linear_loss = 0
|
avg_linear_loss = 0
|
||||||
|
@ -88,10 +88,15 @@ def train(model, criterion, criterion_st, data_loader, optimizer, epoch):
|
||||||
|
|
||||||
# setup lr
|
# setup lr
|
||||||
current_lr = lr_decay(c.lr, current_step, c.warmup_steps)
|
current_lr = lr_decay(c.lr, current_step, c.warmup_steps)
|
||||||
|
|
||||||
for params_group in optimizer.param_groups:
|
for params_group in optimizer.param_groups:
|
||||||
params_group['lr'] = current_lr
|
params_group['lr'] = current_lr
|
||||||
|
|
||||||
|
for params_group in optimizer_st.param_groups:
|
||||||
|
params_group['lr'] = current_lr
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
optimizer_st.zero_grad()
|
||||||
|
|
||||||
# dispatch data to GPU
|
# dispatch data to GPU
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
|
@ -112,16 +117,25 @@ def train(model, criterion, criterion_st, data_loader, optimizer, epoch):
|
||||||
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
||||||
linear_input[:, :, :n_priority_freq],
|
linear_input[:, :, :n_priority_freq],
|
||||||
mel_lengths)
|
mel_lengths)
|
||||||
loss = mel_loss + linear_loss + stop_loss
|
loss = mel_loss + linear_loss
|
||||||
|
|
||||||
# backpass and check the grad norm
|
# backpass and check the grad norm for spec losses
|
||||||
loss.backward()
|
loss.backward(retain_graph=True)
|
||||||
grad_norm, skip_flag = check_update(model, 0.5, 100)
|
grad_norm, skip_flag = check_update(model, 0.5, 100)
|
||||||
if skip_flag:
|
if skip_flag:
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
print(" | > Iteration skipped!!")
|
print(" | > Iteration skipped!!")
|
||||||
continue
|
continue
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
# backpass and check the grad norm for stop loss
|
||||||
|
stop_loss.backward()
|
||||||
|
grad_norm_st, skip_flag = check_update(model.module.decoder.stopnet, 0.5, 100)
|
||||||
|
if skip_flag:
|
||||||
|
optimizer_st.zero_grad()
|
||||||
|
print(" | > Iteration skipped fro stopnet!!")
|
||||||
|
continue
|
||||||
|
optimizer_st.step()
|
||||||
|
|
||||||
step_time = time.time() - start_time
|
step_time = time.time() - start_time
|
||||||
epoch_time += step_time
|
epoch_time += step_time
|
||||||
|
@ -131,7 +145,8 @@ def train(model, criterion, criterion_st, data_loader, optimizer, epoch):
|
||||||
('linear_loss', linear_loss.item()),
|
('linear_loss', linear_loss.item()),
|
||||||
('mel_loss', mel_loss.item()),
|
('mel_loss', mel_loss.item()),
|
||||||
('stop_loss', stop_loss.item()),
|
('stop_loss', stop_loss.item()),
|
||||||
('grad_norm', grad_norm.item())])
|
('grad_norm', grad_norm.item()),
|
||||||
|
('grad_norm_st', grad_norm_st.item())])
|
||||||
avg_linear_loss += linear_loss.item()
|
avg_linear_loss += linear_loss.item()
|
||||||
avg_mel_loss += mel_loss.item()
|
avg_mel_loss += mel_loss.item()
|
||||||
avg_stop_loss += stop_loss.item()
|
avg_stop_loss += stop_loss.item()
|
||||||
|
@ -144,6 +159,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, epoch):
|
||||||
tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'],
|
tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'],
|
||||||
current_step)
|
current_step)
|
||||||
tb.add_scalar('Params/GradNorm', grad_norm, current_step)
|
tb.add_scalar('Params/GradNorm', grad_norm, current_step)
|
||||||
|
tb.add_scalar('Params/GradNormSt', grad_norm_st, current_step)
|
||||||
tb.add_scalar('Time/StepTime', step_time, current_step)
|
tb.add_scalar('Time/StepTime', step_time, current_step)
|
||||||
|
|
||||||
if current_step % c.save_step == 0:
|
if current_step % c.save_step == 0:
|
||||||
|
@ -345,6 +361,7 @@ def main(args):
|
||||||
c.r)
|
c.r)
|
||||||
|
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||||
|
optimizer_st = optim.Adam(model.decoder.stopnet.parameters(), lr=c.lr)
|
||||||
|
|
||||||
criterion = L1LossMasked()
|
criterion = L1LossMasked()
|
||||||
criterion_st = nn.BCELoss()
|
criterion_st = nn.BCELoss()
|
||||||
|
@ -378,7 +395,7 @@ def main(args):
|
||||||
|
|
||||||
for epoch in range(0, c.epochs):
|
for epoch in range(0, c.epochs):
|
||||||
train_loss, current_step = train(
|
train_loss, current_step = train(
|
||||||
model, criterion, criterion_st, train_loader, optimizer, epoch)
|
model, criterion, criterion_st, train_loader, optimizer, optimizer_st, epoch)
|
||||||
val_loss = evaluate(model, criterion, criterion_st, val_loader, current_step)
|
val_loss = evaluate(model, criterion, criterion_st, val_loader, current_step)
|
||||||
best_loss = save_best_model(model, optimizer, val_loss,
|
best_loss = save_best_model(model, optimizer, val_loss,
|
||||||
best_loss, OUT_PATH,
|
best_loss, OUT_PATH,
|
||||||
|
|
Loading…
Reference in New Issue