mirror of https://github.com/coqui-ai/TTS.git
compute alignment diagonality score and encapsulate stats averaging with a class in traning
This commit is contained in:
parent
d1828c9573
commit
609d8efa69
202
train.py
202
train.py
|
@ -20,7 +20,7 @@ from TTS.utils.generic_utils import (NoamLR, check_update, count_parameters,
|
||||||
load_config, remove_experiment_folder,
|
load_config, remove_experiment_folder,
|
||||||
save_best_model, save_checkpoint, weight_decay,
|
save_best_model, save_checkpoint, weight_decay,
|
||||||
set_init_dict, copy_config_file, setup_model,
|
set_init_dict, copy_config_file, setup_model,
|
||||||
split_dataset, gradual_training_scheduler)
|
split_dataset, gradual_training_scheduler, KeepAverage)
|
||||||
from TTS.utils.logger import Logger
|
from TTS.utils.logger import Logger
|
||||||
from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \
|
from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \
|
||||||
get_speakers
|
get_speakers
|
||||||
|
@ -29,6 +29,7 @@ from TTS.utils.text.symbols import phonemes, symbols
|
||||||
from TTS.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.utils.visual import plot_alignment, plot_spectrogram
|
||||||
from TTS.datasets.preprocess import get_preprocessor_by_name
|
from TTS.datasets.preprocess import get_preprocessor_by_name
|
||||||
from TTS.utils.radam import RAdam
|
from TTS.utils.radam import RAdam
|
||||||
|
from TTS.utils.measures import alignment_diagonal_score
|
||||||
|
|
||||||
|
|
||||||
torch.backends.cudnn.enabled = True
|
torch.backends.cudnn.enabled = True
|
||||||
|
@ -45,12 +46,14 @@ def setup_loader(ap, is_val=False, verbose=False):
|
||||||
global meta_data_eval
|
global meta_data_eval
|
||||||
if "meta_data_train" not in globals():
|
if "meta_data_train" not in globals():
|
||||||
if c.meta_file_train is not None:
|
if c.meta_file_train is not None:
|
||||||
meta_data_train = get_preprocessor_by_name(c.dataset)(c.data_path, c.meta_file_train)
|
meta_data_train = get_preprocessor_by_name(
|
||||||
|
c.dataset)(c.data_path, c.meta_file_train)
|
||||||
else:
|
else:
|
||||||
meta_data_train = get_preprocessor_by_name(c.dataset)(c.data_path)
|
meta_data_train = get_preprocessor_by_name(c.dataset)(c.data_path)
|
||||||
if "meta_data_eval" not in globals() and c.run_eval:
|
if "meta_data_eval" not in globals() and c.run_eval:
|
||||||
if c.meta_file_val is not None:
|
if c.meta_file_val is not None:
|
||||||
meta_data_eval = get_preprocessor_by_name(c.dataset)(c.data_path, c.meta_file_val)
|
meta_data_eval = get_preprocessor_by_name(
|
||||||
|
c.dataset)(c.data_path, c.meta_file_val)
|
||||||
else:
|
else:
|
||||||
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
|
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
|
||||||
if is_val and not c.run_eval:
|
if is_val and not c.run_eval:
|
||||||
|
@ -90,14 +93,20 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
||||||
model.train()
|
model.train()
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
avg_postnet_loss = 0
|
train_values = {
|
||||||
avg_decoder_loss = 0
|
'avg_postnet_loss': 0,
|
||||||
avg_stop_loss = 0
|
'avg_decoder_loss': 0,
|
||||||
avg_step_time = 0
|
'avg_stop_loss': 0,
|
||||||
avg_loader_time = 0
|
'avg_align_score': 0,
|
||||||
|
'avg_step_time': 0,
|
||||||
|
'avg_loader_time': 0,
|
||||||
|
'avg_alignment_score': 0}
|
||||||
|
keep_avg = KeepAverage()
|
||||||
|
keep_avg.add_values(train_values)
|
||||||
print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True)
|
print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True)
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus))
|
batch_n_iter = int(len(data_loader.dataset) /
|
||||||
|
(c.batch_size * num_gpus))
|
||||||
else:
|
else:
|
||||||
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
@ -108,7 +117,8 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
text_input = data[0]
|
text_input = data[0]
|
||||||
text_lengths = data[1]
|
text_lengths = data[1]
|
||||||
speaker_names = data[2]
|
speaker_names = data[2]
|
||||||
linear_input = data[3] if c.model in ["Tacotron", "TacotronGST"] else None
|
linear_input = data[3] if c.model in [
|
||||||
|
"Tacotron", "TacotronGST"] else None
|
||||||
mel_input = data[4]
|
mel_input = data[4]
|
||||||
mel_lengths = data[5]
|
mel_lengths = data[5]
|
||||||
stop_targets = data[6]
|
stop_targets = data[6]
|
||||||
|
@ -126,7 +136,8 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
# set stop targets view, we predict a single stop token per r frames prediction
|
# set stop targets view, we predict a single stop token per r frames prediction
|
||||||
stop_targets = stop_targets.view(text_input.shape[0],
|
stop_targets = stop_targets.view(text_input.shape[0],
|
||||||
stop_targets.size(1) // c.r, -1)
|
stop_targets.size(1) // c.r, -1)
|
||||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
|
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(
|
||||||
|
2).float().squeeze(2)
|
||||||
|
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
|
@ -143,7 +154,8 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
text_lengths = text_lengths.cuda(non_blocking=True)
|
text_lengths = text_lengths.cuda(non_blocking=True)
|
||||||
mel_input = mel_input.cuda(non_blocking=True)
|
mel_input = mel_input.cuda(non_blocking=True)
|
||||||
mel_lengths = mel_lengths.cuda(non_blocking=True)
|
mel_lengths = mel_lengths.cuda(non_blocking=True)
|
||||||
linear_input = linear_input.cuda(non_blocking=True) if c.model in ["Tacotron", "TacotronGST"] else None
|
linear_input = linear_input.cuda(non_blocking=True) if c.model in [
|
||||||
|
"Tacotron", "TacotronGST"] else None
|
||||||
stop_targets = stop_targets.cuda(non_blocking=True)
|
stop_targets = stop_targets.cuda(non_blocking=True)
|
||||||
if speaker_ids is not None:
|
if speaker_ids is not None:
|
||||||
speaker_ids = speaker_ids.cuda(non_blocking=True)
|
speaker_ids = speaker_ids.cuda(non_blocking=True)
|
||||||
|
@ -153,13 +165,16 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
|
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
stop_loss = criterion_st(stop_tokens, stop_targets) if c.stopnet else torch.zeros(1)
|
stop_loss = criterion_st(
|
||||||
|
stop_tokens, stop_targets) if c.stopnet else torch.zeros(1)
|
||||||
if c.loss_masking:
|
if c.loss_masking:
|
||||||
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
|
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
|
||||||
if c.model in ["Tacotron", "TacotronGST"]:
|
if c.model in ["Tacotron", "TacotronGST"]:
|
||||||
postnet_loss = criterion(postnet_output, linear_input, mel_lengths)
|
postnet_loss = criterion(
|
||||||
|
postnet_output, linear_input, mel_lengths)
|
||||||
else:
|
else:
|
||||||
postnet_loss = criterion(postnet_output, mel_input, mel_lengths)
|
postnet_loss = criterion(
|
||||||
|
postnet_output, mel_input, mel_lengths)
|
||||||
else:
|
else:
|
||||||
decoder_loss = criterion(decoder_output, mel_input)
|
decoder_loss = criterion(decoder_output, mel_input)
|
||||||
if c.model in ["Tacotron", "TacotronGST"]:
|
if c.model in ["Tacotron", "TacotronGST"]:
|
||||||
|
@ -175,6 +190,10 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
grad_norm, _ = check_update(model, c.grad_clip)
|
grad_norm, _ = check_update(model, c.grad_clip)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
# compute alignment score
|
||||||
|
align_score = alignment_diagonal_score(alignments)
|
||||||
|
keep_avg.update_value('avg_align_score', align_score)
|
||||||
|
|
||||||
# backpass and check the grad norm for stop loss
|
# backpass and check the grad norm for stop loss
|
||||||
if c.separate_stopnet:
|
if c.separate_stopnet:
|
||||||
stop_loss.backward()
|
stop_loss.backward()
|
||||||
|
@ -189,12 +208,12 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
|
|
||||||
if global_step % c.print_step == 0:
|
if global_step % c.print_step == 0:
|
||||||
print(
|
print(
|
||||||
" | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} PostnetLoss:{:.5f} "
|
" | > Step:{}/{} GlobalStep:{} PostnetLoss:{:.5f} "
|
||||||
"DecoderLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "
|
"DecoderLoss:{:.5f} StopLoss:{:.5f} AlignScore:{:.4f} GradNorm:{:.5f} "
|
||||||
"GradNormST:{:.5f} AvgTextLen:{:.1f} AvgSpecLen:{:.1f} StepTime:{:.2f} "
|
"GradNormST:{:.5f} AvgTextLen:{:.1f} AvgSpecLen:{:.1f} StepTime:{:.2f} "
|
||||||
"LoaderTime:{:.2f} LR:{:.6f}".format(
|
"LoaderTime:{:.2f} LR:{:.6f}".format(
|
||||||
num_iter, batch_n_iter, global_step, loss.item(),
|
num_iter, batch_n_iter, global_step,
|
||||||
postnet_loss.item(), decoder_loss.item(), stop_loss.item(),
|
postnet_loss.item(), decoder_loss.item(), stop_loss.item(), align_score.item(),
|
||||||
grad_norm, grad_norm_st, avg_text_length, avg_spec_length, step_time,
|
grad_norm, grad_norm_st, avg_text_length, avg_spec_length, step_time,
|
||||||
loader_time, current_lr),
|
loader_time, current_lr),
|
||||||
flush=True)
|
flush=True)
|
||||||
|
@ -204,14 +223,16 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
postnet_loss = reduce_tensor(postnet_loss.data, num_gpus)
|
postnet_loss = reduce_tensor(postnet_loss.data, num_gpus)
|
||||||
decoder_loss = reduce_tensor(decoder_loss.data, num_gpus)
|
decoder_loss = reduce_tensor(decoder_loss.data, num_gpus)
|
||||||
loss = reduce_tensor(loss.data, num_gpus)
|
loss = reduce_tensor(loss.data, num_gpus)
|
||||||
stop_loss = reduce_tensor(stop_loss.data, num_gpus) if c.stopnet else stop_loss
|
stop_loss = reduce_tensor(
|
||||||
|
stop_loss.data, num_gpus) if c.stopnet else stop_loss
|
||||||
|
|
||||||
if args.rank == 0:
|
if args.rank == 0:
|
||||||
avg_postnet_loss += float(postnet_loss.item())
|
update_train_values = {'avg_postnet_loss': float(postnet_loss.item()),
|
||||||
avg_decoder_loss += float(decoder_loss.item())
|
'avg_decoder_loss': float(decoder_loss.item()),
|
||||||
avg_stop_loss += stop_loss if isinstance(stop_loss, float) else float(stop_loss.item())
|
'avg_stop_loss': stop_loss if isinstance(stop_loss, float) else float(stop_loss.item()),
|
||||||
avg_step_time += step_time
|
'avg_step_time': step_time,
|
||||||
avg_loader_time += loader_time
|
'avg_loader_time': loader_time}
|
||||||
|
keep_avg.update_values(update_train_values)
|
||||||
|
|
||||||
# Plot Training Iter Stats
|
# Plot Training Iter Stats
|
||||||
# reduce TB load
|
# reduce TB load
|
||||||
|
@ -233,7 +254,8 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
|
|
||||||
# Diagnostic visualizations
|
# Diagnostic visualizations
|
||||||
const_spec = postnet_output[0].data.cpu().numpy()
|
const_spec = postnet_output[0].data.cpu().numpy()
|
||||||
gt_spec = linear_input[0].data.cpu().numpy() if c.model in ["Tacotron", "TacotronGST"] else mel_input[0].data.cpu().numpy()
|
gt_spec = linear_input[0].data.cpu().numpy() if c.model in [
|
||||||
|
"Tacotron", "TacotronGST"] else mel_input[0].data.cpu().numpy()
|
||||||
align_img = alignments[0].data.cpu().numpy()
|
align_img = alignments[0].data.cpu().numpy()
|
||||||
|
|
||||||
figures = {
|
figures = {
|
||||||
|
@ -253,35 +275,28 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
c.audio["sample_rate"])
|
c.audio["sample_rate"])
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
||||||
avg_postnet_loss /= (num_iter + 1)
|
|
||||||
avg_decoder_loss /= (num_iter + 1)
|
|
||||||
avg_stop_loss /= (num_iter + 1)
|
|
||||||
avg_total_loss = avg_decoder_loss + avg_postnet_loss + avg_stop_loss
|
|
||||||
avg_step_time /= (num_iter + 1)
|
|
||||||
avg_loader_time /= (num_iter + 1)
|
|
||||||
|
|
||||||
# print epoch stats
|
# print epoch stats
|
||||||
print(
|
print(
|
||||||
" | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "
|
" | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "
|
||||||
"AvgPostnetLoss:{:.5f} AvgDecoderLoss:{:.5f} "
|
"AvgPostnetLoss:{:.5f} AvgDecoderLoss:{:.5f} "
|
||||||
"AvgStopLoss:{:.5f} EpochTime:{:.2f} "
|
"AvgStopLoss:{:.5f} EpochTime:{:.2f} "
|
||||||
"AvgStepTime:{:.2f} AvgLoaderTime:{:.2f}".format(global_step, avg_total_loss,
|
"AvgStepTime:{:.2f} AvgLoaderTime:{:.2f}".format(global_step, keep_avg['avg_postnet_loss'], keep_avg['avg_decoder_loss'],
|
||||||
avg_postnet_loss, avg_decoder_loss,
|
keep_avg['avg_stop_loss'], keep_avg['avg_align_score'],
|
||||||
avg_stop_loss, epoch_time, avg_step_time,
|
epoch_time, keep_avg['avg_step_time'], keep_avg['avg_loader_time']),
|
||||||
avg_loader_time),
|
|
||||||
flush=True)
|
flush=True)
|
||||||
|
|
||||||
# Plot Epoch Stats
|
# Plot Epoch Stats
|
||||||
if args.rank == 0:
|
if args.rank == 0:
|
||||||
# Plot Training Epoch Stats
|
# Plot Training Epoch Stats
|
||||||
epoch_stats = {"loss_postnet": avg_postnet_loss,
|
epoch_stats = {"loss_postnet": keep_avg['avg_postnet_loss'],
|
||||||
"loss_decoder": avg_decoder_loss,
|
"loss_decoder": keep_avg['avg_decoder_loss'],
|
||||||
"stop_loss": avg_stop_loss,
|
"stop_loss": keep_avg['avg_stop_loss'],
|
||||||
|
"alignment_score": keep_avg['avg_align_score'],
|
||||||
"epoch_time": epoch_time}
|
"epoch_time": epoch_time}
|
||||||
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
|
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
|
||||||
if c.tb_model_param_stats:
|
if c.tb_model_param_stats:
|
||||||
tb_logger.tb_model_weights(model, global_step)
|
tb_logger.tb_model_weights(model, global_step)
|
||||||
return avg_postnet_loss, global_step
|
return keep_avg['avg_postnet_loss'], global_step
|
||||||
|
|
||||||
|
|
||||||
def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
||||||
|
@ -290,9 +305,12 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
||||||
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
||||||
model.eval()
|
model.eval()
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
avg_postnet_loss = 0
|
eval_values_dict = {'avg_postnet_loss' : 0,
|
||||||
avg_decoder_loss = 0
|
'avg_decoder_loss' : 0,
|
||||||
avg_stop_loss = 0
|
'avg_stop_loss' : 0,
|
||||||
|
'avg_align_score': 0}
|
||||||
|
keep_avg = KeepAverage()
|
||||||
|
keep_avg.add_values(eval_values_dict)
|
||||||
print("\n > Validation")
|
print("\n > Validation")
|
||||||
if c.test_sentences_file is None:
|
if c.test_sentences_file is None:
|
||||||
test_sentences = [
|
test_sentences = [
|
||||||
|
@ -313,7 +331,8 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
||||||
text_input = data[0]
|
text_input = data[0]
|
||||||
text_lengths = data[1]
|
text_lengths = data[1]
|
||||||
speaker_names = data[2]
|
speaker_names = data[2]
|
||||||
linear_input = data[3] if c.model in ["Tacotron", "TacotronGST"] else None
|
linear_input = data[3] if c.model in [
|
||||||
|
"Tacotron", "TacotronGST"] else None
|
||||||
mel_input = data[4]
|
mel_input = data[4]
|
||||||
mel_lengths = data[5]
|
mel_lengths = data[5]
|
||||||
stop_targets = data[6]
|
stop_targets = data[6]
|
||||||
|
@ -329,14 +348,16 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
||||||
stop_targets = stop_targets.view(text_input.shape[0],
|
stop_targets = stop_targets.view(text_input.shape[0],
|
||||||
stop_targets.size(1) // c.r,
|
stop_targets.size(1) // c.r,
|
||||||
-1)
|
-1)
|
||||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
|
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(
|
||||||
|
2).float().squeeze(2)
|
||||||
|
|
||||||
# dispatch data to GPU
|
# dispatch data to GPU
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
text_input = text_input.cuda()
|
text_input = text_input.cuda()
|
||||||
mel_input = mel_input.cuda()
|
mel_input = mel_input.cuda()
|
||||||
mel_lengths = mel_lengths.cuda()
|
mel_lengths = mel_lengths.cuda()
|
||||||
linear_input = linear_input.cuda() if c.model in ["Tacotron", "TacotronGST"] else None
|
linear_input = linear_input.cuda() if c.model in [
|
||||||
|
"Tacotron", "TacotronGST"] else None
|
||||||
stop_targets = stop_targets.cuda()
|
stop_targets = stop_targets.cuda()
|
||||||
if speaker_ids is not None:
|
if speaker_ids is not None:
|
||||||
speaker_ids = speaker_ids.cuda()
|
speaker_ids = speaker_ids.cuda()
|
||||||
|
@ -347,13 +368,17 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
||||||
speaker_ids=speaker_ids)
|
speaker_ids=speaker_ids)
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
stop_loss = criterion_st(stop_tokens, stop_targets) if c.stopnet else torch.zeros(1)
|
stop_loss = criterion_st(
|
||||||
|
stop_tokens, stop_targets) if c.stopnet else torch.zeros(1)
|
||||||
if c.loss_masking:
|
if c.loss_masking:
|
||||||
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
|
decoder_loss = criterion(
|
||||||
|
decoder_output, mel_input, mel_lengths)
|
||||||
if c.model in ["Tacotron", "TacotronGST"]:
|
if c.model in ["Tacotron", "TacotronGST"]:
|
||||||
postnet_loss = criterion(postnet_output, linear_input, mel_lengths)
|
postnet_loss = criterion(
|
||||||
|
postnet_output, linear_input, mel_lengths)
|
||||||
else:
|
else:
|
||||||
postnet_loss = criterion(postnet_output, mel_input, mel_lengths)
|
postnet_loss = criterion(
|
||||||
|
postnet_output, mel_input, mel_lengths)
|
||||||
else:
|
else:
|
||||||
decoder_loss = criterion(decoder_output, mel_input)
|
decoder_loss = criterion(decoder_output, mel_input)
|
||||||
if c.model in ["Tacotron", "TacotronGST"]:
|
if c.model in ["Tacotron", "TacotronGST"]:
|
||||||
|
@ -365,14 +390,9 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
||||||
step_time = time.time() - start_time
|
step_time = time.time() - start_time
|
||||||
epoch_time += step_time
|
epoch_time += step_time
|
||||||
|
|
||||||
if num_iter % c.print_step == 0:
|
# compute alignment score
|
||||||
print(
|
align_score = alignment_diagonal_score(alignments)
|
||||||
" | > TotalLoss: {:.5f} PostnetLoss: {:.5f} DecoderLoss:{:.5f} "
|
keep_avg.update_value('avg_align_score', align_score)
|
||||||
"StopLoss: {:.5f} ".format(loss.item(),
|
|
||||||
postnet_loss.item(),
|
|
||||||
decoder_loss.item(),
|
|
||||||
stop_loss.item()),
|
|
||||||
flush=True)
|
|
||||||
|
|
||||||
# aggregate losses from processes
|
# aggregate losses from processes
|
||||||
if num_gpus > 1:
|
if num_gpus > 1:
|
||||||
|
@ -381,15 +401,26 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
||||||
if c.stopnet:
|
if c.stopnet:
|
||||||
stop_loss = reduce_tensor(stop_loss.data, num_gpus)
|
stop_loss = reduce_tensor(stop_loss.data, num_gpus)
|
||||||
|
|
||||||
avg_postnet_loss += float(postnet_loss.item())
|
keep_avg.update_values({'avg_postnet_loss' : float(postnet_loss.item()),
|
||||||
avg_decoder_loss += float(decoder_loss.item())
|
'avg_decoder_loss' : float(decoder_loss.item()),
|
||||||
avg_stop_loss += stop_loss.item()
|
'avg_stop_loss' : float(stop_loss.item())})
|
||||||
|
|
||||||
|
if num_iter % c.print_step == 0:
|
||||||
|
print(
|
||||||
|
" | > TotalLoss: {:.5f} PostnetLoss: {:.5f} - {:.5f} DecoderLoss:{:.5f} - {:.5f} "
|
||||||
|
"StopLoss: {:.5f} - {:.5f} AlignScore: {:.4f} : {:.4f}".format(loss.item(),
|
||||||
|
postnet_loss.item(), keep_avg['avg_postnet_loss'],
|
||||||
|
decoder_loss.item(), keep_avg['avg_decoder_loss'],
|
||||||
|
stop_loss.item(), keep_avg['avg_stop_loss'],
|
||||||
|
align_score.item(), keep_avg['avg_align_score']),
|
||||||
|
flush=True)
|
||||||
|
|
||||||
if args.rank == 0:
|
if args.rank == 0:
|
||||||
# Diagnostic visualizations
|
# Diagnostic visualizations
|
||||||
idx = np.random.randint(mel_input.shape[0])
|
idx = np.random.randint(mel_input.shape[0])
|
||||||
const_spec = postnet_output[idx].data.cpu().numpy()
|
const_spec = postnet_output[idx].data.cpu().numpy()
|
||||||
gt_spec = linear_input[idx].data.cpu().numpy() if c.model in ["Tacotron", "TacotronGST"] else mel_input[idx].data.cpu().numpy()
|
gt_spec = linear_input[idx].data.cpu().numpy() if c.model in [
|
||||||
|
"Tacotron", "TacotronGST"] else mel_input[idx].data.cpu().numpy()
|
||||||
align_img = alignments[idx].data.cpu().numpy()
|
align_img = alignments[idx].data.cpu().numpy()
|
||||||
|
|
||||||
eval_figures = {
|
eval_figures = {
|
||||||
|
@ -404,17 +435,13 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
||||||
eval_audio = ap.inv_spectrogram(const_spec.T)
|
eval_audio = ap.inv_spectrogram(const_spec.T)
|
||||||
else:
|
else:
|
||||||
eval_audio = ap.inv_mel_spectrogram(const_spec.T)
|
eval_audio = ap.inv_mel_spectrogram(const_spec.T)
|
||||||
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, c.audio["sample_rate"])
|
tb_logger.tb_eval_audios(
|
||||||
|
global_step, {"ValAudio": eval_audio}, c.audio["sample_rate"])
|
||||||
# compute average losses
|
|
||||||
avg_postnet_loss /= (num_iter + 1)
|
|
||||||
avg_decoder_loss /= (num_iter + 1)
|
|
||||||
avg_stop_loss /= (num_iter + 1)
|
|
||||||
|
|
||||||
# Plot Validation Stats
|
# Plot Validation Stats
|
||||||
epoch_stats = {"loss_postnet": avg_postnet_loss,
|
epoch_stats = {"loss_postnet": keep_avg['avg_postnet_loss'],
|
||||||
"loss_decoder": avg_decoder_loss,
|
"loss_decoder": keep_avg['avg_decoder_loss'],
|
||||||
"stop_loss": avg_stop_loss}
|
"stop_loss": keep_avg['avg_stop_loss']}
|
||||||
tb_logger.tb_eval_stats(global_step, epoch_stats)
|
tb_logger.tb_eval_stats(global_step, epoch_stats)
|
||||||
|
|
||||||
if args.rank == 0 and epoch > c.test_delay_epochs:
|
if args.rank == 0 and epoch > c.test_delay_epochs:
|
||||||
|
@ -436,18 +463,21 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
||||||
"TestSentence_{}.wav".format(idx))
|
"TestSentence_{}.wav".format(idx))
|
||||||
ap.save_wav(wav, file_path)
|
ap.save_wav(wav, file_path)
|
||||||
test_audios['{}-audio'.format(idx)] = wav
|
test_audios['{}-audio'.format(idx)] = wav
|
||||||
test_figures['{}-prediction'.format(idx)] = plot_spectrogram(postnet_output, ap)
|
test_figures['{}-prediction'.format(idx)
|
||||||
test_figures['{}-alignment'.format(idx)] = plot_alignment(alignment)
|
] = plot_spectrogram(postnet_output, ap)
|
||||||
|
test_figures['{}-alignment'.format(idx)
|
||||||
|
] = plot_alignment(alignment)
|
||||||
except:
|
except:
|
||||||
print(" !! Error creating Test Sentence -", idx)
|
print(" !! Error creating Test Sentence -", idx)
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
tb_logger.tb_test_audios(global_step, test_audios, c.audio['sample_rate'])
|
tb_logger.tb_test_audios(
|
||||||
|
global_step, test_audios, c.audio['sample_rate'])
|
||||||
tb_logger.tb_test_figures(global_step, test_figures)
|
tb_logger.tb_test_figures(global_step, test_figures)
|
||||||
return avg_postnet_loss
|
return keep_avg['avg_postnet_loss']
|
||||||
|
|
||||||
|
|
||||||
#FIXME: move args definition/parsing inside of main?
|
# FIXME: move args definition/parsing inside of main?
|
||||||
def main(args): #pylint: disable=redefined-outer-name
|
def main(args): # pylint: disable=redefined-outer-name
|
||||||
# Audio processor
|
# Audio processor
|
||||||
ap = AudioProcessor(**c.audio)
|
ap = AudioProcessor(**c.audio)
|
||||||
|
|
||||||
|
@ -488,9 +518,11 @@ def main(args): #pylint: disable=redefined-outer-name
|
||||||
optimizer_st = None
|
optimizer_st = None
|
||||||
|
|
||||||
if c.loss_masking:
|
if c.loss_masking:
|
||||||
criterion = L1LossMasked() if c.model in ["Tacotron", "TacotronGST"] else MSELossMasked()
|
criterion = L1LossMasked() if c.model in [
|
||||||
|
"Tacotron", "TacotronGST"] else MSELossMasked()
|
||||||
else:
|
else:
|
||||||
criterion = nn.L1Loss() if c.model in ["Tacotron", "TacotronGST"] else nn.MSELoss()
|
criterion = nn.L1Loss() if c.model in [
|
||||||
|
"Tacotron", "TacotronGST"] else nn.MSELoss()
|
||||||
criterion_st = nn.BCEWithLogitsLoss() if c.stopnet else None
|
criterion_st = nn.BCEWithLogitsLoss() if c.stopnet else None
|
||||||
|
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
|
@ -552,7 +584,8 @@ def main(args): #pylint: disable=redefined-outer-name
|
||||||
train_loss, global_step = train(model, criterion, criterion_st,
|
train_loss, global_step = train(model, criterion, criterion_st,
|
||||||
optimizer, optimizer_st, scheduler,
|
optimizer, optimizer_st, scheduler,
|
||||||
ap, global_step, epoch)
|
ap, global_step, epoch)
|
||||||
val_loss = evaluate(model, criterion, criterion_st, ap, global_step, epoch)
|
val_loss = evaluate(model, criterion, criterion_st,
|
||||||
|
ap, global_step, epoch)
|
||||||
print(
|
print(
|
||||||
" | > Training Loss: {:.5f} Validation Loss: {:.5f}".format(
|
" | > Training Loss: {:.5f} Validation Loss: {:.5f}".format(
|
||||||
train_loss, val_loss),
|
train_loss, val_loss),
|
||||||
|
@ -635,7 +668,8 @@ if __name__ == '__main__':
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
new_fields["restore_path"] = args.restore_path
|
new_fields["restore_path"] = args.restore_path
|
||||||
new_fields["github_branch"] = get_git_branch()
|
new_fields["github_branch"] = get_git_branch()
|
||||||
copy_config_file(args.config_path, os.path.join(OUT_PATH, 'config.json'), new_fields)
|
copy_config_file(args.config_path, os.path.join(
|
||||||
|
OUT_PATH, 'config.json'), new_fields)
|
||||||
os.chmod(AUDIO_PATH, 0o775)
|
os.chmod(AUDIO_PATH, 0o775)
|
||||||
os.chmod(OUT_PATH, 0o775)
|
os.chmod(OUT_PATH, 0o775)
|
||||||
|
|
||||||
|
@ -650,8 +684,8 @@ if __name__ == '__main__':
|
||||||
try:
|
try:
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
except SystemExit:
|
except SystemExit:
|
||||||
os._exit(0) #pylint: disable=protected-access
|
os._exit(0) # pylint: disable=protected-access
|
||||||
except Exception: #pylint: disable=broad-except
|
except Exception: # pylint: disable=broad-except
|
||||||
remove_experiment_folder(OUT_PATH)
|
remove_experiment_folder(OUT_PATH)
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
|
@ -314,3 +314,34 @@ def gradual_training_scheduler(global_step, config):
|
||||||
if global_step >= values[0]:
|
if global_step >= values[0]:
|
||||||
new_values = values
|
new_values = values
|
||||||
return new_values[1], new_values[2]
|
return new_values[1], new_values[2]
|
||||||
|
|
||||||
|
|
||||||
|
class KeepAverage():
|
||||||
|
def __init__(self):
|
||||||
|
self.avg_values = {}
|
||||||
|
self.iters = {}
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self.avg_values[key]
|
||||||
|
|
||||||
|
def add_value(self, name, init_val=0, init_iter=0):
|
||||||
|
self.avg_values[name] = init_val
|
||||||
|
self.iters[name] = init_iter
|
||||||
|
|
||||||
|
def update_value(self, name, value, weighted_avg=False):
|
||||||
|
if weighted_avg:
|
||||||
|
self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value
|
||||||
|
self.iters[name] += 1
|
||||||
|
else:
|
||||||
|
self.avg_values[name] = self.avg_values[name] * self.iters[name] + value
|
||||||
|
self.iters[name] += 1
|
||||||
|
self.avg_values[name] /= self.iters[name]
|
||||||
|
|
||||||
|
def add_values(self, name_dict):
|
||||||
|
for key, value in name_dict.items():
|
||||||
|
self.add_value(key, init_val=value)
|
||||||
|
|
||||||
|
def update_values(self, value_dict):
|
||||||
|
for key, value in value_dict.items():
|
||||||
|
self.update_value(key, value)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def alignment_diagonal_score(alignments):
|
||||||
|
"""
|
||||||
|
Compute how diagonal alignment predictions are. It is useful
|
||||||
|
to measure the alignment consistency of a model
|
||||||
|
Args:
|
||||||
|
alignments (torch.Tensor): batch of alignments.
|
||||||
|
Shape:
|
||||||
|
alignments : batch x decoder_steps x encoder_steps
|
||||||
|
"""
|
||||||
|
return alignments.max(dim=1)[0].mean(dim=1).mean(dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue