diff --git a/datasets/preprocess.py b/datasets/preprocess.py index 0aabbdf6..89d464a0 100644 --- a/datasets/preprocess.py +++ b/datasets/preprocess.py @@ -55,6 +55,6 @@ def nancy(root_path, meta_file): id = line.split()[1] text = line[line.find('"')+1:line.rfind('"')-1] wav_file = root_path + 'wavn/' + id + '.wav' - items.append(text, wav_file) + items.append([text, wav_file]) random.shuffle(items) return items \ No newline at end of file diff --git a/train.py b/train.py index 3613d1dc..6d4e7975 100644 --- a/train.py +++ b/train.py @@ -51,6 +51,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, mel_input = data[3] mel_lengths = data[4] stop_targets = data[5] + avg_text_length = torch.mean(text_lengths.float()) # set stop targets view, we predict a single stop token per r frames prediction stop_targets = stop_targets.view(text_input.shape[0], @@ -68,13 +69,13 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, # dispatch data to GPU if use_cuda: - text_input = text_input.cuda() - text_lengths = text_lengths.cuda() - mel_input = mel_input.cuda() - mel_lengths = mel_lengths.cuda() - linear_input = linear_input.cuda() - stop_targets = stop_targets.cuda() - + text_input = text_input.cuda(non_blocking=True) + text_lengths = text_lengths.cuda(non_blocking=True) + mel_input = mel_input.cuda(non_blocking=True) + mel_lengths = mel_lengths.cuda(non_blocking=True) + linear_input = linear_input.cuda(non_blocking=True) + stop_targets = stop_targets.cuda(non_blocking=True) + # compute mask for padding mask = sequence_mask(text_lengths) @@ -129,10 +130,10 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, print( " | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} " "MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} " - "GradNormST:{:.5f} StepTime:{:.2f} LR:{:.6f}".format( + "GradNormST:{:.5f} AvgTextLen:{:.1f} StepTime:{:.2f} LR:{:.6f}".format( num_iter, batch_n_iter, current_step, loss.item(), linear_loss.item(), mel_loss.item(), stop_loss.item(), - grad_norm, grad_norm_st, step_time, current_lr), + grad_norm, grad_norm_st, avg_text_length, step_time, current_lr), flush=True) avg_linear_loss += linear_loss.item()