linter fix

This commit is contained in:
Eren Golge 2019-10-04 18:36:32 +02:00
parent 0849e3c42f
commit fbfa20e3b3
1 changed files with 170 additions and 141 deletions

253
train.py
View File

@ -15,12 +15,11 @@ from distribute import (DistributedSampler, apply_gradient_allreduce,
init_distributed, reduce_tensor) init_distributed, reduce_tensor)
from TTS.layers.losses import L1LossMasked, MSELossMasked from TTS.layers.losses import L1LossMasked, MSELossMasked
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import (NoamLR, check_update, count_parameters, from TTS.utils.generic_utils import (
create_experiment_folder, get_git_branch, NoamLR, check_update, count_parameters, create_experiment_folder,
load_config, remove_experiment_folder, get_git_branch, load_config, remove_experiment_folder, save_best_model,
save_best_model, save_checkpoint, adam_weight_decay, save_checkpoint, adam_weight_decay, set_init_dict, copy_config_file,
set_init_dict, copy_config_file, setup_model, setup_model, gradual_training_scheduler, KeepAverage,
split_dataset, gradual_training_scheduler, KeepAverage,
set_weight_decay) set_weight_decay)
from TTS.utils.logger import Logger from TTS.utils.logger import Logger
from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \ from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \
@ -32,7 +31,6 @@ from TTS.datasets.preprocess import load_meta_data
from TTS.utils.radam import RAdam from TTS.utils.radam import RAdam
from TTS.utils.measures import alignment_diagonal_score from TTS.utils.measures import alignment_diagonal_score
torch.backends.cudnn.enabled = True torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
torch.manual_seed(54321) torch.manual_seed(54321)
@ -51,7 +49,8 @@ def setup_loader(ap, is_val=False, verbose=False):
c.text_cleaner, c.text_cleaner,
meta_data=meta_data_eval if is_val else meta_data_train, meta_data=meta_data_eval if is_val else meta_data_train,
ap=ap, ap=ap,
batch_group_size=0 if is_val else c.batch_group_size * c.batch_size, batch_group_size=0 if is_val else c.batch_group_size *
c.batch_size,
min_seq_len=c.min_seq_len, min_seq_len=c.min_seq_len,
max_seq_len=c.max_seq_len, max_seq_len=c.max_seq_len,
phoneme_cache_path=c.phoneme_cache_path, phoneme_cache_path=c.phoneme_cache_path,
@ -87,13 +86,14 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
'avg_align_score': 0, 'avg_align_score': 0,
'avg_step_time': 0, 'avg_step_time': 0,
'avg_loader_time': 0, 'avg_loader_time': 0,
'avg_alignment_score': 0} 'avg_alignment_score': 0
}
keep_avg = KeepAverage() keep_avg = KeepAverage()
keep_avg.add_values(train_values) keep_avg.add_values(train_values)
print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True) print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True)
if use_cuda: if use_cuda:
batch_n_iter = int(len(data_loader.dataset) / batch_n_iter = int(
(c.batch_size * num_gpus)) len(data_loader.dataset) / (c.batch_size * num_gpus))
else: else:
batch_n_iter = int(len(data_loader.dataset) / c.batch_size) batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
end_time = time.time() end_time = time.time()
@ -104,8 +104,8 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
text_input = data[0] text_input = data[0]
text_lengths = data[1] text_lengths = data[1]
speaker_names = data[2] speaker_names = data[2]
linear_input = data[3] if c.model in [ linear_input = data[3] if c.model in ["Tacotron", "TacotronGST"
"Tacotron", "TacotronGST"] else None ] else None
mel_input = data[4] mel_input = data[4]
mel_lengths = data[5] mel_lengths = data[5]
stop_targets = data[6] stop_targets = data[6]
@ -114,8 +114,9 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
loader_time = time.time() - end_time loader_time = time.time() - end_time
if c.use_speaker_embedding: if c.use_speaker_embedding:
speaker_ids = [speaker_mapping[speaker_name] speaker_ids = [
for speaker_name in speaker_names] speaker_mapping[speaker_name] for speaker_name in speaker_names
]
speaker_ids = torch.LongTensor(speaker_ids) speaker_ids = torch.LongTensor(speaker_ids)
else: else:
speaker_ids = None speaker_ids = None
@ -123,8 +124,8 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
# set stop targets view, we predict a single stop token per r frames prediction # set stop targets view, we predict a single stop token per r frames prediction
stop_targets = stop_targets.view(text_input.shape[0], stop_targets = stop_targets.view(text_input.shape[0],
stop_targets.size(1) // c.r, -1) stop_targets.size(1) // c.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze( stop_targets = (stop_targets.sum(2) >
2).float().squeeze(2) 0.0).unsqueeze(2).float().squeeze(2)
global_step += 1 global_step += 1
@ -141,8 +142,9 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
text_lengths = text_lengths.cuda(non_blocking=True) text_lengths = text_lengths.cuda(non_blocking=True)
mel_input = mel_input.cuda(non_blocking=True) mel_input = mel_input.cuda(non_blocking=True)
mel_lengths = mel_lengths.cuda(non_blocking=True) mel_lengths = mel_lengths.cuda(non_blocking=True)
linear_input = linear_input.cuda(non_blocking=True) if c.model in [ linear_input = linear_input.cuda(
"Tacotron", "TacotronGST"] else None non_blocking=True) if c.model in ["Tacotron", "TacotronGST"
] else None
stop_targets = stop_targets.cuda(non_blocking=True) stop_targets = stop_targets.cuda(non_blocking=True)
if speaker_ids is not None: if speaker_ids is not None:
speaker_ids = speaker_ids.cuda(non_blocking=True) speaker_ids = speaker_ids.cuda(non_blocking=True)
@ -152,16 +154,16 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
text_input, text_lengths, mel_input, speaker_ids=speaker_ids) text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
# loss computation # loss computation
stop_loss = criterion_st( stop_loss = criterion_st(stop_tokens,
stop_tokens, stop_targets) if c.stopnet else torch.zeros(1) stop_targets) if c.stopnet else torch.zeros(1)
if c.loss_masking: if c.loss_masking:
decoder_loss = criterion(decoder_output, mel_input, mel_lengths) decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
if c.model in ["Tacotron", "TacotronGST"]: if c.model in ["Tacotron", "TacotronGST"]:
postnet_loss = criterion( postnet_loss = criterion(postnet_output, linear_input,
postnet_output, linear_input, mel_lengths) mel_lengths)
else: else:
postnet_loss = criterion( postnet_loss = criterion(postnet_output, mel_input,
postnet_output, mel_input, mel_lengths) mel_lengths)
else: else:
decoder_loss = criterion(decoder_output, mel_input) decoder_loss = criterion(decoder_output, mel_input)
if c.model in ["Tacotron", "TacotronGST"]: if c.model in ["Tacotron", "TacotronGST"]:
@ -199,10 +201,10 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
"DecoderLoss:{:.5f} StopLoss:{:.5f} AlignScore:{:.4f} GradNorm:{:.5f} " "DecoderLoss:{:.5f} StopLoss:{:.5f} AlignScore:{:.4f} GradNorm:{:.5f} "
"GradNormST:{:.5f} AvgTextLen:{:.1f} AvgSpecLen:{:.1f} StepTime:{:.2f} " "GradNormST:{:.5f} AvgTextLen:{:.1f} AvgSpecLen:{:.1f} StepTime:{:.2f} "
"LoaderTime:{:.2f} LR:{:.6f}".format( "LoaderTime:{:.2f} LR:{:.6f}".format(
num_iter, batch_n_iter, global_step, num_iter, batch_n_iter, global_step, postnet_loss.item(),
postnet_loss.item(), decoder_loss.item(), stop_loss.item(), align_score, decoder_loss.item(), stop_loss.item(), align_score,
grad_norm, grad_norm_st, avg_text_length, avg_spec_length, step_time, grad_norm, grad_norm_st, avg_text_length, avg_spec_length,
loader_time, current_lr), step_time, loader_time, current_lr),
flush=True) flush=True)
# aggregate losses from processes # aggregate losses from processes
@ -210,26 +212,36 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
postnet_loss = reduce_tensor(postnet_loss.data, num_gpus) postnet_loss = reduce_tensor(postnet_loss.data, num_gpus)
decoder_loss = reduce_tensor(decoder_loss.data, num_gpus) decoder_loss = reduce_tensor(decoder_loss.data, num_gpus)
loss = reduce_tensor(loss.data, num_gpus) loss = reduce_tensor(loss.data, num_gpus)
stop_loss = reduce_tensor( stop_loss = reduce_tensor(stop_loss.data,
stop_loss.data, num_gpus) if c.stopnet else stop_loss num_gpus) if c.stopnet else stop_loss
if args.rank == 0: if args.rank == 0:
update_train_values = {'avg_postnet_loss': float(postnet_loss.item()), update_train_values = {
'avg_decoder_loss': float(decoder_loss.item()), 'avg_postnet_loss':
'avg_stop_loss': stop_loss if isinstance(stop_loss, float) else float(stop_loss.item()), float(postnet_loss.item()),
'avg_step_time': step_time, 'avg_decoder_loss':
'avg_loader_time': loader_time} float(decoder_loss.item()),
'avg_stop_loss':
stop_loss
if isinstance(stop_loss, float) else float(stop_loss.item()),
'avg_step_time':
step_time,
'avg_loader_time':
loader_time
}
keep_avg.update_values(update_train_values) keep_avg.update_values(update_train_values)
# Plot Training Iter Stats # Plot Training Iter Stats
# reduce TB load # reduce TB load
if global_step % 10 == 0: if global_step % 10 == 0:
iter_stats = {"loss_posnet": postnet_loss.item(), iter_stats = {
"loss_posnet": postnet_loss.item(),
"loss_decoder": decoder_loss.item(), "loss_decoder": decoder_loss.item(),
"lr": current_lr, "lr": current_lr,
"grad_norm": grad_norm, "grad_norm": grad_norm,
"grad_norm_st": grad_norm_st, "grad_norm_st": grad_norm_st,
"step_time": step_time} "step_time": step_time
}
tb_logger.tb_train_iter_stats(global_step, iter_stats) tb_logger.tb_train_iter_stats(global_step, iter_stats)
if global_step % c.save_step == 0: if global_step % c.save_step == 0:
@ -242,7 +254,8 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
# Diagnostic visualizations # Diagnostic visualizations
const_spec = postnet_output[0].data.cpu().numpy() const_spec = postnet_output[0].data.cpu().numpy()
gt_spec = linear_input[0].data.cpu().numpy() if c.model in [ gt_spec = linear_input[0].data.cpu().numpy() if c.model in [
"Tacotron", "TacotronGST"] else mel_input[0].data.cpu().numpy() "Tacotron", "TacotronGST"
] else mel_input[0].data.cpu().numpy()
align_img = alignments[0].data.cpu().numpy() align_img = alignments[0].data.cpu().numpy()
figures = { figures = {
@ -263,23 +276,26 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
end_time = time.time() end_time = time.time()
# print epoch stats # print epoch stats
print( print(" | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "
" | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "
"AvgPostnetLoss:{:.5f} AvgDecoderLoss:{:.5f} " "AvgPostnetLoss:{:.5f} AvgDecoderLoss:{:.5f} "
"AvgStopLoss:{:.5f} EpochTime:{:.2f} " "AvgStopLoss:{:.5f} EpochTime:{:.2f} "
"AvgStepTime:{:.2f} AvgLoaderTime:{:.2f}".format(global_step, keep_avg['avg_postnet_loss'], keep_avg['avg_decoder_loss'], "AvgStepTime:{:.2f} AvgLoaderTime:{:.2f}".format(
keep_avg['avg_stop_loss'], keep_avg['avg_align_score'], global_step, keep_avg['avg_postnet_loss'],
epoch_time, keep_avg['avg_step_time'], keep_avg['avg_loader_time']), keep_avg['avg_decoder_loss'], keep_avg['avg_stop_loss'],
keep_avg['avg_align_score'], epoch_time,
keep_avg['avg_step_time'], keep_avg['avg_loader_time']),
flush=True) flush=True)
# Plot Epoch Stats # Plot Epoch Stats
if args.rank == 0: if args.rank == 0:
# Plot Training Epoch Stats # Plot Training Epoch Stats
epoch_stats = {"loss_postnet": keep_avg['avg_postnet_loss'], epoch_stats = {
"loss_postnet": keep_avg['avg_postnet_loss'],
"loss_decoder": keep_avg['avg_decoder_loss'], "loss_decoder": keep_avg['avg_decoder_loss'],
"stop_loss": keep_avg['avg_stop_loss'], "stop_loss": keep_avg['avg_stop_loss'],
"alignment_score": keep_avg['avg_align_score'], "alignment_score": keep_avg['avg_align_score'],
"epoch_time": epoch_time} "epoch_time": epoch_time
}
tb_logger.tb_train_epoch_stats(global_step, epoch_stats) tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
if c.tb_model_param_stats: if c.tb_model_param_stats:
tb_logger.tb_model_weights(model, global_step) tb_logger.tb_model_weights(model, global_step)
@ -292,10 +308,12 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
speaker_mapping = load_speaker_mapping(OUT_PATH) speaker_mapping = load_speaker_mapping(OUT_PATH)
model.eval() model.eval()
epoch_time = 0 epoch_time = 0
eval_values_dict = {'avg_postnet_loss': 0, eval_values_dict = {
'avg_postnet_loss': 0,
'avg_decoder_loss': 0, 'avg_decoder_loss': 0,
'avg_stop_loss': 0, 'avg_stop_loss': 0,
'avg_align_score': 0} 'avg_align_score': 0
}
keep_avg = KeepAverage() keep_avg = KeepAverage()
keep_avg.add_values(eval_values_dict) keep_avg.add_values(eval_values_dict)
print("\n > Validation") print("\n > Validation")
@ -319,14 +337,17 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
text_lengths = data[1] text_lengths = data[1]
speaker_names = data[2] speaker_names = data[2]
linear_input = data[3] if c.model in [ linear_input = data[3] if c.model in [
"Tacotron", "TacotronGST"] else None "Tacotron", "TacotronGST"
] else None
mel_input = data[4] mel_input = data[4]
mel_lengths = data[5] mel_lengths = data[5]
stop_targets = data[6] stop_targets = data[6]
if c.use_speaker_embedding: if c.use_speaker_embedding:
speaker_ids = [speaker_mapping[speaker_name] speaker_ids = [
for speaker_name in speaker_names] speaker_mapping[speaker_name]
for speaker_name in speaker_names
]
speaker_ids = torch.LongTensor(speaker_ids) speaker_ids = torch.LongTensor(speaker_ids)
else: else:
speaker_ids = None speaker_ids = None
@ -335,8 +356,8 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
stop_targets = stop_targets.view(text_input.shape[0], stop_targets = stop_targets.view(text_input.shape[0],
stop_targets.size(1) // c.r, stop_targets.size(1) // c.r,
-1) -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze( stop_targets = (stop_targets.sum(2) >
2).float().squeeze(2) 0.0).unsqueeze(2).float().squeeze(2)
# dispatch data to GPU # dispatch data to GPU
if use_cuda: if use_cuda:
@ -344,7 +365,8 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
mel_input = mel_input.cuda() mel_input = mel_input.cuda()
mel_lengths = mel_lengths.cuda() mel_lengths = mel_lengths.cuda()
linear_input = linear_input.cuda() if c.model in [ linear_input = linear_input.cuda() if c.model in [
"Tacotron", "TacotronGST"] else None "Tacotron", "TacotronGST"
] else None
stop_targets = stop_targets.cuda() stop_targets = stop_targets.cuda()
if speaker_ids is not None: if speaker_ids is not None:
speaker_ids = speaker_ids.cuda() speaker_ids = speaker_ids.cuda()
@ -358,14 +380,14 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
stop_loss = criterion_st( stop_loss = criterion_st(
stop_tokens, stop_targets) if c.stopnet else torch.zeros(1) stop_tokens, stop_targets) if c.stopnet else torch.zeros(1)
if c.loss_masking: if c.loss_masking:
decoder_loss = criterion( decoder_loss = criterion(decoder_output, mel_input,
decoder_output, mel_input, mel_lengths) mel_lengths)
if c.model in ["Tacotron", "TacotronGST"]: if c.model in ["Tacotron", "TacotronGST"]:
postnet_loss = criterion( postnet_loss = criterion(postnet_output, linear_input,
postnet_output, linear_input, mel_lengths) mel_lengths)
else: else:
postnet_loss = criterion( postnet_loss = criterion(postnet_output, mel_input,
postnet_output, mel_input, mel_lengths) mel_lengths)
else: else:
decoder_loss = criterion(decoder_output, mel_input) decoder_loss = criterion(decoder_output, mel_input)
if c.model in ["Tacotron", "TacotronGST"]: if c.model in ["Tacotron", "TacotronGST"]:
@ -388,19 +410,25 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
if c.stopnet: if c.stopnet:
stop_loss = reduce_tensor(stop_loss.data, num_gpus) stop_loss = reduce_tensor(stop_loss.data, num_gpus)
keep_avg.update_values({'avg_postnet_loss': float(postnet_loss.item()), keep_avg.update_values({
'avg_decoder_loss': float(decoder_loss.item()), 'avg_postnet_loss':
'avg_stop_loss': float(stop_loss.item())}) float(postnet_loss.item()),
'avg_decoder_loss':
float(decoder_loss.item()),
'avg_stop_loss':
float(stop_loss.item())
})
if num_iter % c.print_step == 0: if num_iter % c.print_step == 0:
print( print(
" | > TotalLoss: {:.5f} PostnetLoss: {:.5f} - {:.5f} DecoderLoss:{:.5f} - {:.5f} " " | > TotalLoss: {:.5f} PostnetLoss: {:.5f} - {:.5f} DecoderLoss:{:.5f} - {:.5f} "
"StopLoss: {:.5f} - {:.5f} AlignScore: {:.4f} : {:.4f}".format( "StopLoss: {:.5f} - {:.5f} AlignScore: {:.4f} : {:.4f}"
loss.item(), .format(loss.item(), postnet_loss.item(),
postnet_loss.item(), keep_avg['avg_postnet_loss'], keep_avg['avg_postnet_loss'],
decoder_loss.item(), keep_avg['avg_decoder_loss'], decoder_loss.item(),
stop_loss.item(), keep_avg['avg_stop_loss'], keep_avg['avg_decoder_loss'], stop_loss.item(),
align_score, keep_avg['avg_align_score']), keep_avg['avg_stop_loss'], align_score,
keep_avg['avg_align_score']),
flush=True) flush=True)
if args.rank == 0: if args.rank == 0:
@ -408,7 +436,8 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
idx = np.random.randint(mel_input.shape[0]) idx = np.random.randint(mel_input.shape[0])
const_spec = postnet_output[idx].data.cpu().numpy() const_spec = postnet_output[idx].data.cpu().numpy()
gt_spec = linear_input[idx].data.cpu().numpy() if c.model in [ gt_spec = linear_input[idx].data.cpu().numpy() if c.model in [
"Tacotron", "TacotronGST"] else mel_input[idx].data.cpu().numpy() "Tacotron", "TacotronGST"
] else mel_input[idx].data.cpu().numpy()
align_img = alignments[idx].data.cpu().numpy() align_img = alignments[idx].data.cpu().numpy()
eval_figures = { eval_figures = {
@ -423,13 +452,15 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
eval_audio = ap.inv_spectrogram(const_spec.T) eval_audio = ap.inv_spectrogram(const_spec.T)
else: else:
eval_audio = ap.inv_mel_spectrogram(const_spec.T) eval_audio = ap.inv_mel_spectrogram(const_spec.T)
tb_logger.tb_eval_audios( tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
global_step, {"ValAudio": eval_audio}, c.audio["sample_rate"]) c.audio["sample_rate"])
# Plot Validation Stats # Plot Validation Stats
epoch_stats = {"loss_postnet": keep_avg['avg_postnet_loss'], epoch_stats = {
"loss_postnet": keep_avg['avg_postnet_loss'],
"loss_decoder": keep_avg['avg_decoder_loss'], "loss_decoder": keep_avg['avg_decoder_loss'],
"stop_loss": keep_avg['avg_stop_loss']} "stop_loss": keep_avg['avg_stop_loss']
}
tb_logger.tb_eval_stats(global_step, epoch_stats) tb_logger.tb_eval_stats(global_step, epoch_stats)
if args.rank == 0 and epoch > c.test_delay_epochs: if args.rank == 0 and epoch > c.test_delay_epochs:
@ -442,7 +473,11 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
for idx, test_sentence in enumerate(test_sentences): for idx, test_sentence in enumerate(test_sentences):
try: try:
wav, alignment, decoder_output, postnet_output, stop_tokens = synthesis( wav, alignment, decoder_output, postnet_output, stop_tokens = synthesis(
model, test_sentence, c, use_cuda, ap, model,
test_sentence,
c,
use_cuda,
ap,
speaker_id=speaker_id, speaker_id=speaker_id,
style_wav=style_wav) style_wav=style_wav)
file_path = os.path.join(AUDIO_PATH, str(global_step)) file_path = os.path.join(AUDIO_PATH, str(global_step))
@ -451,15 +486,15 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
"TestSentence_{}.wav".format(idx)) "TestSentence_{}.wav".format(idx))
ap.save_wav(wav, file_path) ap.save_wav(wav, file_path)
test_audios['{}-audio'.format(idx)] = wav test_audios['{}-audio'.format(idx)] = wav
test_figures['{}-prediction'.format(idx) test_figures['{}-prediction'.format(idx)] = plot_spectrogram(
] = plot_spectrogram(postnet_output, ap) postnet_output, ap)
test_figures['{}-alignment'.format(idx) test_figures['{}-alignment'.format(idx)] = plot_alignment(
] = plot_alignment(alignment) alignment)
except: except:
print(" !! Error creating Test Sentence -", idx) print(" !! Error creating Test Sentence -", idx)
traceback.print_exc() traceback.print_exc()
tb_logger.tb_test_audios( tb_logger.tb_test_audios(global_step, test_audios,
global_step, test_audios, c.audio['sample_rate']) c.audio['sample_rate'])
tb_logger.tb_test_figures(global_step, test_figures) tb_logger.tb_test_figures(global_step, test_figures)
return keep_avg['avg_postnet_loss'] return keep_avg['avg_postnet_loss']
@ -490,8 +525,7 @@ def main(args): # pylint: disable=redefined-outer-name
"introduce new speakers to " \ "introduce new speakers to " \
"a previously trained model." "a previously trained model."
else: else:
speaker_mapping = {name: i speaker_mapping = {name: i for i, name in enumerate(speakers)}
for i, name in enumerate(speakers)}
save_speaker_mapping(OUT_PATH, speaker_mapping) save_speaker_mapping(OUT_PATH, speaker_mapping)
num_speakers = len(speaker_mapping) num_speakers = len(speaker_mapping)
print("Training with {} speakers: {}".format(num_speakers, print("Training with {} speakers: {}".format(num_speakers,
@ -506,18 +540,20 @@ def main(args): # pylint: disable=redefined-outer-name
params = set_weight_decay(model, c.wd) params = set_weight_decay(model, c.wd)
optimizer = RAdam(params, lr=c.lr, weight_decay=0) optimizer = RAdam(params, lr=c.lr, weight_decay=0)
if c.stopnet and c.separate_stopnet: if c.stopnet and c.separate_stopnet:
optimizer_st = RAdam( optimizer_st = RAdam(model.decoder.stopnet.parameters(),
model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0) lr=c.lr,
weight_decay=0)
else: else:
optimizer_st = None optimizer_st = None
if c.loss_masking: if c.loss_masking:
criterion = L1LossMasked() if c.model in [ criterion = L1LossMasked() if c.model in ["Tacotron", "TacotronGST"
"Tacotron", "TacotronGST"] else MSELossMasked() ] else MSELossMasked()
else: else:
criterion = nn.L1Loss() if c.model in [ criterion = nn.L1Loss() if c.model in ["Tacotron", "TacotronGST"
"Tacotron", "TacotronGST"] else nn.MSELoss() ] else nn.MSELoss()
criterion_st = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(20.0)) if c.stopnet else None criterion_st = nn.BCEWithLogitsLoss(
pos_weight=torch.tensor(20.0)) if c.stopnet else None
if args.restore_path: if args.restore_path:
checkpoint = torch.load(args.restore_path) checkpoint = torch.load(args.restore_path)
@ -536,8 +572,8 @@ def main(args): # pylint: disable=redefined-outer-name
del model_dict del model_dict
for group in optimizer.param_groups: for group in optimizer.param_groups:
group['lr'] = c.lr group['lr'] = c.lr
print( print(" > Model restored from step %d" % checkpoint['step'],
" > Model restored from step %d" % checkpoint['step'], flush=True) flush=True)
args.restore_step = checkpoint['step'] args.restore_step = checkpoint['step']
else: else:
args.restore_step = 0 args.restore_step = 0
@ -553,8 +589,7 @@ def main(args): # pylint: disable=redefined-outer-name
model = apply_gradient_allreduce(model) model = apply_gradient_allreduce(model)
if c.lr_decay: if c.lr_decay:
scheduler = NoamLR( scheduler = NoamLR(optimizer,
optimizer,
warmup_steps=c.warmup_steps, warmup_steps=c.warmup_steps,
last_epoch=args.restore_step - 1) last_epoch=args.restore_step - 1)
else: else:
@ -576,12 +611,11 @@ def main(args): # pylint: disable=redefined-outer-name
print(" > Number of outputs per iteration:", model.decoder.r) print(" > Number of outputs per iteration:", model.decoder.r)
train_loss, global_step = train(model, criterion, criterion_st, train_loss, global_step = train(model, criterion, criterion_st,
optimizer, optimizer_st, scheduler, optimizer, optimizer_st, scheduler, ap,
ap, global_step, epoch) global_step, epoch)
val_loss = evaluate(model, criterion, criterion_st, val_loss = evaluate(model, criterion, criterion_st, ap, global_step,
ap, global_step, epoch) epoch)
print( print(" | > Training Loss: {:.5f} Validation Loss: {:.5f}".format(
" | > Training Loss: {:.5f} Validation Loss: {:.5f}".format(
train_loss, val_loss), train_loss, val_loss),
flush=True) flush=True)
target_loss = train_loss target_loss = train_loss
@ -603,8 +637,7 @@ if __name__ == '__main__':
type=str, type=str,
help='Path to config file for training.', help='Path to config file for training.',
) )
parser.add_argument( parser.add_argument('--debug',
'--debug',
type=bool, type=bool,
default=True, default=True,
help='Do not verify commit integrity to run training.') help='Do not verify commit integrity to run training.')
@ -613,17 +646,14 @@ if __name__ == '__main__':
type=str, type=str,
default='', default='',
help='Defines the data path. It overwrites config.json.') help='Defines the data path. It overwrites config.json.')
parser.add_argument( parser.add_argument('--output_path',
'--output_path',
type=str, type=str,
help='path for training outputs.', help='path for training outputs.',
default='') default='')
parser.add_argument( parser.add_argument('--output_folder',
'--output_folder',
type=str, type=str,
default='', default='',
help='folder name for training outputs.' help='folder name for training outputs.')
)
# DISTRUBUTED # DISTRUBUTED
parser.add_argument( parser.add_argument(
@ -631,8 +661,7 @@ if __name__ == '__main__':
type=int, type=int,
default=0, default=0,
help='DISTRIBUTED: process rank for distributed training.') help='DISTRIBUTED: process rank for distributed training.')
parser.add_argument( parser.add_argument('--group_id',
'--group_id',
type=str, type=str,
default="", default="",
help='DISTRIBUTED: process group id.') help='DISTRIBUTED: process group id.')
@ -662,8 +691,8 @@ if __name__ == '__main__':
if args.restore_path: if args.restore_path:
new_fields["restore_path"] = args.restore_path new_fields["restore_path"] = args.restore_path
new_fields["github_branch"] = get_git_branch() new_fields["github_branch"] = get_git_branch()
copy_config_file(args.config_path, os.path.join( copy_config_file(args.config_path,
OUT_PATH, 'config.json'), new_fields) os.path.join(OUT_PATH, 'config.json'), new_fields)
os.chmod(AUDIO_PATH, 0o775) os.chmod(AUDIO_PATH, 0o775)
os.chmod(OUT_PATH, 0o775) os.chmod(OUT_PATH, 0o775)