print max lengths in tacotron training

This commit is contained in:
erogol 2020-11-25 14:49:07 +01:00
parent 1229554c42
commit d8c1b5b73d
1 changed files with 6 additions and 6 deletions

View File

@ -86,8 +86,8 @@ def format_data(data, speaker_mapping=None):
mel_input = data[4]
mel_lengths = data[5]
stop_targets = data[6]
avg_text_length = torch.mean(text_lengths.float())
avg_spec_length = torch.mean(mel_lengths.float())
max_text_length = torch.max(text_lengths.float())
max_spec_length = torch.max(mel_lengths.float())
if c.use_speaker_embedding:
if c.use_external_speaker_embedding_file:
@ -123,7 +123,7 @@ def format_data(data, speaker_mapping=None):
if speaker_embeddings is not None:
speaker_embeddings = speaker_embeddings.cuda(non_blocking=True)
return text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, avg_text_length, avg_spec_length
return text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, max_text_length, max_spec_length
def train(model, criterion, optimizer, optimizer_st, scheduler,
@ -144,7 +144,7 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
start_time = time.time()
# format data
text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, avg_text_length, avg_spec_length = format_data(data, speaker_mapping)
text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, max_text_length, max_spec_length = format_data(data, speaker_mapping)
loader_time = time.time() - end_time
global_step += 1
@ -255,8 +255,8 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
# print training progress
if global_step % c.print_step == 0:
log_dict = {
"avg_spec_length": [avg_spec_length, 1], # value, precision
"avg_text_length": [avg_text_length, 1],
"max_spec_length": [max_spec_length, 1], # value, precision
"max_text_length": [max_text_length, 1],
"step_time": [step_time, 4],
"loader_time": [loader_time, 2],
"current_lr": current_lr,