diff --git a/train.py b/train.py index 241ba22e..f570b7fe 100644 --- a/train.py +++ b/train.py @@ -13,16 +13,18 @@ import numpy as np import torch.nn as nn from torch import optim +from torch import onnx from torch.utils.data import DataLoader +from torch.optim.lr_scheduler import ReduceLROnPlateau from tensorboardX import SummaryWriter from utils.generic_utils import (Progbar, remove_experiment_folder, create_experiment_folder, save_checkpoint, save_best_model, load_config, lr_decay, - count_parameters, check_update, get_commit_hash, - create_attn_mask, mk_decay) + count_parameters, check_update, get_commit_hash) from utils.model import get_param_size from utils.visual import plot_alignment, plot_spectrogram +from datasets.LJSpeech import LJSpeechDataset from models.tacotron import Tacotron from layers.losses import L1LossMasked @@ -65,15 +67,15 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, avg_stop_loss = 0 print(" | > Epoch {}/{}".format(epoch, c.epochs)) progbar = Progbar(len(data_loader.dataset) / c.batch_size) - progbar_display = {} + n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq) for num_iter, data in enumerate(data_loader): start_time = time.time() # setup input data text_input = data[0] text_lengths = data[1] - linear_spec = data[2] - mel_spec = data[3] + linear_input = data[2] + mel_input = data[3] mel_lengths = data[4] stop_targets = data[5] @@ -138,12 +140,6 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, step_time = time.time() - start_time epoch_time += step_time - - progbar_display['total_loss'] = loss.item() - progbar_display['linear_loss'] = linear_loss.item() - progbar_display['mel_loss'] = mel_loss.item() - progbar_display['stop_loss'] = stop_loss.item() - progbar_display['grad_norm'] = grad_norm.item() # update progbar.update(num_iter+1, values=[('total_loss', loss.item()), @@ -208,7 +204,6 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, # Plot Training Epoch Stats tb.add_scalar('TrainEpochLoss/TotalLoss', avg_total_loss, current_step) tb.add_scalar('TrainEpochLoss/LinearLoss', avg_linear_loss, current_step) - tb.add_scalar('TrainEpochLoss/StopLoss', avg_stop_loss, current_step) tb.add_scalar('TrainEpochLoss/MelLoss', avg_mel_loss, current_step) tb.add_scalar('TrainEpochLoss/StopLoss', avg_stop_loss, current_step) tb.add_scalar('Time/EpochTime', epoch_time, epoch) @@ -277,7 +272,7 @@ def evaluate(model, criterion, criterion_st, data_loader, current_step): avg_stop_loss += stop_loss.item() # Diagnostic visualizations - idx = np.random.randint(mel_spec.shape[0]) + idx = np.random.randint(mel_input.shape[0]) const_spec = linear_output[idx].data.cpu().numpy() gt_spec = linear_input[idx].data.cpu().numpy() align_img = alignments[idx].data.cpu().numpy() @@ -314,50 +309,48 @@ def evaluate(model, criterion, criterion_st, data_loader, current_step): tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss, current_step) tb.add_scalar('ValEpochLoss/MelLoss', avg_mel_loss, current_step) tb.add_scalar('ValEpochLoss/Stop_loss', avg_stop_loss, current_step) + return avg_linear_loss def main(args): - print(" > Using dataset: {}".format(c.dataset)) - mod = importlib.import_module('datasets.{}'.format(c.dataset)) - Dataset = getattr(mod, c.dataset+"Dataset") # Setup the dataset - train_dataset = Dataset(os.path.join(c.data_path, c.meta_file_train), - os.path.join(c.data_path, 'wavs'), - c.r, - c.sample_rate, - c.text_cleaner, - c.num_mels, - c.min_level_db, - c.frame_shift_ms, - c.frame_length_ms, - c.preemphasis, - c.ref_level_db, - c.num_freq, - c.power, - min_seq_len=c.min_seq_len - ) + train_dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata_train.csv'), + os.path.join(c.data_path, 'wavs'), + c.r, + c.sample_rate, + c.text_cleaner, + c.num_mels, + c.min_level_db, + c.frame_shift_ms, + c.frame_length_ms, + c.preemphasis, + c.ref_level_db, + c.num_freq, + c.power, + min_seq_len=c.min_seq_len + ) train_loader = DataLoader(train_dataset, batch_size=c.batch_size, shuffle=False, collate_fn=train_dataset.collate_fn, - drop_last=True, num_workers=c.num_loader_workers, + drop_last=False, num_workers=c.num_loader_workers, pin_memory=True) - val_dataset = Dataset(os.path.join(c.data_path, c.meta_file_val), - os.path.join(c.data_path, 'wavs'), - c.r, - c.sample_rate, - c.text_cleaner, - c.num_mels, - c.min_level_db, - c.frame_shift_ms, - c.frame_length_ms, - c.preemphasis, - c.ref_level_db, - c.num_freq, - c.power - ) + val_dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata_val.csv'), + os.path.join(c.data_path, 'wavs'), + c.r, + c.sample_rate, + c.text_cleaner, + c.num_mels, + c.min_level_db, + c.frame_shift_ms, + c.frame_length_ms, + c.preemphasis, + c.ref_level_db, + c.num_freq, + c.power + ) val_loader = DataLoader(val_dataset, batch_size=c.eval_batch_size, shuffle=False, collate_fn=val_dataset.collate_fn, @@ -392,8 +385,7 @@ def main(args): optimizer_st = optim.Adam(model.decoder.stopnet.parameters(), lr=c.lr) else: args.restore_step = 0 - optimizer = optim.Adam(model.parameters(), lr=c.lr) - print(" > Starting a new training") + print("\n > Starting a new training") if use_cuda: model = nn.DataParallel(model.cuda())