mirror of https://github.com/coqui-ai/TTS.git
Revert train.py
This commit is contained in:
parent
24644b20d4
commit
65ea7b0afb
88
train.py
88
train.py
|
@ -13,16 +13,18 @@ import numpy as np
|
|||
|
||||
import torch.nn as nn
|
||||
from torch import optim
|
||||
from torch import onnx
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
from utils.generic_utils import (Progbar, remove_experiment_folder,
|
||||
create_experiment_folder, save_checkpoint,
|
||||
save_best_model, load_config, lr_decay,
|
||||
count_parameters, check_update, get_commit_hash,
|
||||
create_attn_mask, mk_decay)
|
||||
count_parameters, check_update, get_commit_hash)
|
||||
from utils.model import get_param_size
|
||||
from utils.visual import plot_alignment, plot_spectrogram
|
||||
from datasets.LJSpeech import LJSpeechDataset
|
||||
from models.tacotron import Tacotron
|
||||
from layers.losses import L1LossMasked
|
||||
|
||||
|
@ -65,15 +67,15 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
|||
avg_stop_loss = 0
|
||||
print(" | > Epoch {}/{}".format(epoch, c.epochs))
|
||||
progbar = Progbar(len(data_loader.dataset) / c.batch_size)
|
||||
progbar_display = {}
|
||||
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
# setup input data
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_spec = data[2]
|
||||
mel_spec = data[3]
|
||||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
stop_targets = data[5]
|
||||
|
||||
|
@ -138,12 +140,6 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
|||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
progbar_display['total_loss'] = loss.item()
|
||||
progbar_display['linear_loss'] = linear_loss.item()
|
||||
progbar_display['mel_loss'] = mel_loss.item()
|
||||
progbar_display['stop_loss'] = stop_loss.item()
|
||||
progbar_display['grad_norm'] = grad_norm.item()
|
||||
|
||||
# update
|
||||
progbar.update(num_iter+1, values=[('total_loss', loss.item()),
|
||||
|
@ -208,7 +204,6 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
|||
# Plot Training Epoch Stats
|
||||
tb.add_scalar('TrainEpochLoss/TotalLoss', avg_total_loss, current_step)
|
||||
tb.add_scalar('TrainEpochLoss/LinearLoss', avg_linear_loss, current_step)
|
||||
tb.add_scalar('TrainEpochLoss/StopLoss', avg_stop_loss, current_step)
|
||||
tb.add_scalar('TrainEpochLoss/MelLoss', avg_mel_loss, current_step)
|
||||
tb.add_scalar('TrainEpochLoss/StopLoss', avg_stop_loss, current_step)
|
||||
tb.add_scalar('Time/EpochTime', epoch_time, epoch)
|
||||
|
@ -277,7 +272,7 @@ def evaluate(model, criterion, criterion_st, data_loader, current_step):
|
|||
avg_stop_loss += stop_loss.item()
|
||||
|
||||
# Diagnostic visualizations
|
||||
idx = np.random.randint(mel_spec.shape[0])
|
||||
idx = np.random.randint(mel_input.shape[0])
|
||||
const_spec = linear_output[idx].data.cpu().numpy()
|
||||
gt_spec = linear_input[idx].data.cpu().numpy()
|
||||
align_img = alignments[idx].data.cpu().numpy()
|
||||
|
@ -314,50 +309,48 @@ def evaluate(model, criterion, criterion_st, data_loader, current_step):
|
|||
tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss, current_step)
|
||||
tb.add_scalar('ValEpochLoss/MelLoss', avg_mel_loss, current_step)
|
||||
tb.add_scalar('ValEpochLoss/Stop_loss', avg_stop_loss, current_step)
|
||||
|
||||
return avg_linear_loss
|
||||
|
||||
|
||||
def main(args):
|
||||
print(" > Using dataset: {}".format(c.dataset))
|
||||
mod = importlib.import_module('datasets.{}'.format(c.dataset))
|
||||
Dataset = getattr(mod, c.dataset+"Dataset")
|
||||
|
||||
# Setup the dataset
|
||||
train_dataset = Dataset(os.path.join(c.data_path, c.meta_file_train),
|
||||
os.path.join(c.data_path, 'wavs'),
|
||||
c.r,
|
||||
c.sample_rate,
|
||||
c.text_cleaner,
|
||||
c.num_mels,
|
||||
c.min_level_db,
|
||||
c.frame_shift_ms,
|
||||
c.frame_length_ms,
|
||||
c.preemphasis,
|
||||
c.ref_level_db,
|
||||
c.num_freq,
|
||||
c.power,
|
||||
min_seq_len=c.min_seq_len
|
||||
)
|
||||
train_dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata_train.csv'),
|
||||
os.path.join(c.data_path, 'wavs'),
|
||||
c.r,
|
||||
c.sample_rate,
|
||||
c.text_cleaner,
|
||||
c.num_mels,
|
||||
c.min_level_db,
|
||||
c.frame_shift_ms,
|
||||
c.frame_length_ms,
|
||||
c.preemphasis,
|
||||
c.ref_level_db,
|
||||
c.num_freq,
|
||||
c.power,
|
||||
min_seq_len=c.min_seq_len
|
||||
)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=c.batch_size,
|
||||
shuffle=False, collate_fn=train_dataset.collate_fn,
|
||||
drop_last=True, num_workers=c.num_loader_workers,
|
||||
drop_last=False, num_workers=c.num_loader_workers,
|
||||
pin_memory=True)
|
||||
|
||||
val_dataset = Dataset(os.path.join(c.data_path, c.meta_file_val),
|
||||
os.path.join(c.data_path, 'wavs'),
|
||||
c.r,
|
||||
c.sample_rate,
|
||||
c.text_cleaner,
|
||||
c.num_mels,
|
||||
c.min_level_db,
|
||||
c.frame_shift_ms,
|
||||
c.frame_length_ms,
|
||||
c.preemphasis,
|
||||
c.ref_level_db,
|
||||
c.num_freq,
|
||||
c.power
|
||||
)
|
||||
val_dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata_val.csv'),
|
||||
os.path.join(c.data_path, 'wavs'),
|
||||
c.r,
|
||||
c.sample_rate,
|
||||
c.text_cleaner,
|
||||
c.num_mels,
|
||||
c.min_level_db,
|
||||
c.frame_shift_ms,
|
||||
c.frame_length_ms,
|
||||
c.preemphasis,
|
||||
c.ref_level_db,
|
||||
c.num_freq,
|
||||
c.power
|
||||
)
|
||||
|
||||
val_loader = DataLoader(val_dataset, batch_size=c.eval_batch_size,
|
||||
shuffle=False, collate_fn=val_dataset.collate_fn,
|
||||
|
@ -392,8 +385,7 @@ def main(args):
|
|||
optimizer_st = optim.Adam(model.decoder.stopnet.parameters(), lr=c.lr)
|
||||
else:
|
||||
args.restore_step = 0
|
||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||
print(" > Starting a new training")
|
||||
print("\n > Starting a new training")
|
||||
|
||||
if use_cuda:
|
||||
model = nn.DataParallel(model.cuda())
|
||||
|
|
Loading…
Reference in New Issue