mirror of https://github.com/coqui-ai/TTS.git
Revert train.py
This commit is contained in:
parent
24644b20d4
commit
65ea7b0afb
88
train.py
88
train.py
|
@ -13,16 +13,18 @@ import numpy as np
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import optim
|
from torch import optim
|
||||||
|
from torch import onnx
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
from utils.generic_utils import (Progbar, remove_experiment_folder,
|
from utils.generic_utils import (Progbar, remove_experiment_folder,
|
||||||
create_experiment_folder, save_checkpoint,
|
create_experiment_folder, save_checkpoint,
|
||||||
save_best_model, load_config, lr_decay,
|
save_best_model, load_config, lr_decay,
|
||||||
count_parameters, check_update, get_commit_hash,
|
count_parameters, check_update, get_commit_hash)
|
||||||
create_attn_mask, mk_decay)
|
|
||||||
from utils.model import get_param_size
|
from utils.model import get_param_size
|
||||||
from utils.visual import plot_alignment, plot_spectrogram
|
from utils.visual import plot_alignment, plot_spectrogram
|
||||||
|
from datasets.LJSpeech import LJSpeechDataset
|
||||||
from models.tacotron import Tacotron
|
from models.tacotron import Tacotron
|
||||||
from layers.losses import L1LossMasked
|
from layers.losses import L1LossMasked
|
||||||
|
|
||||||
|
@ -65,15 +67,15 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
avg_stop_loss = 0
|
avg_stop_loss = 0
|
||||||
print(" | > Epoch {}/{}".format(epoch, c.epochs))
|
print(" | > Epoch {}/{}".format(epoch, c.epochs))
|
||||||
progbar = Progbar(len(data_loader.dataset) / c.batch_size)
|
progbar = Progbar(len(data_loader.dataset) / c.batch_size)
|
||||||
progbar_display = {}
|
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
|
||||||
for num_iter, data in enumerate(data_loader):
|
for num_iter, data in enumerate(data_loader):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# setup input data
|
# setup input data
|
||||||
text_input = data[0]
|
text_input = data[0]
|
||||||
text_lengths = data[1]
|
text_lengths = data[1]
|
||||||
linear_spec = data[2]
|
linear_input = data[2]
|
||||||
mel_spec = data[3]
|
mel_input = data[3]
|
||||||
mel_lengths = data[4]
|
mel_lengths = data[4]
|
||||||
stop_targets = data[5]
|
stop_targets = data[5]
|
||||||
|
|
||||||
|
@ -138,12 +140,6 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
|
|
||||||
step_time = time.time() - start_time
|
step_time = time.time() - start_time
|
||||||
epoch_time += step_time
|
epoch_time += step_time
|
||||||
|
|
||||||
progbar_display['total_loss'] = loss.item()
|
|
||||||
progbar_display['linear_loss'] = linear_loss.item()
|
|
||||||
progbar_display['mel_loss'] = mel_loss.item()
|
|
||||||
progbar_display['stop_loss'] = stop_loss.item()
|
|
||||||
progbar_display['grad_norm'] = grad_norm.item()
|
|
||||||
|
|
||||||
# update
|
# update
|
||||||
progbar.update(num_iter+1, values=[('total_loss', loss.item()),
|
progbar.update(num_iter+1, values=[('total_loss', loss.item()),
|
||||||
|
@ -208,7 +204,6 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
# Plot Training Epoch Stats
|
# Plot Training Epoch Stats
|
||||||
tb.add_scalar('TrainEpochLoss/TotalLoss', avg_total_loss, current_step)
|
tb.add_scalar('TrainEpochLoss/TotalLoss', avg_total_loss, current_step)
|
||||||
tb.add_scalar('TrainEpochLoss/LinearLoss', avg_linear_loss, current_step)
|
tb.add_scalar('TrainEpochLoss/LinearLoss', avg_linear_loss, current_step)
|
||||||
tb.add_scalar('TrainEpochLoss/StopLoss', avg_stop_loss, current_step)
|
|
||||||
tb.add_scalar('TrainEpochLoss/MelLoss', avg_mel_loss, current_step)
|
tb.add_scalar('TrainEpochLoss/MelLoss', avg_mel_loss, current_step)
|
||||||
tb.add_scalar('TrainEpochLoss/StopLoss', avg_stop_loss, current_step)
|
tb.add_scalar('TrainEpochLoss/StopLoss', avg_stop_loss, current_step)
|
||||||
tb.add_scalar('Time/EpochTime', epoch_time, epoch)
|
tb.add_scalar('Time/EpochTime', epoch_time, epoch)
|
||||||
|
@ -277,7 +272,7 @@ def evaluate(model, criterion, criterion_st, data_loader, current_step):
|
||||||
avg_stop_loss += stop_loss.item()
|
avg_stop_loss += stop_loss.item()
|
||||||
|
|
||||||
# Diagnostic visualizations
|
# Diagnostic visualizations
|
||||||
idx = np.random.randint(mel_spec.shape[0])
|
idx = np.random.randint(mel_input.shape[0])
|
||||||
const_spec = linear_output[idx].data.cpu().numpy()
|
const_spec = linear_output[idx].data.cpu().numpy()
|
||||||
gt_spec = linear_input[idx].data.cpu().numpy()
|
gt_spec = linear_input[idx].data.cpu().numpy()
|
||||||
align_img = alignments[idx].data.cpu().numpy()
|
align_img = alignments[idx].data.cpu().numpy()
|
||||||
|
@ -314,50 +309,48 @@ def evaluate(model, criterion, criterion_st, data_loader, current_step):
|
||||||
tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss, current_step)
|
tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss, current_step)
|
||||||
tb.add_scalar('ValEpochLoss/MelLoss', avg_mel_loss, current_step)
|
tb.add_scalar('ValEpochLoss/MelLoss', avg_mel_loss, current_step)
|
||||||
tb.add_scalar('ValEpochLoss/Stop_loss', avg_stop_loss, current_step)
|
tb.add_scalar('ValEpochLoss/Stop_loss', avg_stop_loss, current_step)
|
||||||
|
|
||||||
return avg_linear_loss
|
return avg_linear_loss
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
print(" > Using dataset: {}".format(c.dataset))
|
|
||||||
mod = importlib.import_module('datasets.{}'.format(c.dataset))
|
|
||||||
Dataset = getattr(mod, c.dataset+"Dataset")
|
|
||||||
|
|
||||||
# Setup the dataset
|
# Setup the dataset
|
||||||
train_dataset = Dataset(os.path.join(c.data_path, c.meta_file_train),
|
train_dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata_train.csv'),
|
||||||
os.path.join(c.data_path, 'wavs'),
|
os.path.join(c.data_path, 'wavs'),
|
||||||
c.r,
|
c.r,
|
||||||
c.sample_rate,
|
c.sample_rate,
|
||||||
c.text_cleaner,
|
c.text_cleaner,
|
||||||
c.num_mels,
|
c.num_mels,
|
||||||
c.min_level_db,
|
c.min_level_db,
|
||||||
c.frame_shift_ms,
|
c.frame_shift_ms,
|
||||||
c.frame_length_ms,
|
c.frame_length_ms,
|
||||||
c.preemphasis,
|
c.preemphasis,
|
||||||
c.ref_level_db,
|
c.ref_level_db,
|
||||||
c.num_freq,
|
c.num_freq,
|
||||||
c.power,
|
c.power,
|
||||||
min_seq_len=c.min_seq_len
|
min_seq_len=c.min_seq_len
|
||||||
)
|
)
|
||||||
|
|
||||||
train_loader = DataLoader(train_dataset, batch_size=c.batch_size,
|
train_loader = DataLoader(train_dataset, batch_size=c.batch_size,
|
||||||
shuffle=False, collate_fn=train_dataset.collate_fn,
|
shuffle=False, collate_fn=train_dataset.collate_fn,
|
||||||
drop_last=True, num_workers=c.num_loader_workers,
|
drop_last=False, num_workers=c.num_loader_workers,
|
||||||
pin_memory=True)
|
pin_memory=True)
|
||||||
|
|
||||||
val_dataset = Dataset(os.path.join(c.data_path, c.meta_file_val),
|
val_dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata_val.csv'),
|
||||||
os.path.join(c.data_path, 'wavs'),
|
os.path.join(c.data_path, 'wavs'),
|
||||||
c.r,
|
c.r,
|
||||||
c.sample_rate,
|
c.sample_rate,
|
||||||
c.text_cleaner,
|
c.text_cleaner,
|
||||||
c.num_mels,
|
c.num_mels,
|
||||||
c.min_level_db,
|
c.min_level_db,
|
||||||
c.frame_shift_ms,
|
c.frame_shift_ms,
|
||||||
c.frame_length_ms,
|
c.frame_length_ms,
|
||||||
c.preemphasis,
|
c.preemphasis,
|
||||||
c.ref_level_db,
|
c.ref_level_db,
|
||||||
c.num_freq,
|
c.num_freq,
|
||||||
c.power
|
c.power
|
||||||
)
|
)
|
||||||
|
|
||||||
val_loader = DataLoader(val_dataset, batch_size=c.eval_batch_size,
|
val_loader = DataLoader(val_dataset, batch_size=c.eval_batch_size,
|
||||||
shuffle=False, collate_fn=val_dataset.collate_fn,
|
shuffle=False, collate_fn=val_dataset.collate_fn,
|
||||||
|
@ -392,8 +385,7 @@ def main(args):
|
||||||
optimizer_st = optim.Adam(model.decoder.stopnet.parameters(), lr=c.lr)
|
optimizer_st = optim.Adam(model.decoder.stopnet.parameters(), lr=c.lr)
|
||||||
else:
|
else:
|
||||||
args.restore_step = 0
|
args.restore_step = 0
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
print("\n > Starting a new training")
|
||||||
print(" > Starting a new training")
|
|
||||||
|
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
model = nn.DataParallel(model.cuda())
|
model = nn.DataParallel(model.cuda())
|
||||||
|
|
Loading…
Reference in New Issue