coqui-tts/TTS/bin/train_glow_tts.py

655 lines
25 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import glob
import os
import sys
import time
import traceback
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.nn.parallel import DistributedDataParallel as DDP
from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.layers.losses import GlowTTSLoss
from TTS.utils.console_logger import ConsoleLogger
from TTS.tts.utils.distribute import (DistributedSampler,
init_distributed,
reduce_tensor)
from TTS.tts.utils.generic_utils import check_config, setup_model
from TTS.tts.utils.io import save_best_model, save_checkpoint
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.speakers import (get_speakers,
load_speaker_mapping,
save_speaker_mapping)
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import (
KeepAverage, count_parameters, create_experiment_folder, get_git_branch,
remove_experiment_folder, set_init_dict)
from TTS.utils.io import copy_config_file, load_config
from TTS.utils.radam import RAdam
from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.utils.training import (NoamLR, adam_weight_decay,
check_update,
gradual_training_scheduler,
set_weight_decay,
setup_torch_training_env)
use_cuda, num_gpus = setup_torch_training_env(True, False)
def setup_loader(ap, r, is_val=False, verbose=False):
if is_val and not c.run_eval:
loader = None
else:
dataset = MyDataset(
r,
c.text_cleaner,
compute_linear_spec=True if c.model.lower() == 'tacotron' else False,
meta_data=meta_data_eval if is_val else meta_data_train,
ap=ap,
tp=c.characters if 'characters' in c.keys() else None,
batch_group_size=0 if is_val else c.batch_group_size *
c.batch_size,
min_seq_len=c.min_seq_len,
max_seq_len=c.max_seq_len,
phoneme_cache_path=c.phoneme_cache_path,
use_phonemes=c.use_phonemes,
phoneme_language=c.phoneme_language,
enable_eos_bos=c.enable_eos_bos_chars,
verbose=verbose)
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader(
dataset,
batch_size=c.eval_batch_size if is_val else c.batch_size,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=False,
sampler=sampler,
num_workers=c.num_val_loader_workers
if is_val else c.num_loader_workers,
pin_memory=False)
return loader
def format_data(data):
if c.use_speaker_embedding:
speaker_mapping = load_speaker_mapping(OUT_PATH)
# setup input data
text_input = data[0]
text_lengths = data[1]
speaker_names = data[2]
mel_input = data[4].permute(0, 2, 1) # B x D x T
mel_lengths = data[5]
attn_mask = data[8]
avg_text_length = torch.mean(text_lengths.float())
avg_spec_length = torch.mean(mel_lengths.float())
if c.use_speaker_embedding:
speaker_ids = [
speaker_mapping[speaker_name] for speaker_name in speaker_names
]
speaker_ids = torch.LongTensor(speaker_ids)
else:
speaker_ids = None
# dispatch data to GPU
if use_cuda:
text_input = text_input.cuda(non_blocking=True)
text_lengths = text_lengths.cuda(non_blocking=True)
mel_input = mel_input.cuda(non_blocking=True)
mel_lengths = mel_lengths.cuda(non_blocking=True)
if speaker_ids is not None:
speaker_ids = speaker_ids.cuda(non_blocking=True)
if attn_mask is not None:
attn_mask = attn_mask.cuda(non_blocking=True)
return text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\
avg_text_length, avg_spec_length, attn_mask
def data_depended_init(model, ap):
"""Data depended initialization for normalization layers."""
if hasattr(model, 'module'):
for f in model.module.decoder.flows:
if getattr(f, "set_ddi", False):
f.set_ddi(True)
else:
for f in model.decoder.flows:
if getattr(f, "set_ddi", False):
f.set_ddi(True)
data_loader = setup_loader(ap, 1, is_val=False)
model.train()
print(" > Data depended initialization ... ")
with torch.no_grad():
for num_iter, data in enumerate(data_loader):
# format data
text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\
avg_text_length, avg_spec_length, attn_mask = format_data(data)
# forward pass model
_ = model.forward(
text_input, text_lengths, mel_input, mel_lengths, attn_mask)
break
if hasattr(model, 'module'):
for f in model.module.decoder.flows:
if getattr(f, "set_ddi", False):
f.set_ddi(False)
else:
for f in model.decoder.flows:
if getattr(f, "set_ddi", False):
f.set_ddi(False)
return model
def train(model, criterion, optimizer, scheduler,
ap, global_step, epoch, amp):
data_loader = setup_loader(ap, 1, is_val=False,
verbose=(epoch == 0))
model.train()
epoch_time = 0
keep_avg = KeepAverage()
if use_cuda:
batch_n_iter = int(
len(data_loader.dataset) / (c.batch_size * num_gpus))
else:
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
end_time = time.time()
c_logger.print_train_start()
for num_iter, data in enumerate(data_loader):
start_time = time.time()
# format data
text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\
avg_text_length, avg_spec_length, attn_mask = format_data(data)
loader_time = time.time() - end_time
global_step += 1
# setup lr
if c.noam_schedule:
scheduler.step()
optimizer.zero_grad()
# forward pass model
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
text_input, text_lengths, mel_input, mel_lengths, attn_mask)
# compute loss
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
o_dur_log, o_total_dur, text_lengths)
# backward pass
if amp is not None:
with amp.scale_loss( loss_dict['loss'], optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss_dict['loss'].backward()
if amp:
amp_opt_params = amp.master_params(optimizer)
else:
amp_opt_params = None
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True, amp_opt_params=amp_opt_params)
optimizer.step()
# current_lr
current_lr = optimizer.param_groups[0]['lr']
# compute alignment error (the lower the better )
align_error = 1 - alignment_diagonal_score(alignments)
loss_dict['align_error'] = align_error
step_time = time.time() - start_time
epoch_time += step_time
# aggregate losses from processes
if num_gpus > 1:
loss_dict['log_mle'] = reduce_tensor(loss_dict['log_mle'].data, num_gpus)
loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus)
loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus)
# detach loss values
loss_dict_new = dict()
for key, value in loss_dict.items():
if isinstance(value, (int, float)):
loss_dict_new[key] = value
else:
loss_dict_new[key] = value.item()
loss_dict = loss_dict_new
# update avg stats
update_train_values = dict()
for key, value in loss_dict.items():
update_train_values['avg_' + key] = value
update_train_values['avg_loader_time'] = loader_time
update_train_values['avg_step_time'] = step_time
keep_avg.update_values(update_train_values)
# print training progress
if global_step % c.print_step == 0:
log_dict = {
"avg_spec_length": [avg_spec_length, 1], # value, precision
"avg_text_length": [avg_text_length, 1],
"step_time": [step_time, 4],
"loader_time": [loader_time, 2],
"current_lr": current_lr,
}
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
log_dict, loss_dict, keep_avg.avg_values)
if args.rank == 0:
# Plot Training Iter Stats
# reduce TB load
if global_step % c.tb_plot_step == 0:
iter_stats = {
"lr": current_lr,
"grad_norm": grad_norm,
"step_time": step_time
}
iter_stats.update(loss_dict)
tb_logger.tb_train_iter_stats(global_step, iter_stats)
if global_step % c.save_step == 0:
if c.checkpoint:
# save model
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH,
model_loss=loss_dict['loss'],
amp_state_dict=amp.state_dict() if amp else None)
# Diagnostic visualizations
# direct pass on model for spec predictions
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1])
spec_pred = spec_pred.permute(0, 2, 1)
gt_spec = mel_input.permute(0, 2, 1)
const_spec = spec_pred[0].data.cpu().numpy()
gt_spec = gt_spec[0].data.cpu().numpy()
align_img = alignments[0].data.cpu().numpy()
figures = {
"prediction": plot_spectrogram(const_spec, ap),
"ground_truth": plot_spectrogram(gt_spec, ap),
"alignment": plot_alignment(align_img),
}
tb_logger.tb_train_figures(global_step, figures)
# Sample audio
train_audio = ap.inv_melspectrogram(const_spec.T)
tb_logger.tb_train_audios(global_step,
{'TrainAudio': train_audio},
c.audio["sample_rate"])
end_time = time.time()
# print epoch stats
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
# Plot Epoch Stats
if args.rank == 0:
epoch_stats = {"epoch_time": epoch_time}
epoch_stats.update(keep_avg.avg_values)
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
if c.tb_model_param_stats:
tb_logger.tb_model_weights(model, global_step)
return keep_avg.avg_values, global_step
@torch.no_grad()
def evaluate(model, criterion, ap, global_step, epoch):
data_loader = setup_loader(ap, 1, is_val=True)
model.eval()
epoch_time = 0
keep_avg = KeepAverage()
c_logger.print_eval_start()
if data_loader is not None:
for num_iter, data in enumerate(data_loader):
start_time = time.time()
# format data
text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\
avg_text_length, avg_spec_length, attn_mask = format_data(data)
# forward pass model
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
text_input, text_lengths, mel_input, mel_lengths, attn_mask)
# compute loss
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
o_dur_log, o_total_dur, text_lengths)
# step time
step_time = time.time() - start_time
epoch_time += step_time
# compute alignment score
align_error = 1 - alignment_diagonal_score(alignments)
loss_dict['align_error'] = align_error
# aggregate losses from processes
if num_gpus > 1:
loss_dict['log_mle'] = reduce_tensor(loss_dict['log_mle'].data, num_gpus)
loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus)
loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus)
# detach loss values
loss_dict_new = dict()
for key, value in loss_dict.items():
if isinstance(value, (int, float)):
loss_dict_new[key] = value
else:
loss_dict_new[key] = value.item()
loss_dict = loss_dict_new
# update avg stats
update_train_values = dict()
for key, value in loss_dict.items():
update_train_values['avg_' + key] = value
keep_avg.update_values(update_train_values)
if c.print_eval:
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
if args.rank == 0:
# Diagnostic visualizations
# direct pass on model for spec predictions
if hasattr(model, 'module'):
spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1])
else:
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1])
spec_pred = spec_pred.permute(0, 2, 1)
gt_spec = mel_input.permute(0, 2, 1)
const_spec = spec_pred[0].data.cpu().numpy()
gt_spec = gt_spec[0].data.cpu().numpy()
align_img = alignments[0].data.cpu().numpy()
eval_figures = {
"prediction": plot_spectrogram(const_spec, ap),
"ground_truth": plot_spectrogram(gt_spec, ap),
"alignment": plot_alignment(align_img)
}
# Sample audio
eval_audio = ap.inv_melspectrogram(const_spec.T)
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
c.audio["sample_rate"])
# Plot Validation Stats
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
tb_logger.tb_eval_figures(global_step, eval_figures)
if args.rank == 0 and epoch >= c.test_delay_epochs:
if c.test_sentences_file is None:
test_sentences = [
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.",
"I'm sorry Dave. I'm afraid I can't do that.",
"This cake is great. It's so delicious and moist.",
"Prior to November 22, 1963."
]
else:
with open(c.test_sentences_file, "r") as f:
test_sentences = [s.strip() for s in f.readlines()]
# test sentences
test_audios = {}
test_figures = {}
print(" | > Synthesizing test sentences")
speaker_id = 0 if c.use_speaker_embedding else None
style_wav = c.get("style_wav_for_test")
for idx, test_sentence in enumerate(test_sentences):
try:
wav, alignment, decoder_output, postnet_output, stop_tokens, inputs = synthesis(
model,
test_sentence,
c,
use_cuda,
ap,
speaker_id=speaker_id,
style_wav=style_wav,
truncated=False,
enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument
use_griffin_lim=True,
do_trim_silence=False)
file_path = os.path.join(AUDIO_PATH, str(global_step))
os.makedirs(file_path, exist_ok=True)
file_path = os.path.join(file_path,
"TestSentence_{}.wav".format(idx))
ap.save_wav(wav, file_path)
test_audios['{}-audio'.format(idx)] = wav
test_figures['{}-prediction'.format(idx)] = plot_spectrogram(
postnet_output, ap)
test_figures['{}-alignment'.format(idx)] = plot_alignment(
alignment)
except:
print(" !! Error creating Test Sentence -", idx)
traceback.print_exc()
tb_logger.tb_test_audios(global_step, test_audios,
c.audio['sample_rate'])
tb_logger.tb_test_figures(global_step, test_figures)
return keep_avg.avg_values
# FIXME: move args definition/parsing inside of main?
def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=global-variable-undefined
global meta_data_train, meta_data_eval, symbols, phonemes
# Audio processor
ap = AudioProcessor(**c.audio)
if 'characters' in c.keys():
symbols, phonemes = make_symbols(**c.characters)
# DISTRUBUTED
if num_gpus > 1:
init_distributed(args.rank, num_gpus, args.group_id,
c.distributed["backend"], c.distributed["url"])
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
# load data instances
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
# set the portion of the data used for training
if 'train_portion' in c.keys():
meta_data_train = meta_data_train[:int(len(meta_data_train) * c.train_portion)]
if 'eval_portion' in c.keys():
meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)]
# parse speakers
if c.use_speaker_embedding:
speakers = get_speakers(meta_data_train)
if args.restore_path:
prev_out_path = os.path.dirname(args.restore_path)
speaker_mapping = load_speaker_mapping(prev_out_path)
assert all([speaker in speaker_mapping
for speaker in speakers]), "As of now you, you cannot " \
"introduce new speakers to " \
"a previously trained model."
else:
speaker_mapping = {name: i for i, name in enumerate(speakers)}
save_speaker_mapping(OUT_PATH, speaker_mapping)
num_speakers = len(speaker_mapping)
print("Training with {} speakers: {}".format(num_speakers,
", ".join(speakers)))
else:
num_speakers = 0
# setup model
model = setup_model(num_chars, num_speakers, c)
optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0, betas=(0.9, 0.98), eps=1e-9)
criterion = GlowTTSLoss()
if c.apex_amp_level:
# pylint: disable=import-outside-toplevel
from apex import amp
from apex.parallel import DistributedDataParallel as DDP
model.cuda()
model, optimizer = amp.initialize(model, optimizer, opt_level=c.apex_amp_level)
else:
amp = None
if args.restore_path:
checkpoint = torch.load(args.restore_path, map_location='cpu')
try:
# TODO: fix optimizer init, model.cuda() needs to be called before
# optimizer restore
optimizer.load_state_dict(checkpoint['optimizer'])
if c.reinit_layers:
raise RuntimeError
model.load_state_dict(checkpoint['model'])
except:
print(" > Partial model initialization.")
model_dict = model.state_dict()
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
model.load_state_dict(model_dict)
del model_dict
if amp and 'amp' in checkpoint:
amp.load_state_dict(checkpoint['amp'])
for group in optimizer.param_groups:
group['initial_lr'] = c.lr
print(" > Model restored from step %d" % checkpoint['step'],
flush=True)
args.restore_step = checkpoint['step']
else:
args.restore_step = 0
if use_cuda:
model.cuda()
criterion.cuda()
# DISTRUBUTED
if num_gpus > 1:
model = DDP(model)
if c.noam_schedule:
scheduler = NoamLR(optimizer,
warmup_steps=c.warmup_steps,
last_epoch=args.restore_step - 1)
else:
scheduler = None
num_params = count_parameters(model)
print("\n > Model has {} parameters".format(num_params), flush=True)
if 'best_loss' not in locals():
best_loss = float('inf')
global_step = args.restore_step
model = data_depended_init(model, ap)
for epoch in range(0, c.epochs):
c_logger.print_epoch_start(epoch, c.epochs)
train_avg_loss_dict, global_step = train(model, criterion, optimizer,
scheduler, ap, global_step,
epoch, amp)
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = train_avg_loss_dict['avg_loss']
if c.run_eval:
target_loss = eval_avg_loss_dict['avg_loss']
best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
OUT_PATH, amp_state_dict=amp.state_dict() if amp else None)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--continue_path',
type=str,
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
default='',
required='--config_path' not in sys.argv)
parser.add_argument(
'--restore_path',
type=str,
help='Model file to be restored. Use to finetune a model.',
default='')
parser.add_argument(
'--config_path',
type=str,
help='Path to config file for training.',
required='--continue_path' not in sys.argv
)
parser.add_argument('--debug',
type=bool,
default=False,
help='Do not verify commit integrity to run training.')
# DISTRUBUTED
parser.add_argument(
'--rank',
type=int,
default=0,
help='DISTRIBUTED: process rank for distributed training.')
parser.add_argument('--group_id',
type=str,
default="",
help='DISTRIBUTED: process group id.')
args = parser.parse_args()
if args.continue_path != '':
args.output_path = args.continue_path
args.config_path = os.path.join(args.continue_path, 'config.json')
list_of_files = glob.glob(args.continue_path + "/*.pth.tar") # * means all if need specific format then *.csv
latest_model_file = max(list_of_files, key=os.path.getctime)
args.restore_path = latest_model_file
print(f" > Training continues for {args.restore_path}")
# setup output paths and read configs
c = load_config(args.config_path)
# check_config(c)
_ = os.path.dirname(os.path.realpath(__file__))
if c.apex_amp_level:
print(" > apex AMP level: ", c.apex_amp_level)
OUT_PATH = args.continue_path
if args.continue_path == '':
OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug)
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
c_logger = ConsoleLogger()
if args.rank == 0:
os.makedirs(AUDIO_PATH, exist_ok=True)
new_fields = {}
if args.restore_path:
new_fields["restore_path"] = args.restore_path
new_fields["github_branch"] = get_git_branch()
copy_config_file(args.config_path,
os.path.join(OUT_PATH, 'config.json'), new_fields)
os.chmod(AUDIO_PATH, 0o775)
os.chmod(OUT_PATH, 0o775)
LOG_DIR = OUT_PATH
tb_logger = TensorboardLogger(LOG_DIR, model_name='TTS')
# write model desc to tensorboard
tb_logger.tb_add_text('model-description', c['run_description'], 0)
try:
main(args)
except KeyboardInterrupt:
remove_experiment_folder(OUT_PATH)
try:
sys.exit(0)
except SystemExit:
os._exit(0) # pylint: disable=protected-access
except Exception: # pylint: disable=broad-except
remove_experiment_folder(OUT_PATH)
traceback.print_exc()
sys.exit(1)