From d8c1b5b73d51504a60bb32953e22794756ac8c57 Mon Sep 17 00:00:00 2001 From: erogol Date: Wed, 25 Nov 2020 14:49:07 +0100 Subject: [PATCH] print max lengths in tacotron training --- TTS/bin/train_tacotron.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/TTS/bin/train_tacotron.py b/TTS/bin/train_tacotron.py index 6c12e54b..1263a616 100644 --- a/TTS/bin/train_tacotron.py +++ b/TTS/bin/train_tacotron.py @@ -86,8 +86,8 @@ def format_data(data, speaker_mapping=None): mel_input = data[4] mel_lengths = data[5] stop_targets = data[6] - avg_text_length = torch.mean(text_lengths.float()) - avg_spec_length = torch.mean(mel_lengths.float()) + max_text_length = torch.max(text_lengths.float()) + max_spec_length = torch.max(mel_lengths.float()) if c.use_speaker_embedding: if c.use_external_speaker_embedding_file: @@ -123,7 +123,7 @@ def format_data(data, speaker_mapping=None): if speaker_embeddings is not None: speaker_embeddings = speaker_embeddings.cuda(non_blocking=True) - return text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, avg_text_length, avg_spec_length + return text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, max_text_length, max_spec_length def train(model, criterion, optimizer, optimizer_st, scheduler, @@ -144,7 +144,7 @@ def train(model, criterion, optimizer, optimizer_st, scheduler, start_time = time.time() # format data - text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, avg_text_length, avg_spec_length = format_data(data, speaker_mapping) + text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, max_text_length, max_spec_length = format_data(data, speaker_mapping) loader_time = time.time() - end_time global_step += 1 @@ -255,8 +255,8 @@ def train(model, criterion, optimizer, optimizer_st, scheduler, # print training progress if global_step % c.print_step == 0: log_dict = { - "avg_spec_length": [avg_spec_length, 1], # value, precision - "avg_text_length": [avg_text_length, 1], + "max_spec_length": [max_spec_length, 1], # value, precision + "max_text_length": [max_text_length, 1], "step_time": [step_time, 4], "loader_time": [loader_time, 2], "current_lr": current_lr,