Revert train.py

This commit is contained in:
Eren Golge 2018-05-25 05:17:08 -07:00
parent 24644b20d4
commit 65ea7b0afb
1 changed files with 40 additions and 48 deletions

View File

@ -13,16 +13,18 @@ import numpy as np
import torch.nn as nn import torch.nn as nn
from torch import optim from torch import optim
from torch import onnx
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from utils.generic_utils import (Progbar, remove_experiment_folder, from utils.generic_utils import (Progbar, remove_experiment_folder,
create_experiment_folder, save_checkpoint, create_experiment_folder, save_checkpoint,
save_best_model, load_config, lr_decay, save_best_model, load_config, lr_decay,
count_parameters, check_update, get_commit_hash, count_parameters, check_update, get_commit_hash)
create_attn_mask, mk_decay)
from utils.model import get_param_size from utils.model import get_param_size
from utils.visual import plot_alignment, plot_spectrogram from utils.visual import plot_alignment, plot_spectrogram
from datasets.LJSpeech import LJSpeechDataset
from models.tacotron import Tacotron from models.tacotron import Tacotron
from layers.losses import L1LossMasked from layers.losses import L1LossMasked
@ -65,15 +67,15 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
avg_stop_loss = 0 avg_stop_loss = 0
print(" | > Epoch {}/{}".format(epoch, c.epochs)) print(" | > Epoch {}/{}".format(epoch, c.epochs))
progbar = Progbar(len(data_loader.dataset) / c.batch_size) progbar = Progbar(len(data_loader.dataset) / c.batch_size)
progbar_display = {} n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
for num_iter, data in enumerate(data_loader): for num_iter, data in enumerate(data_loader):
start_time = time.time() start_time = time.time()
# setup input data # setup input data
text_input = data[0] text_input = data[0]
text_lengths = data[1] text_lengths = data[1]
linear_spec = data[2] linear_input = data[2]
mel_spec = data[3] mel_input = data[3]
mel_lengths = data[4] mel_lengths = data[4]
stop_targets = data[5] stop_targets = data[5]
@ -138,12 +140,6 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
step_time = time.time() - start_time step_time = time.time() - start_time
epoch_time += step_time epoch_time += step_time
progbar_display['total_loss'] = loss.item()
progbar_display['linear_loss'] = linear_loss.item()
progbar_display['mel_loss'] = mel_loss.item()
progbar_display['stop_loss'] = stop_loss.item()
progbar_display['grad_norm'] = grad_norm.item()
# update # update
progbar.update(num_iter+1, values=[('total_loss', loss.item()), progbar.update(num_iter+1, values=[('total_loss', loss.item()),
@ -208,7 +204,6 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
# Plot Training Epoch Stats # Plot Training Epoch Stats
tb.add_scalar('TrainEpochLoss/TotalLoss', avg_total_loss, current_step) tb.add_scalar('TrainEpochLoss/TotalLoss', avg_total_loss, current_step)
tb.add_scalar('TrainEpochLoss/LinearLoss', avg_linear_loss, current_step) tb.add_scalar('TrainEpochLoss/LinearLoss', avg_linear_loss, current_step)
tb.add_scalar('TrainEpochLoss/StopLoss', avg_stop_loss, current_step)
tb.add_scalar('TrainEpochLoss/MelLoss', avg_mel_loss, current_step) tb.add_scalar('TrainEpochLoss/MelLoss', avg_mel_loss, current_step)
tb.add_scalar('TrainEpochLoss/StopLoss', avg_stop_loss, current_step) tb.add_scalar('TrainEpochLoss/StopLoss', avg_stop_loss, current_step)
tb.add_scalar('Time/EpochTime', epoch_time, epoch) tb.add_scalar('Time/EpochTime', epoch_time, epoch)
@ -277,7 +272,7 @@ def evaluate(model, criterion, criterion_st, data_loader, current_step):
avg_stop_loss += stop_loss.item() avg_stop_loss += stop_loss.item()
# Diagnostic visualizations # Diagnostic visualizations
idx = np.random.randint(mel_spec.shape[0]) idx = np.random.randint(mel_input.shape[0])
const_spec = linear_output[idx].data.cpu().numpy() const_spec = linear_output[idx].data.cpu().numpy()
gt_spec = linear_input[idx].data.cpu().numpy() gt_spec = linear_input[idx].data.cpu().numpy()
align_img = alignments[idx].data.cpu().numpy() align_img = alignments[idx].data.cpu().numpy()
@ -314,50 +309,48 @@ def evaluate(model, criterion, criterion_st, data_loader, current_step):
tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss, current_step) tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss, current_step)
tb.add_scalar('ValEpochLoss/MelLoss', avg_mel_loss, current_step) tb.add_scalar('ValEpochLoss/MelLoss', avg_mel_loss, current_step)
tb.add_scalar('ValEpochLoss/Stop_loss', avg_stop_loss, current_step) tb.add_scalar('ValEpochLoss/Stop_loss', avg_stop_loss, current_step)
return avg_linear_loss return avg_linear_loss
def main(args): def main(args):
print(" > Using dataset: {}".format(c.dataset))
mod = importlib.import_module('datasets.{}'.format(c.dataset))
Dataset = getattr(mod, c.dataset+"Dataset")
# Setup the dataset # Setup the dataset
train_dataset = Dataset(os.path.join(c.data_path, c.meta_file_train), train_dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata_train.csv'),
os.path.join(c.data_path, 'wavs'), os.path.join(c.data_path, 'wavs'),
c.r, c.r,
c.sample_rate, c.sample_rate,
c.text_cleaner, c.text_cleaner,
c.num_mels, c.num_mels,
c.min_level_db, c.min_level_db,
c.frame_shift_ms, c.frame_shift_ms,
c.frame_length_ms, c.frame_length_ms,
c.preemphasis, c.preemphasis,
c.ref_level_db, c.ref_level_db,
c.num_freq, c.num_freq,
c.power, c.power,
min_seq_len=c.min_seq_len min_seq_len=c.min_seq_len
) )
train_loader = DataLoader(train_dataset, batch_size=c.batch_size, train_loader = DataLoader(train_dataset, batch_size=c.batch_size,
shuffle=False, collate_fn=train_dataset.collate_fn, shuffle=False, collate_fn=train_dataset.collate_fn,
drop_last=True, num_workers=c.num_loader_workers, drop_last=False, num_workers=c.num_loader_workers,
pin_memory=True) pin_memory=True)
val_dataset = Dataset(os.path.join(c.data_path, c.meta_file_val), val_dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata_val.csv'),
os.path.join(c.data_path, 'wavs'), os.path.join(c.data_path, 'wavs'),
c.r, c.r,
c.sample_rate, c.sample_rate,
c.text_cleaner, c.text_cleaner,
c.num_mels, c.num_mels,
c.min_level_db, c.min_level_db,
c.frame_shift_ms, c.frame_shift_ms,
c.frame_length_ms, c.frame_length_ms,
c.preemphasis, c.preemphasis,
c.ref_level_db, c.ref_level_db,
c.num_freq, c.num_freq,
c.power c.power
) )
val_loader = DataLoader(val_dataset, batch_size=c.eval_batch_size, val_loader = DataLoader(val_dataset, batch_size=c.eval_batch_size,
shuffle=False, collate_fn=val_dataset.collate_fn, shuffle=False, collate_fn=val_dataset.collate_fn,
@ -392,8 +385,7 @@ def main(args):
optimizer_st = optim.Adam(model.decoder.stopnet.parameters(), lr=c.lr) optimizer_st = optim.Adam(model.decoder.stopnet.parameters(), lr=c.lr)
else: else:
args.restore_step = 0 args.restore_step = 0
optimizer = optim.Adam(model.parameters(), lr=c.lr) print("\n > Starting a new training")
print(" > Starting a new training")
if use_cuda: if use_cuda:
model = nn.DataParallel(model.cuda()) model = nn.DataParallel(model.cuda())