updated to current dev

This commit is contained in:
gerazov 2021-02-06 22:59:52 +01:00
parent 2705d27b28
commit 8fdd08ea15
6 changed files with 109 additions and 107 deletions

View File

@ -116,7 +116,7 @@ def format_data(data):
avg_text_length, avg_spec_length, attn_mask, item_idx avg_text_length, avg_spec_length, attn_mask, item_idx
def data_depended_init(data_loader, model, ap): def data_depended_init(data_loader, model):
"""Data depended initialization for activation normalization.""" """Data depended initialization for activation normalization."""
if hasattr(model, 'module'): if hasattr(model, 'module'):
for f in model.module.decoder.flows: for f in model.module.decoder.flows:
@ -135,7 +135,7 @@ def data_depended_init(data_loader, model, ap):
# format data # format data
text_input, text_lengths, mel_input, mel_lengths, spekaer_embed,\ text_input, text_lengths, mel_input, mel_lengths, spekaer_embed,\
_, _, attn_mask, item_idx = format_data(data) _, _, attn_mask, _ = format_data(data)
# forward pass model # forward pass model
_ = model.forward( _ = model.forward(
@ -174,7 +174,7 @@ def train(data_loader, model, criterion, optimizer, scheduler,
# format data # format data
text_input, text_lengths, mel_input, mel_lengths, speaker_c,\ text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
avg_text_length, avg_spec_length, attn_mask, item_idx = format_data(data) avg_text_length, avg_spec_length, attn_mask, _ = format_data(data)
loader_time = time.time() - end_time loader_time = time.time() - end_time
@ -188,20 +188,20 @@ def train(data_loader, model, criterion, optimizer, scheduler,
# compute loss # compute loss
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
o_dur_log, o_total_dur, text_lengths) o_dur_log, o_total_dur, text_lengths)
# backward pass with loss scaling # backward pass with loss scaling
if c.mixed_precision: if c.mixed_precision:
scaler.scale(loss_dict['loss']).backward() scaler.scale(loss_dict['loss']).backward()
scaler.unscale_(optimizer) scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
c.grad_clip) c.grad_clip)
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
else: else:
loss_dict['loss'].backward() loss_dict['loss'].backward()
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
c.grad_clip) c.grad_clip)
optimizer.step() optimizer.step()
# setup lr # setup lr
@ -329,7 +329,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
# format data # format data
text_input, text_lengths, mel_input, mel_lengths, speaker_c,\ text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
_, _, attn_mask, item_idx = format_data(data) _, _, attn_mask, _ = format_data(data)
# forward pass model # forward pass model
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward( z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
@ -546,13 +546,14 @@ def main(args): # pylint: disable=redefined-outer-name
eval_loader = setup_loader(ap, 1, is_val=True, verbose=True) eval_loader = setup_loader(ap, 1, is_val=True, verbose=True)
global_step = args.restore_step global_step = args.restore_step
model = data_depended_init(train_loader, model, ap) model = data_depended_init(train_loader, model)
for epoch in range(0, c.epochs): for epoch in range(0, c.epochs):
c_logger.print_epoch_start(epoch, c.epochs) c_logger.print_epoch_start(epoch, c.epochs)
train_avg_loss_dict, global_step = train(train_loader, model, criterion, optimizer, train_avg_loss_dict, global_step = train(train_loader, model, criterion, optimizer,
scheduler, ap, global_step, scheduler, ap, global_step,
epoch) epoch)
eval_avg_loss_dict = evaluate(eval_loader , model, criterion, ap, global_step, epoch) eval_avg_loss_dict = evaluate(eval_loader , model, criterion, ap,
global_step, epoch)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict) c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = train_avg_loss_dict['avg_loss'] target_loss = train_avg_loss_dict['avg_loss']
if c.run_eval: if c.run_eval:

View File

@ -172,13 +172,13 @@ def train(data_loader, model, criterion, optimizer, scheduler,
scaler.scale(loss_dict['loss']).backward() scaler.scale(loss_dict['loss']).backward()
scaler.unscale_(optimizer) scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
c.grad_clip) c.grad_clip)
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
else: else:
loss_dict['loss'].backward() loss_dict['loss'].backward()
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
c.grad_clip) c.grad_clip)
optimizer.step() optimizer.step()
# setup lr # setup lr
@ -515,12 +515,14 @@ def main(args): # pylint: disable=redefined-outer-name
train_avg_loss_dict, global_step = train(train_loader, model, criterion, optimizer, train_avg_loss_dict, global_step = train(train_loader, model, criterion, optimizer,
scheduler, ap, global_step, scheduler, ap, global_step,
epoch) epoch)
eval_avg_loss_dict = evaluate(eval_loader , model, criterion, ap, global_step, epoch) eval_avg_loss_dict = evaluate(eval_loader , model, criterion, ap,
global_step, epoch)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict) c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = train_avg_loss_dict['avg_loss'] target_loss = train_avg_loss_dict['avg_loss']
if c.run_eval: if c.run_eval:
target_loss = eval_avg_loss_dict['avg_loss'] target_loss = eval_avg_loss_dict['avg_loss']
best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r, best_loss = save_best_model(target_loss, best_loss, model, optimizer,
global_step, epoch, c.r,
OUT_PATH) OUT_PATH)

View File

@ -9,8 +9,8 @@ from random import randrange
import numpy as np import numpy as np
import torch import torch
from TTS.utils.arguments import parse_arguments, process_args
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from TTS.utils.arguments import parse_arguments, process_args
from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.datasets.TTSDataset import MyDataset from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.layers.losses import TacotronLoss from TTS.tts.layers.losses import TacotronLoss
@ -62,7 +62,7 @@ def setup_loader(ap, r, is_val=False, verbose=False, dataset=None):
c.use_external_speaker_embedding_file c.use_external_speaker_embedding_file
) else None ) else None
) )
) )
if c.use_phonemes and c.compute_input_seq_cache: if c.use_phonemes and c.compute_input_seq_cache:
# precompute phonemes to have a better estimate of sequence lengths. # precompute phonemes to have a better estimate of sequence lengths.
@ -179,10 +179,10 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
# compute loss # compute loss
loss_dict = criterion(postnet_output, decoder_output, mel_input, loss_dict = criterion(postnet_output, decoder_output, mel_input,
linear_input, stop_tokens, stop_targets, linear_input, stop_tokens, stop_targets,
mel_lengths, decoder_backward_output, mel_lengths, decoder_backward_output,
alignments, alignment_lengths, alignments_backward, alignments, alignment_lengths,
text_lengths) alignments_backward, text_lengths)
# check nan loss # check nan loss
if torch.isnan(loss_dict['loss']).any(): if torch.isnan(loss_dict['loss']).any():
@ -200,7 +200,7 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
# stopnet optimizer step # stopnet optimizer step
if c.separate_stopnet: if c.separate_stopnet:
scaler_st.scale( loss_dict['stopnet_loss']).backward() scaler_st.scale(loss_dict['stopnet_loss']).backward()
scaler.unscale_(optimizer_st) scaler.unscale_(optimizer_st)
optimizer_st, _ = adam_weight_decay(optimizer_st) optimizer_st, _ = adam_weight_decay(optimizer_st)
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0) grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
@ -534,8 +534,7 @@ def main(args): # pylint: disable=redefined-outer-name
optimizer_st = None optimizer_st = None
# setup criterion # setup criterion
criterion = TacotronLoss(c, stopnet_pos_weight=10.0, ga_sigma=0.4) criterion = TacotronLoss(c, stopnet_pos_weight=c.stopnet_pos_weight, ga_sigma=0.4)
if args.restore_path: if args.restore_path:
checkpoint = torch.load(args.restore_path, map_location='cpu') checkpoint = torch.load(args.restore_path, map_location='cpu')
try: try:
@ -637,7 +636,8 @@ def main(args): # pylint: disable=redefined-outer-name
epoch, epoch,
c.r, c.r,
OUT_PATH, OUT_PATH,
scaler=scaler.state_dict() if c.mixed_precision else None) scaler=scaler.state_dict() if c.mixed_precision else None
)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -8,8 +8,8 @@ import traceback
from inspect import signature from inspect import signature
import torch import torch
from TTS.utils.arguments import parse_arguments, process_args
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from TTS.utils.arguments import parse_arguments, process_args
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import (KeepAverage, count_parameters, from TTS.utils.generic_utils import (KeepAverage, count_parameters,
remove_experiment_folder, set_init_dict) remove_experiment_folder, set_init_dict)
@ -33,9 +33,8 @@ use_cuda, num_gpus = setup_torch_training_env(True, True)
def setup_loader(ap, is_val=False, verbose=False): def setup_loader(ap, is_val=False, verbose=False):
if is_val and not c.run_eval: loader = None
loader = None if not is_val or c.run_eval:
else:
dataset = GANDataset(ap=ap, dataset = GANDataset(ap=ap,
items=eval_data if is_val else train_data, items=eval_data if is_val else train_data,
seq_len=c.seq_len, seq_len=c.seq_len,
@ -114,7 +113,7 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
y_hat = model_G(c_G) y_hat = model_G(c_G)
y_hat_sub = None y_hat_sub = None
y_G_sub = None y_G_sub = None
y_hat_vis = y_hat # for visualization # FIXME! .clone().detach() y_hat_vis = y_hat # for visualization
# PQMF formatting # PQMF formatting
if y_hat.shape[1] > 1: if y_hat.shape[1] > 1:
@ -274,14 +273,14 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
# compute spectrograms # compute spectrograms
figures = plot_results(y_hat_vis, y_G, ap, global_step, figures = plot_results(y_hat_vis, y_G, ap, global_step,
'train') 'train')
tb_logger.tb_train_figures(global_step, figures) tb_logger.tb_train_figures(global_step, figures)
# Sample audio # Sample audio
sample_voice = y_hat_vis[0].squeeze(0).detach().cpu().numpy() sample_voice = y_hat_vis[0].squeeze(0).detach().cpu().numpy()
tb_logger.tb_train_audios(global_step, tb_logger.tb_train_audios(global_step,
{'train/audio': sample_voice}, {'train/audio': sample_voice},
c.audio["sample_rate"]) c.audio["sample_rate"])
end_time = time.time() end_time = time.time()
# print epoch stats # print epoch stats

View File

@ -8,12 +8,12 @@ import traceback
import numpy as np import numpy as np
import torch import torch
from TTS.utils.arguments import parse_arguments, process_args
# DISTRIBUTED # DISTRIBUTED
from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.optim import Adam from torch.optim import Adam
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from TTS.utils.arguments import parse_arguments, process_args
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.distribute import init_distributed from TTS.utils.distribute import init_distributed
from TTS.utils.generic_utils import (KeepAverage, count_parameters, from TTS.utils.generic_utils import (KeepAverage, count_parameters,
@ -32,16 +32,16 @@ def setup_loader(ap, is_val=False, verbose=False):
loader = None loader = None
else: else:
dataset = WaveGradDataset(ap=ap, dataset = WaveGradDataset(ap=ap,
items=eval_data if is_val else train_data, items=eval_data if is_val else train_data,
seq_len=c.seq_len, seq_len=c.seq_len,
hop_len=ap.hop_length, hop_len=ap.hop_length,
pad_short=c.pad_short, pad_short=c.pad_short,
conv_pad=c.conv_pad, conv_pad=c.conv_pad,
is_training=not is_val, is_training=not is_val,
return_segments=True, return_segments=True,
use_noise_augment=False, use_noise_augment=False,
use_cache=c.use_cache, use_cache=c.use_cache,
verbose=verbose) verbose=verbose)
sampler = DistributedSampler(dataset) if num_gpus > 1 else None sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader(dataset, loader = DataLoader(dataset,
batch_size=c.batch_size, batch_size=c.batch_size,
@ -77,8 +77,8 @@ def format_test_data(data):
return m, x return m, x
def train(model, criterion, optimizer, def train(model, criterion, optimizer, scheduler, scaler, ap, global_step,
scheduler, scaler, ap, global_step, epoch): epoch):
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0)) data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
model.train() model.train()
epoch_time = 0 epoch_time = 0
@ -92,7 +92,8 @@ def train(model, criterion, optimizer,
c_logger.print_train_start() c_logger.print_train_start()
# setup noise schedule # setup noise schedule
noise_schedule = c['train_noise_schedule'] noise_schedule = c['train_noise_schedule']
betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'], noise_schedule['num_steps']) betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'],
noise_schedule['num_steps'])
if hasattr(model, 'module'): if hasattr(model, 'module'):
model.module.compute_noise_level(betas) model.module.compute_noise_level(betas)
else: else:
@ -118,7 +119,7 @@ def train(model, criterion, optimizer,
# compute losses # compute losses
loss = criterion(noise, noise_hat) loss = criterion(noise, noise_hat)
loss_wavegrad_dict = {'wavegrad_loss':loss} loss_wavegrad_dict = {'wavegrad_loss': loss}
# check nan loss # check nan loss
if torch.isnan(loss).any(): if torch.isnan(loss).any():
@ -131,13 +132,13 @@ def train(model, criterion, optimizer,
scaler.scale(loss).backward() scaler.scale(loss).backward()
scaler.unscale_(optimizer) scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
c.clip_grad) c.clip_grad)
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
else: else:
loss.backward() loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
c.clip_grad) c.clip_grad)
optimizer.step() optimizer.step()
# schedule update # schedule update
@ -193,17 +194,19 @@ def train(model, criterion, optimizer,
if global_step % c.save_step == 0: if global_step % c.save_step == 0:
if c.checkpoint: if c.checkpoint:
# save model # save model
save_checkpoint(model, save_checkpoint(
optimizer, model,
scheduler, optimizer,
None, scheduler,
None, None,
None, None,
global_step, None,
epoch, global_step,
OUT_PATH, epoch,
model_losses=loss_dict, OUT_PATH,
scaler=scaler.state_dict() if c.mixed_precision else None) model_losses=loss_dict,
scaler=scaler.state_dict() if c.mixed_precision else None
)
end_time = time.time() end_time = time.time()
@ -250,7 +253,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
# compute losses # compute losses
loss = criterion(noise, noise_hat) loss = criterion(noise, noise_hat)
loss_wavegrad_dict = {'wavegrad_loss':loss} loss_wavegrad_dict = {'wavegrad_loss': loss}
loss_dict = dict() loss_dict = dict()
@ -282,7 +285,9 @@ def evaluate(model, criterion, ap, global_step, epoch):
# setup noise schedule and inference # setup noise schedule and inference
noise_schedule = c['test_noise_schedule'] noise_schedule = c['test_noise_schedule']
betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'], noise_schedule['num_steps']) betas = np.linspace(noise_schedule['min_val'],
noise_schedule['max_val'],
noise_schedule['num_steps'])
if hasattr(model, 'module'): if hasattr(model, 'module'):
model.module.compute_noise_level(betas) model.module.compute_noise_level(betas)
# compute voice # compute voice
@ -313,7 +318,8 @@ def main(args): # pylint: disable=redefined-outer-name
print(f" > Loading wavs from: {c.data_path}") print(f" > Loading wavs from: {c.data_path}")
if c.feature_path is not None: if c.feature_path is not None:
print(f" > Loading features from: {c.feature_path}") print(f" > Loading features from: {c.feature_path}")
eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size) eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path,
c.eval_split_size)
else: else:
eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size) eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)
@ -343,6 +349,10 @@ def main(args): # pylint: disable=redefined-outer-name
# setup criterion # setup criterion
criterion = torch.nn.L1Loss().cuda() criterion = torch.nn.L1Loss().cuda()
if use_cuda:
model.cuda()
criterion.cuda()
if args.restore_path: if args.restore_path:
checkpoint = torch.load(args.restore_path, map_location='cpu') checkpoint = torch.load(args.restore_path, map_location='cpu')
try: try:
@ -376,10 +386,6 @@ def main(args): # pylint: disable=redefined-outer-name
else: else:
args.restore_step = 0 args.restore_step = 0
if use_cuda:
model.cuda()
criterion.cuda()
# DISTRUBUTED # DISTRUBUTED
if num_gpus > 1: if num_gpus > 1:
model = DDP_th(model, device_ids=[args.rank]) model = DDP_th(model, device_ids=[args.rank])
@ -393,26 +399,26 @@ def main(args): # pylint: disable=redefined-outer-name
global_step = args.restore_step global_step = args.restore_step
for epoch in range(0, c.epochs): for epoch in range(0, c.epochs):
c_logger.print_epoch_start(epoch, c.epochs) c_logger.print_epoch_start(epoch, c.epochs)
_, global_step = train(model, criterion, optimizer, _, global_step = train(model, criterion, optimizer, scheduler, scaler,
scheduler, scaler, ap, global_step, ap, global_step, epoch)
epoch) eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch)
eval_avg_loss_dict = evaluate(model, criterion, ap,
global_step, epoch)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict) c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = eval_avg_loss_dict[c.target_loss] target_loss = eval_avg_loss_dict[c.target_loss]
best_loss = save_best_model(target_loss, best_loss = save_best_model(
best_loss, target_loss,
model, best_loss,
optimizer, model,
scheduler, optimizer,
None, scheduler,
None, None,
None, None,
global_step, None,
epoch, global_step,
OUT_PATH, epoch,
model_losses=eval_avg_loss_dict, OUT_PATH,
scaler=scaler.state_dict() if c.mixed_precision else None) model_losses=eval_avg_loss_dict,
scaler=scaler.state_dict() if c.mixed_precision else None
)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -178,18 +178,19 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch
if global_step % c.save_step == 0: if global_step % c.save_step == 0:
if c.checkpoint: if c.checkpoint:
# save model # save model
save_checkpoint(model, save_checkpoint(
optimizer, model,
scheduler, optimizer,
None, scheduler,
None, None,
None, None,
global_step, None,
epoch, global_step,
OUT_PATH, epoch,
model_losses=loss_dict, OUT_PATH,
scaler=scaler.state_dict() if c.mixed_precision else None model_losses=loss_dict,
) scaler=scaler.state_dict() if c.mixed_precision else None
)
# synthesize a full voice # synthesize a full voice
rand_idx = random.randrange(0, len(train_data)) rand_idx = random.randrange(0, len(train_data))
@ -204,14 +205,7 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch
c.batched, c.batched,
c.target_samples, c.target_samples,
c.overlap_samples, c.overlap_samples,
# use_cuda
) )
# sample_wav = model.generate(ground_mel,
# c.batched,
# c.target_samples,
# c.overlap_samples,
# use_cuda
# )
predict_mel = ap.melspectrogram(sample_wav) predict_mel = ap.melspectrogram(sample_wav)
# compute spectrograms # compute spectrograms
@ -300,7 +294,6 @@ def evaluate(model, criterion, ap, global_step, epoch):
c.batched, c.batched,
c.target_samples, c.target_samples,
c.overlap_samples, c.overlap_samples,
# use_cuda
) )
predict_mel = ap.melspectrogram(sample_wav) predict_mel = ap.melspectrogram(sample_wav)
@ -311,9 +304,10 @@ def evaluate(model, criterion, ap, global_step, epoch):
) )
# compute spectrograms # compute spectrograms
figures = {"eval/ground_truth": plot_spectrogram(ground_mel.T), figures = {
"eval/prediction": plot_spectrogram(predict_mel.T) "eval/ground_truth": plot_spectrogram(ground_mel.T),
} "eval/prediction": plot_spectrogram(predict_mel.T)
}
tb_logger.tb_eval_figures(global_step, figures) tb_logger.tb_eval_figures(global_step, figures)
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)