Revert train.py

This commit is contained in:
Eren Golge 2018-05-25 05:17:08 -07:00
parent 24644b20d4
commit 65ea7b0afb
1 changed files with 40 additions and 48 deletions

View File

@ -13,16 +13,18 @@ import numpy as np
import torch.nn as nn
from torch import optim
from torch import onnx
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tensorboardX import SummaryWriter
from utils.generic_utils import (Progbar, remove_experiment_folder,
create_experiment_folder, save_checkpoint,
save_best_model, load_config, lr_decay,
count_parameters, check_update, get_commit_hash,
create_attn_mask, mk_decay)
count_parameters, check_update, get_commit_hash)
from utils.model import get_param_size
from utils.visual import plot_alignment, plot_spectrogram
from datasets.LJSpeech import LJSpeechDataset
from models.tacotron import Tacotron
from layers.losses import L1LossMasked
@ -65,15 +67,15 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
avg_stop_loss = 0
print(" | > Epoch {}/{}".format(epoch, c.epochs))
progbar = Progbar(len(data_loader.dataset) / c.batch_size)
progbar_display = {}
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
for num_iter, data in enumerate(data_loader):
start_time = time.time()
# setup input data
text_input = data[0]
text_lengths = data[1]
linear_spec = data[2]
mel_spec = data[3]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_targets = data[5]
@ -138,12 +140,6 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
step_time = time.time() - start_time
epoch_time += step_time
progbar_display['total_loss'] = loss.item()
progbar_display['linear_loss'] = linear_loss.item()
progbar_display['mel_loss'] = mel_loss.item()
progbar_display['stop_loss'] = stop_loss.item()
progbar_display['grad_norm'] = grad_norm.item()
# update
progbar.update(num_iter+1, values=[('total_loss', loss.item()),
@ -208,7 +204,6 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
# Plot Training Epoch Stats
tb.add_scalar('TrainEpochLoss/TotalLoss', avg_total_loss, current_step)
tb.add_scalar('TrainEpochLoss/LinearLoss', avg_linear_loss, current_step)
tb.add_scalar('TrainEpochLoss/StopLoss', avg_stop_loss, current_step)
tb.add_scalar('TrainEpochLoss/MelLoss', avg_mel_loss, current_step)
tb.add_scalar('TrainEpochLoss/StopLoss', avg_stop_loss, current_step)
tb.add_scalar('Time/EpochTime', epoch_time, epoch)
@ -277,7 +272,7 @@ def evaluate(model, criterion, criterion_st, data_loader, current_step):
avg_stop_loss += stop_loss.item()
# Diagnostic visualizations
idx = np.random.randint(mel_spec.shape[0])
idx = np.random.randint(mel_input.shape[0])
const_spec = linear_output[idx].data.cpu().numpy()
gt_spec = linear_input[idx].data.cpu().numpy()
align_img = alignments[idx].data.cpu().numpy()
@ -314,50 +309,48 @@ def evaluate(model, criterion, criterion_st, data_loader, current_step):
tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss, current_step)
tb.add_scalar('ValEpochLoss/MelLoss', avg_mel_loss, current_step)
tb.add_scalar('ValEpochLoss/Stop_loss', avg_stop_loss, current_step)
return avg_linear_loss
def main(args):
print(" > Using dataset: {}".format(c.dataset))
mod = importlib.import_module('datasets.{}'.format(c.dataset))
Dataset = getattr(mod, c.dataset+"Dataset")
# Setup the dataset
train_dataset = Dataset(os.path.join(c.data_path, c.meta_file_train),
os.path.join(c.data_path, 'wavs'),
c.r,
c.sample_rate,
c.text_cleaner,
c.num_mels,
c.min_level_db,
c.frame_shift_ms,
c.frame_length_ms,
c.preemphasis,
c.ref_level_db,
c.num_freq,
c.power,
min_seq_len=c.min_seq_len
)
train_dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata_train.csv'),
os.path.join(c.data_path, 'wavs'),
c.r,
c.sample_rate,
c.text_cleaner,
c.num_mels,
c.min_level_db,
c.frame_shift_ms,
c.frame_length_ms,
c.preemphasis,
c.ref_level_db,
c.num_freq,
c.power,
min_seq_len=c.min_seq_len
)
train_loader = DataLoader(train_dataset, batch_size=c.batch_size,
shuffle=False, collate_fn=train_dataset.collate_fn,
drop_last=True, num_workers=c.num_loader_workers,
drop_last=False, num_workers=c.num_loader_workers,
pin_memory=True)
val_dataset = Dataset(os.path.join(c.data_path, c.meta_file_val),
os.path.join(c.data_path, 'wavs'),
c.r,
c.sample_rate,
c.text_cleaner,
c.num_mels,
c.min_level_db,
c.frame_shift_ms,
c.frame_length_ms,
c.preemphasis,
c.ref_level_db,
c.num_freq,
c.power
)
val_dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata_val.csv'),
os.path.join(c.data_path, 'wavs'),
c.r,
c.sample_rate,
c.text_cleaner,
c.num_mels,
c.min_level_db,
c.frame_shift_ms,
c.frame_length_ms,
c.preemphasis,
c.ref_level_db,
c.num_freq,
c.power
)
val_loader = DataLoader(val_dataset, batch_size=c.eval_batch_size,
shuffle=False, collate_fn=val_dataset.collate_fn,
@ -392,8 +385,7 @@ def main(args):
optimizer_st = optim.Adam(model.decoder.stopnet.parameters(), lr=c.lr)
else:
args.restore_step = 0
optimizer = optim.Adam(model.parameters(), lr=c.lr)
print(" > Starting a new training")
print("\n > Starting a new training")
if use_cuda:
model = nn.DataParallel(model.cuda())