mirror of https://github.com/coqui-ai/TTS.git
bug fixes for single speaker glow-tts, enable torch based amp. Make amp optional for wavegrad. Bug fixes for synthesis setup for glow-tts
This commit is contained in:
parent
14c2381207
commit
946a0c0fb9
|
@ -15,8 +15,6 @@ from torch.utils.data import DataLoader
|
||||||
from TTS.tts.datasets.preprocess import load_meta_data
|
from TTS.tts.datasets.preprocess import load_meta_data
|
||||||
from TTS.tts.datasets.TTSDataset import MyDataset
|
from TTS.tts.datasets.TTSDataset import MyDataset
|
||||||
from TTS.tts.layers.losses import GlowTTSLoss
|
from TTS.tts.layers.losses import GlowTTSLoss
|
||||||
from TTS.tts.utils.distribute import (DistributedSampler, init_distributed,
|
|
||||||
reduce_tensor)
|
|
||||||
from TTS.tts.utils.generic_utils import setup_model, check_config_tts
|
from TTS.tts.utils.generic_utils import setup_model, check_config_tts
|
||||||
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
||||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||||
|
@ -28,7 +26,8 @@ from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.console_logger import ConsoleLogger
|
from TTS.utils.console_logger import ConsoleLogger
|
||||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||||
create_experiment_folder, get_git_branch,
|
create_experiment_folder, get_git_branch,
|
||||||
remove_experiment_folder, set_init_dict)
|
remove_experiment_folder, set_init_dict,
|
||||||
|
set_amp_context)
|
||||||
from TTS.utils.io import copy_config_file, load_config
|
from TTS.utils.io import copy_config_file, load_config
|
||||||
from TTS.utils.radam import RAdam
|
from TTS.utils.radam import RAdam
|
||||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||||
|
@ -36,7 +35,6 @@ from TTS.utils.training import (NoamLR, check_update,
|
||||||
setup_torch_training_env)
|
setup_torch_training_env)
|
||||||
|
|
||||||
# DISTRIBUTED
|
# DISTRIBUTED
|
||||||
from apex.parallel import DistributedDataParallel as DDP_apex
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from TTS.utils.distribute import init_distributed, reduce_tensor
|
from TTS.utils.distribute import init_distributed, reduce_tensor
|
||||||
|
@ -157,7 +155,7 @@ def data_depended_init(model, ap, speaker_mapping=None):
|
||||||
|
|
||||||
|
|
||||||
def train(model, criterion, optimizer, scheduler,
|
def train(model, criterion, optimizer, scheduler,
|
||||||
ap, global_step, epoch, amp, speaker_mapping=None):
|
ap, global_step, epoch, speaker_mapping=None):
|
||||||
data_loader = setup_loader(ap, 1, is_val=False,
|
data_loader = setup_loader(ap, 1, is_val=False,
|
||||||
verbose=(epoch == 0), speaker_mapping=speaker_mapping)
|
verbose=(epoch == 0), speaker_mapping=speaker_mapping)
|
||||||
model.train()
|
model.train()
|
||||||
|
@ -170,6 +168,7 @@ def train(model, criterion, optimizer, scheduler,
|
||||||
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
c_logger.print_train_start()
|
c_logger.print_train_start()
|
||||||
|
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
|
||||||
for num_iter, data in enumerate(data_loader):
|
for num_iter, data in enumerate(data_loader):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
@ -180,33 +179,38 @@ def train(model, criterion, optimizer, scheduler,
|
||||||
loader_time = time.time() - end_time
|
loader_time = time.time() - end_time
|
||||||
|
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# forward pass model
|
||||||
|
with set_amp_context(c.mixed_precision):
|
||||||
|
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
|
||||||
|
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids)
|
||||||
|
|
||||||
|
# compute loss
|
||||||
|
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
|
||||||
|
o_dur_log, o_total_dur, text_lengths)
|
||||||
|
|
||||||
|
# backward pass with loss scaling
|
||||||
|
if c.mixed_precision:
|
||||||
|
scaler.scale(loss_dict['loss']).backward()
|
||||||
|
scaler.unscale_(optimizer)
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
|
||||||
|
c.grad_clip)
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
else:
|
||||||
|
loss_dict['loss'].backward()
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
|
||||||
|
c.grad_clip)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
|
||||||
|
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
# setup lr
|
# setup lr
|
||||||
if c.noam_schedule:
|
if c.noam_schedule:
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
# forward pass model
|
|
||||||
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
|
|
||||||
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids)
|
|
||||||
|
|
||||||
# compute loss
|
|
||||||
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
|
|
||||||
o_dur_log, o_total_dur, text_lengths)
|
|
||||||
|
|
||||||
# backward pass - DISTRIBUTED
|
|
||||||
if amp is not None:
|
|
||||||
with amp.scale_loss(loss_dict['loss'], optimizer) as scaled_loss:
|
|
||||||
scaled_loss.backward()
|
|
||||||
else:
|
|
||||||
loss_dict['loss'].backward()
|
|
||||||
|
|
||||||
if amp:
|
|
||||||
amp_opt_params = amp.master_params(optimizer)
|
|
||||||
else:
|
|
||||||
amp_opt_params = None
|
|
||||||
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True, amp_opt_params=amp_opt_params)
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
# current_lr
|
# current_lr
|
||||||
current_lr = optimizer.param_groups[0]['lr']
|
current_lr = optimizer.param_groups[0]['lr']
|
||||||
|
@ -269,12 +273,12 @@ def train(model, criterion, optimizer, scheduler,
|
||||||
if c.checkpoint:
|
if c.checkpoint:
|
||||||
# save model
|
# save model
|
||||||
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH,
|
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH,
|
||||||
model_loss=loss_dict['loss'],
|
model_loss=loss_dict['loss'])
|
||||||
amp_state_dict=amp.state_dict() if amp else None)
|
|
||||||
|
|
||||||
# Diagnostic visualizations
|
# Diagnostic visualizations
|
||||||
# direct pass on model for spec predictions
|
# direct pass on model for spec predictions
|
||||||
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=speaker_ids[:1])
|
target_speaker = None if speaker_ids is None else speaker_ids[:1]
|
||||||
|
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=target_speaker)
|
||||||
spec_pred = spec_pred.permute(0, 2, 1)
|
spec_pred = spec_pred.permute(0, 2, 1)
|
||||||
gt_spec = mel_input.permute(0, 2, 1)
|
gt_spec = mel_input.permute(0, 2, 1)
|
||||||
const_spec = spec_pred[0].data.cpu().numpy()
|
const_spec = spec_pred[0].data.cpu().numpy()
|
||||||
|
@ -367,10 +371,11 @@ def evaluate(model, criterion, ap, global_step, epoch, speaker_mapping):
|
||||||
if args.rank == 0:
|
if args.rank == 0:
|
||||||
# Diagnostic visualizations
|
# Diagnostic visualizations
|
||||||
# direct pass on model for spec predictions
|
# direct pass on model for spec predictions
|
||||||
|
target_speaker = None if speaker_ids is None else speaker_ids[:1]
|
||||||
if hasattr(model, 'module'):
|
if hasattr(model, 'module'):
|
||||||
spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1], g=speaker_ids[:1])
|
spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1], g=target_speaker)
|
||||||
else:
|
else:
|
||||||
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=speaker_ids[:1])
|
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=target_speaker)
|
||||||
spec_pred = spec_pred.permute(0, 2, 1)
|
spec_pred = spec_pred.permute(0, 2, 1)
|
||||||
gt_spec = mel_input.permute(0, 2, 1)
|
gt_spec = mel_input.permute(0, 2, 1)
|
||||||
|
|
||||||
|
@ -489,14 +494,6 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0, betas=(0.9, 0.98), eps=1e-9)
|
optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0, betas=(0.9, 0.98), eps=1e-9)
|
||||||
criterion = GlowTTSLoss()
|
criterion = GlowTTSLoss()
|
||||||
|
|
||||||
if c.apex_amp_level is not None:
|
|
||||||
# pylint: disable=import-outside-toplevel
|
|
||||||
from apex import amp
|
|
||||||
model.cuda()
|
|
||||||
model, optimizer = amp.initialize(model, optimizer, opt_level=c.apex_amp_level)
|
|
||||||
else:
|
|
||||||
amp = None
|
|
||||||
|
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
||||||
try:
|
try:
|
||||||
|
@ -513,9 +510,6 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
model.load_state_dict(model_dict)
|
model.load_state_dict(model_dict)
|
||||||
del model_dict
|
del model_dict
|
||||||
|
|
||||||
if amp and 'amp' in checkpoint:
|
|
||||||
amp.load_state_dict(checkpoint['amp'])
|
|
||||||
|
|
||||||
for group in optimizer.param_groups:
|
for group in optimizer.param_groups:
|
||||||
group['initial_lr'] = c.lr
|
group['initial_lr'] = c.lr
|
||||||
print(" > Model restored from step %d" % checkpoint['step'],
|
print(" > Model restored from step %d" % checkpoint['step'],
|
||||||
|
@ -530,10 +524,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
# DISTRUBUTED
|
# DISTRUBUTED
|
||||||
if num_gpus > 1:
|
if num_gpus > 1:
|
||||||
if c.apex_amp_level is not None:
|
model = DDP_th(model, device_ids=[args.rank])
|
||||||
model = DDP_apex(model)
|
|
||||||
else:
|
|
||||||
model = DDP_th(model, device_ids=[args.rank])
|
|
||||||
|
|
||||||
if c.noam_schedule:
|
if c.noam_schedule:
|
||||||
scheduler = NoamLR(optimizer,
|
scheduler = NoamLR(optimizer,
|
||||||
|
@ -554,14 +545,14 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
c_logger.print_epoch_start(epoch, c.epochs)
|
c_logger.print_epoch_start(epoch, c.epochs)
|
||||||
train_avg_loss_dict, global_step = train(model, criterion, optimizer,
|
train_avg_loss_dict, global_step = train(model, criterion, optimizer,
|
||||||
scheduler, ap, global_step,
|
scheduler, ap, global_step,
|
||||||
epoch, amp, speaker_mapping)
|
epoch, speaker_mapping)
|
||||||
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch, speaker_mapping=speaker_mapping)
|
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch, speaker_mapping=speaker_mapping)
|
||||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||||
target_loss = train_avg_loss_dict['avg_loss']
|
target_loss = train_avg_loss_dict['avg_loss']
|
||||||
if c.run_eval:
|
if c.run_eval:
|
||||||
target_loss = eval_avg_loss_dict['avg_loss']
|
target_loss = eval_avg_loss_dict['avg_loss']
|
||||||
best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
|
best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
|
||||||
OUT_PATH, amp_state_dict=amp.state_dict() if amp else None)
|
OUT_PATH)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -614,8 +605,8 @@ if __name__ == '__main__':
|
||||||
check_config_tts(c)
|
check_config_tts(c)
|
||||||
_ = os.path.dirname(os.path.realpath(__file__))
|
_ = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
|
||||||
if c.apex_amp_level:
|
if c.mixed_precision:
|
||||||
print(" > apex AMP level: ", c.apex_amp_level)
|
print(" > Mixed precision enabled.")
|
||||||
|
|
||||||
OUT_PATH = args.continue_path
|
OUT_PATH = args.continue_path
|
||||||
if args.continue_path == '':
|
if args.continue_path == '':
|
||||||
|
|
|
@ -16,7 +16,8 @@ from TTS.utils.console_logger import ConsoleLogger
|
||||||
from TTS.utils.distribute import init_distributed
|
from TTS.utils.distribute import init_distributed
|
||||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||||
create_experiment_folder, get_git_branch,
|
create_experiment_folder, get_git_branch,
|
||||||
remove_experiment_folder, set_init_dict)
|
remove_experiment_folder, set_init_dict,
|
||||||
|
set_amp_context)
|
||||||
from TTS.utils.io import copy_config_file, load_config
|
from TTS.utils.io import copy_config_file, load_config
|
||||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||||
from TTS.utils.training import setup_torch_training_env
|
from TTS.utils.training import setup_torch_training_env
|
||||||
|
@ -101,7 +102,7 @@ def train(model, criterion, optimizer,
|
||||||
model.compute_noise_level(noise_schedule['num_steps'],
|
model.compute_noise_level(noise_schedule['num_steps'],
|
||||||
noise_schedule['min_val'],
|
noise_schedule['min_val'],
|
||||||
noise_schedule['max_val'])
|
noise_schedule['max_val'])
|
||||||
scaler = torch.cuda.amp.GradScaler()
|
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
|
||||||
for num_iter, data in enumerate(data_loader):
|
for num_iter, data in enumerate(data_loader):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
@ -111,7 +112,7 @@ def train(model, criterion, optimizer,
|
||||||
|
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
with torch.cuda.amp.autocast():
|
with set_amp_context(c.mixed_precision):
|
||||||
# compute noisy input
|
# compute noisy input
|
||||||
if hasattr(model, 'module'):
|
if hasattr(model, 'module'):
|
||||||
noise, x_noisy, noise_scale = model.module.compute_y_n(x)
|
noise, x_noisy, noise_scale = model.module.compute_y_n(x)
|
||||||
|
@ -127,7 +128,7 @@ def train(model, criterion, optimizer,
|
||||||
|
|
||||||
# check nan loss
|
# check nan loss
|
||||||
if torch.isnan(loss).any():
|
if torch.isnan(loss).any():
|
||||||
raise RuntimeError(f'Detected NaN loss at step {self.step}.')
|
raise RuntimeError(f'Detected NaN loss at step {global_step}.')
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
|
|
@ -102,7 +102,7 @@ class Encoder(nn.Module):
|
||||||
o = layer(o)
|
o = layer(o)
|
||||||
o = o.transpose(1, 2)
|
o = o.transpose(1, 2)
|
||||||
o = nn.utils.rnn.pack_padded_sequence(o,
|
o = nn.utils.rnn.pack_padded_sequence(o,
|
||||||
input_lengths,
|
input_lengths.cpu(),
|
||||||
batch_first=True)
|
batch_first=True)
|
||||||
self.lstm.flatten_parameters()
|
self.lstm.flatten_parameters()
|
||||||
o, _ = self.lstm(o)
|
o, _ = self.lstm(o)
|
||||||
|
|
|
@ -248,7 +248,7 @@ def check_config_tts(c):
|
||||||
check_argument('use_external_speaker_embedding_file', c, restricted=True if c['use_speaker_embedding'] else False, val_type=bool)
|
check_argument('use_external_speaker_embedding_file', c, restricted=True if c['use_speaker_embedding'] else False, val_type=bool)
|
||||||
check_argument('external_speaker_embedding_file', c, restricted=True if c['use_external_speaker_embedding_file'] else False, val_type=str)
|
check_argument('external_speaker_embedding_file', c, restricted=True if c['use_external_speaker_embedding_file'] else False, val_type=str)
|
||||||
check_argument('use_gst', c, restricted=is_tacotron(c), val_type=bool)
|
check_argument('use_gst', c, restricted=is_tacotron(c), val_type=bool)
|
||||||
if c['use_gst']:
|
if c['model'].lower() in ['tacotron', 'tacotron2'] and c['use_gst']:
|
||||||
check_argument('gst', c, restricted=is_tacotron(c), val_type=dict)
|
check_argument('gst', c, restricted=is_tacotron(c), val_type=dict)
|
||||||
check_argument('gst_style_input', c['gst'], restricted=is_tacotron(c), val_type=[str, dict])
|
check_argument('gst_style_input', c['gst'], restricted=is_tacotron(c), val_type=[str, dict])
|
||||||
check_argument('gst_embedding_dim', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=0, max_val=1000)
|
check_argument('gst_embedding_dim', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=0, max_val=1000)
|
||||||
|
|
|
@ -210,7 +210,7 @@ def synthesis(model,
|
||||||
"""
|
"""
|
||||||
# GST processing
|
# GST processing
|
||||||
style_mel = None
|
style_mel = None
|
||||||
if CONFIG.use_gst and style_wav is not None:
|
if 'use_gst' in CONFIG.keys() and CONFIG.use_gst and style_wav is not None:
|
||||||
if isinstance(style_wav, dict):
|
if isinstance(style_wav, dict):
|
||||||
style_mel = style_wav
|
style_mel = style_wav
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1,8 +1,19 @@
|
||||||
import os
|
|
||||||
import glob
|
|
||||||
import shutil
|
|
||||||
import datetime
|
import datetime
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def set_amp_context(mixed_precision):
|
||||||
|
if mixed_precision:
|
||||||
|
cm = torch.cuda.amp.autocast()
|
||||||
|
else:
|
||||||
|
cm = nullcontext()
|
||||||
|
return cm
|
||||||
|
|
||||||
|
|
||||||
def get_git_branch():
|
def get_git_branch():
|
||||||
|
|
|
@ -1,11 +1,8 @@
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn.utils import weight_norm
|
from torch.nn.utils import weight_norm
|
||||||
|
|
||||||
from math import log as ln
|
|
||||||
|
|
||||||
|
|
||||||
class Conv1d(nn.Conv1d):
|
class Conv1d(nn.Conv1d):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
|
Loading…
Reference in New Issue