mirror of https://github.com/coqui-ai/TTS.git
enalbe graudal training by config.json
This commit is contained in:
parent
f038b1aa3f
commit
ee706b50f6
46
train.py
46
train.py
|
@ -20,7 +20,7 @@ from utils.generic_utils import (NoamLR, check_update, count_parameters,
|
||||||
load_config, remove_experiment_folder,
|
load_config, remove_experiment_folder,
|
||||||
save_best_model, save_checkpoint, weight_decay,
|
save_best_model, save_checkpoint, weight_decay,
|
||||||
set_init_dict, copy_config_file, setup_model,
|
set_init_dict, copy_config_file, setup_model,
|
||||||
split_dataset)
|
split_dataset, gradual_training_scheduler)
|
||||||
from utils.logger import Logger
|
from utils.logger import Logger
|
||||||
from utils.speakers import load_speaker_mapping, save_speaker_mapping, \
|
from utils.speakers import load_speaker_mapping, save_speaker_mapping, \
|
||||||
get_speakers
|
get_speakers
|
||||||
|
@ -81,20 +81,6 @@ def setup_loader(ap, is_val=False, verbose=False):
|
||||||
return loader
|
return loader
|
||||||
|
|
||||||
|
|
||||||
def gradual_training_scheduler(global_step):
|
|
||||||
if global_step < 10000:
|
|
||||||
r, batch_size = 7, 32
|
|
||||||
elif global_step < 50000:
|
|
||||||
r, batch_size = 5, 32
|
|
||||||
elif global_step < 130000:
|
|
||||||
r, batch_size = 3, 32
|
|
||||||
elif global_step < 290000:
|
|
||||||
r, batch_size = 2, 16
|
|
||||||
else:
|
|
||||||
r, batch_size = 1, 16
|
|
||||||
return r, batch_size
|
|
||||||
|
|
||||||
|
|
||||||
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
ap, epoch):
|
ap, epoch):
|
||||||
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
||||||
|
@ -106,8 +92,10 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
avg_decoder_loss = 0
|
avg_decoder_loss = 0
|
||||||
avg_stop_loss = 0
|
avg_stop_loss = 0
|
||||||
avg_step_time = 0
|
avg_step_time = 0
|
||||||
|
avg_loader_time = 0
|
||||||
print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True)
|
print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True)
|
||||||
batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus))
|
batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus))
|
||||||
|
end_time = time.time()
|
||||||
for num_iter, data in enumerate(data_loader):
|
for num_iter, data in enumerate(data_loader):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
@ -121,6 +109,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
stop_targets = data[6]
|
stop_targets = data[6]
|
||||||
avg_text_length = torch.mean(text_lengths.float())
|
avg_text_length = torch.mean(text_lengths.float())
|
||||||
avg_spec_length = torch.mean(mel_lengths.float())
|
avg_spec_length = torch.mean(mel_lengths.float())
|
||||||
|
loader_time = time.time() - end_time
|
||||||
|
|
||||||
if c.use_speaker_embedding:
|
if c.use_speaker_embedding:
|
||||||
speaker_ids = [speaker_mapping[speaker_name]
|
speaker_ids = [speaker_mapping[speaker_name]
|
||||||
|
@ -191,17 +180,16 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
else:
|
else:
|
||||||
grad_norm_st = 0
|
grad_norm_st = 0
|
||||||
|
|
||||||
step_time = time.time() - start_time
|
|
||||||
epoch_time += step_time
|
|
||||||
|
|
||||||
if current_step % c.print_step == 0:
|
if current_step % c.print_step == 0:
|
||||||
print(
|
print(
|
||||||
" | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} PostnetLoss:{:.5f} "
|
" | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} PostnetLoss:{:.5f} "
|
||||||
"DecoderLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "
|
"DecoderLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "
|
||||||
"GradNormST:{:.5f} AvgTextLen:{:.1f} AvgSpecLen:{:.1f} StepTime:{:.2f} LR:{:.6f}".format(
|
"GradNormST:{:.5f} AvgTextLen:{:.1f} AvgSpecLen:{:.1f} StepTime:{:.2f} "
|
||||||
|
"LoaderTime:{:.2f} LR:{:.6f}".format(
|
||||||
num_iter, batch_n_iter, current_step, loss.item(),
|
num_iter, batch_n_iter, current_step, loss.item(),
|
||||||
postnet_loss.item(), decoder_loss.item(), stop_loss.item(),
|
postnet_loss.item(), decoder_loss.item(), stop_loss.item(),
|
||||||
grad_norm, grad_norm_st, avg_text_length, avg_spec_length, step_time, current_lr),
|
grad_norm, grad_norm_st, avg_text_length, avg_spec_length, step_time,
|
||||||
|
loader_time, current_lr),
|
||||||
flush=True)
|
flush=True)
|
||||||
|
|
||||||
# aggregate losses from processes
|
# aggregate losses from processes
|
||||||
|
@ -216,6 +204,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
avg_decoder_loss += float(decoder_loss.item())
|
avg_decoder_loss += float(decoder_loss.item())
|
||||||
avg_stop_loss += stop_loss if isinstance(stop_loss, float) else float(stop_loss.item())
|
avg_stop_loss += stop_loss if isinstance(stop_loss, float) else float(stop_loss.item())
|
||||||
avg_step_time += step_time
|
avg_step_time += step_time
|
||||||
|
avg_loader_time += loader_time
|
||||||
|
|
||||||
# Plot Training Iter Stats
|
# Plot Training Iter Stats
|
||||||
iter_stats = {"loss_posnet": postnet_loss.item(),
|
iter_stats = {"loss_posnet": postnet_loss.item(),
|
||||||
|
@ -254,11 +243,15 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
{'TrainAudio': train_audio},
|
{'TrainAudio': train_audio},
|
||||||
c.audio["sample_rate"])
|
c.audio["sample_rate"])
|
||||||
|
|
||||||
|
step_time = end_time - start_time
|
||||||
|
epoch_time += step_time
|
||||||
|
|
||||||
avg_postnet_loss /= (num_iter + 1)
|
avg_postnet_loss /= (num_iter + 1)
|
||||||
avg_decoder_loss /= (num_iter + 1)
|
avg_decoder_loss /= (num_iter + 1)
|
||||||
avg_stop_loss /= (num_iter + 1)
|
avg_stop_loss /= (num_iter + 1)
|
||||||
avg_total_loss = avg_decoder_loss + avg_postnet_loss + avg_stop_loss
|
avg_total_loss = avg_decoder_loss + avg_postnet_loss + avg_stop_loss
|
||||||
avg_step_time /= (num_iter + 1)
|
avg_step_time /= (num_iter + 1)
|
||||||
|
avg_loader_time /= (num_iter + 1)
|
||||||
|
|
||||||
# print epoch stats
|
# print epoch stats
|
||||||
print(
|
print(
|
||||||
|
@ -267,7 +260,8 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
"AvgStopLoss:{:.5f} EpochTime:{:.2f} "
|
"AvgStopLoss:{:.5f} EpochTime:{:.2f} "
|
||||||
"AvgStepTime:{:.2f}".format(current_step, avg_total_loss,
|
"AvgStepTime:{:.2f}".format(current_step, avg_total_loss,
|
||||||
avg_postnet_loss, avg_decoder_loss,
|
avg_postnet_loss, avg_decoder_loss,
|
||||||
avg_stop_loss, epoch_time, avg_step_time),
|
avg_stop_loss, epoch_time, avg_step_time,
|
||||||
|
avg_loader_time),
|
||||||
flush=True)
|
flush=True)
|
||||||
|
|
||||||
# Plot Epoch Stats
|
# Plot Epoch Stats
|
||||||
|
@ -281,6 +275,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
if c.tb_model_param_stats:
|
if c.tb_model_param_stats:
|
||||||
tb_logger.tb_model_weights(model, current_step)
|
tb_logger.tb_model_weights(model, current_step)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
return avg_postnet_loss, current_step
|
return avg_postnet_loss, current_step
|
||||||
|
|
||||||
|
|
||||||
|
@ -541,9 +536,10 @@ def main(args): #pylint: disable=redefined-outer-name
|
||||||
current_step = 0
|
current_step = 0
|
||||||
for epoch in range(0, c.epochs):
|
for epoch in range(0, c.epochs):
|
||||||
# set gradual training
|
# set gradual training
|
||||||
r, c.batch_size = gradual_training_scheduler(current_step)
|
if c.gradual_training is not None:
|
||||||
c.r = r
|
r, c.batch_size = gradual_training_scheduler(current_step, c)
|
||||||
model.decoder._set_r(r)
|
c.r = r
|
||||||
|
model.decoder._set_r(r)
|
||||||
print(" > Number of outputs per iteration:", model.decoder.r)
|
print(" > Number of outputs per iteration:", model.decoder.r)
|
||||||
|
|
||||||
train_loss, current_step = train(model, criterion, criterion_st,
|
train_loss, current_step = train(model, criterion, criterion_st,
|
||||||
|
@ -592,7 +588,7 @@ if __name__ == '__main__':
|
||||||
'--output_folder',
|
'--output_folder',
|
||||||
type=str,
|
type=str,
|
||||||
default='',
|
default='',
|
||||||
help='folder name for traning outputs.'
|
help='folder name for training outputs.'
|
||||||
)
|
)
|
||||||
|
|
||||||
# DISTRUBUTED
|
# DISTRUBUTED
|
||||||
|
|
|
@ -305,3 +305,10 @@ def split_dataset(items):
|
||||||
else:
|
else:
|
||||||
return items[:eval_split_size], items[eval_split_size:]
|
return items[:eval_split_size], items[eval_split_size:]
|
||||||
|
|
||||||
|
|
||||||
|
def gradual_training_scheduler(global_step, config):
|
||||||
|
new_values = None
|
||||||
|
for values in config.gradual_training:
|
||||||
|
if global_step >= values[0]:
|
||||||
|
new_values = values
|
||||||
|
return new_values[1], new_values[2]
|
Loading…
Reference in New Issue