mirror of https://github.com/coqui-ai/TTS.git
print max lengths in tacotron training
This commit is contained in:
parent
1229554c42
commit
d8c1b5b73d
|
@ -86,8 +86,8 @@ def format_data(data, speaker_mapping=None):
|
||||||
mel_input = data[4]
|
mel_input = data[4]
|
||||||
mel_lengths = data[5]
|
mel_lengths = data[5]
|
||||||
stop_targets = data[6]
|
stop_targets = data[6]
|
||||||
avg_text_length = torch.mean(text_lengths.float())
|
max_text_length = torch.max(text_lengths.float())
|
||||||
avg_spec_length = torch.mean(mel_lengths.float())
|
max_spec_length = torch.max(mel_lengths.float())
|
||||||
|
|
||||||
if c.use_speaker_embedding:
|
if c.use_speaker_embedding:
|
||||||
if c.use_external_speaker_embedding_file:
|
if c.use_external_speaker_embedding_file:
|
||||||
|
@ -123,7 +123,7 @@ def format_data(data, speaker_mapping=None):
|
||||||
if speaker_embeddings is not None:
|
if speaker_embeddings is not None:
|
||||||
speaker_embeddings = speaker_embeddings.cuda(non_blocking=True)
|
speaker_embeddings = speaker_embeddings.cuda(non_blocking=True)
|
||||||
|
|
||||||
return text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, avg_text_length, avg_spec_length
|
return text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, max_text_length, max_spec_length
|
||||||
|
|
||||||
|
|
||||||
def train(model, criterion, optimizer, optimizer_st, scheduler,
|
def train(model, criterion, optimizer, optimizer_st, scheduler,
|
||||||
|
@ -144,7 +144,7 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# format data
|
# format data
|
||||||
text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, avg_text_length, avg_spec_length = format_data(data, speaker_mapping)
|
text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, max_text_length, max_spec_length = format_data(data, speaker_mapping)
|
||||||
loader_time = time.time() - end_time
|
loader_time = time.time() - end_time
|
||||||
|
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
@ -255,8 +255,8 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
|
||||||
# print training progress
|
# print training progress
|
||||||
if global_step % c.print_step == 0:
|
if global_step % c.print_step == 0:
|
||||||
log_dict = {
|
log_dict = {
|
||||||
"avg_spec_length": [avg_spec_length, 1], # value, precision
|
"max_spec_length": [max_spec_length, 1], # value, precision
|
||||||
"avg_text_length": [avg_text_length, 1],
|
"max_text_length": [max_text_length, 1],
|
||||||
"step_time": [step_time, 4],
|
"step_time": [step_time, 4],
|
||||||
"loader_time": [loader_time, 2],
|
"loader_time": [loader_time, 2],
|
||||||
"current_lr": current_lr,
|
"current_lr": current_lr,
|
||||||
|
|
Loading…
Reference in New Issue