diff --git a/TTS/bin/train_glow_tts.py b/TTS/bin/train_glow_tts.py index fcf6c4cd..81e7a9f2 100644 --- a/TTS/bin/train_glow_tts.py +++ b/TTS/bin/train_glow_tts.py @@ -7,41 +7,37 @@ import os import sys import time import traceback +from random import randrange import torch -from random import randrange +# DISTRIBUTED +from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.utils.data import DataLoader - +from torch.utils.data.distributed import DistributedSampler from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.TTSDataset import MyDataset from TTS.tts.layers.losses import GlowTTSLoss -from TTS.tts.utils.generic_utils import setup_model, check_config_tts +from TTS.tts.utils.generic_utils import check_config_tts, setup_model from TTS.tts.utils.io import save_best_model, save_checkpoint from TTS.tts.utils.measures import alignment_diagonal_score -from TTS.tts.utils.speakers import parse_speakers, load_speaker_mapping +from TTS.tts.utils.speakers import parse_speakers from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor from TTS.utils.console_logger import ConsoleLogger +from TTS.utils.distribute import init_distributed, reduce_tensor from TTS.utils.generic_utils import (KeepAverage, count_parameters, create_experiment_folder, get_git_branch, remove_experiment_folder, set_init_dict) from TTS.utils.io import copy_config_file, load_config from TTS.utils.radam import RAdam from TTS.utils.tensorboard_logger import TensorboardLogger -from TTS.utils.training import (NoamLR, check_update, - setup_torch_training_env) - -# DISTRIBUTED -from torch.nn.parallel import DistributedDataParallel as DDP_th -from torch.utils.data.distributed import DistributedSampler -from TTS.utils.distribute import init_distributed, reduce_tensor - +from TTS.utils.training import NoamLR, setup_torch_training_env use_cuda, num_gpus = setup_torch_training_env(True, False) -def setup_loader(ap, r, is_val=False, verbose=False, speaker_mapping=None): +def setup_loader(ap, r, is_val=False, verbose=False): if is_val and not c.run_eval: loader = None else: @@ -78,29 +74,29 @@ def setup_loader(ap, r, is_val=False, verbose=False, speaker_mapping=None): def format_data(data): - if c.use_speaker_embedding: - speaker_mapping = load_speaker_mapping(OUT_PATH) - # setup input data text_input = data[0] text_lengths = data[1] speaker_names = data[2] mel_input = data[4].permute(0, 2, 1) # B x D x T mel_lengths = data[5] - attn_mask = data[8] + item_idx = data[7] + attn_mask = data[9] avg_text_length = torch.mean(text_lengths.float()) avg_spec_length = torch.mean(mel_lengths.float()) if c.use_speaker_embedding: if c.use_external_speaker_embedding_file: - speaker_ids = data[8] + # return precomputed embedding vector + speaker_c = data[8] else: - speaker_ids = [ + # return speaker_id to be used by an embedding layer + speaker_c = [ speaker_mapping[speaker_name] for speaker_name in speaker_names ] - speaker_ids = torch.LongTensor(speaker_ids) + speaker_c = torch.LongTensor(speaker_c) else: - speaker_ids = None + speaker_c = None # dispatch data to GPU if use_cuda: @@ -108,15 +104,15 @@ def format_data(data): text_lengths = text_lengths.cuda(non_blocking=True) mel_input = mel_input.cuda(non_blocking=True) mel_lengths = mel_lengths.cuda(non_blocking=True) - if speaker_ids is not None: - speaker_ids = speaker_ids.cuda(non_blocking=True) + if speaker_c is not None: + speaker_c = speaker_c.cuda(non_blocking=True) if attn_mask is not None: attn_mask = attn_mask.cuda(non_blocking=True) - return text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\ - avg_text_length, avg_spec_length, attn_mask + return text_input, text_lengths, mel_input, mel_lengths, speaker_c,\ + avg_text_length, avg_spec_length, attn_mask, item_idx -def data_depended_init(model, ap, speaker_mapping=None): +def data_depended_init(model, ap): """Data depended initialization for activation normalization.""" if hasattr(model, 'module'): for f in model.module.decoder.flows: @@ -127,20 +123,23 @@ def data_depended_init(model, ap, speaker_mapping=None): if getattr(f, "set_ddi", False): f.set_ddi(True) - data_loader = setup_loader(ap, 1, is_val=False, speaker_mapping=speaker_mapping) + data_loader = setup_loader(ap, 1, is_val=False) model.train() print(" > Data depended initialization ... ") + num_iter = 0 with torch.no_grad(): for _, data in enumerate(data_loader): # format data - text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\ - _, _, attn_mask = format_data(data) + text_input, text_lengths, mel_input, mel_lengths, spekaer_embed,\ + _, _, attn_mask, item_idx = format_data(data) # forward pass model _ = model.forward( - text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids) - break + text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=spekaer_embed) + if num_iter == c.data_dep_init_iter: + break + num_iter += 1 if hasattr(model, 'module'): for f in model.module.decoder.flows: @@ -154,9 +153,9 @@ def data_depended_init(model, ap, speaker_mapping=None): def train(model, criterion, optimizer, scheduler, - ap, global_step, epoch, speaker_mapping=None): + ap, global_step, epoch): data_loader = setup_loader(ap, 1, is_val=False, - verbose=(epoch == 0), speaker_mapping=speaker_mapping) + verbose=(epoch == 0)) model.train() epoch_time = 0 keep_avg = KeepAverage() @@ -172,8 +171,8 @@ def train(model, criterion, optimizer, scheduler, start_time = time.time() # format data - text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\ - avg_text_length, avg_spec_length, attn_mask = format_data(data) + text_input, text_lengths, mel_input, mel_lengths, speaker_c,\ + avg_text_length, avg_spec_length, attn_mask, item_idx = format_data(data) loader_time = time.time() - end_time @@ -203,10 +202,6 @@ def train(model, criterion, optimizer, scheduler, c.grad_clip) optimizer.step() - - grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True) - optimizer.step() - # setup lr if c.noam_schedule: scheduler.step() @@ -215,7 +210,7 @@ def train(model, criterion, optimizer, scheduler, current_lr = optimizer.param_groups[0]['lr'] # compute alignment error (the lower the better ) - align_error = 1 - alignment_diagonal_score(alignments) + align_error = 1 - alignment_diagonal_score(alignments, binary=True) loss_dict['align_error'] = align_error step_time = time.time() - start_time @@ -276,7 +271,7 @@ def train(model, criterion, optimizer, scheduler, # Diagnostic visualizations # direct pass on model for spec predictions - target_speaker = None if speaker_ids is None else speaker_ids[:1] + target_speaker = None if speaker_c is None else speaker_c[:1] spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=target_speaker) spec_pred = spec_pred.permute(0, 2, 1) gt_spec = mel_input.permute(0, 2, 1) @@ -313,8 +308,8 @@ def train(model, criterion, optimizer, scheduler, @torch.no_grad() -def evaluate(model, criterion, ap, global_step, epoch, speaker_mapping): - data_loader = setup_loader(ap, 1, is_val=True, speaker_mapping=speaker_mapping) +def evaluate(model, criterion, ap, global_step, epoch): + data_loader = setup_loader(ap, 1, is_val=True) model.eval() epoch_time = 0 keep_avg = KeepAverage() @@ -324,12 +319,12 @@ def evaluate(model, criterion, ap, global_step, epoch, speaker_mapping): start_time = time.time() # format data - text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\ - _, _, attn_mask = format_data(data) + text_input, text_lengths, mel_input, mel_lengths, speaker_c,\ + _, _, attn_mask, item_idx = format_data(data) # forward pass model z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward( - text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids) + text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c) # compute loss loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, @@ -370,7 +365,7 @@ def evaluate(model, criterion, ap, global_step, epoch, speaker_mapping): if args.rank == 0: # Diagnostic visualizations # direct pass on model for spec predictions - target_speaker = None if speaker_ids is None else speaker_ids[:1] + target_speaker = None if speaker_c is None else speaker_c[:1] if hasattr(model, 'module'): spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1], g=target_speaker) else: @@ -464,7 +459,7 @@ def evaluate(model, criterion, ap, global_step, epoch, speaker_mapping): # FIXME: move args definition/parsing inside of main? def main(args): # pylint: disable=redefined-outer-name # pylint: disable=global-variable-undefined - global meta_data_train, meta_data_eval, symbols, phonemes + global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping # Audio processor ap = AudioProcessor(**c.audio) if 'characters' in c.keys(): @@ -539,13 +534,13 @@ def main(args): # pylint: disable=redefined-outer-name best_loss = float('inf') global_step = args.restore_step - model = data_depended_init(model, ap, speaker_mapping) + model = data_depended_init(model, ap) for epoch in range(0, c.epochs): c_logger.print_epoch_start(epoch, c.epochs) train_avg_loss_dict, global_step = train(model, criterion, optimizer, scheduler, ap, global_step, - epoch, speaker_mapping) - eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch, speaker_mapping=speaker_mapping) + epoch) + eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) target_loss = train_avg_loss_dict['avg_loss'] if c.run_eval: diff --git a/TTS/bin/train_tacotron.py b/TTS/bin/train_tacotron.py index 1263a616..b3fbc415 100644 --- a/TTS/bin/train_tacotron.py +++ b/TTS/bin/train_tacotron.py @@ -18,7 +18,7 @@ from TTS.tts.layers.losses import TacotronLoss from TTS.tts.utils.generic_utils import check_config_tts, setup_model from TTS.tts.utils.io import save_best_model, save_checkpoint from TTS.tts.utils.measures import alignment_diagonal_score -from TTS.tts.utils.speakers import load_speaker_mapping, parse_speakers +from TTS.tts.utils.speakers import parse_speakers from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.tts.utils.visual import plot_alignment, plot_spectrogram @@ -39,7 +39,7 @@ from TTS.utils.training import (NoamLR, adam_weight_decay, check_update, use_cuda, num_gpus = setup_torch_training_env(True, False) -def setup_loader(ap, r, is_val=False, verbose=False, speaker_mapping=None): +def setup_loader(ap, r, is_val=False, verbose=False): if is_val and not c.run_eval: loader = None else: @@ -74,10 +74,7 @@ def setup_loader(ap, r, is_val=False, verbose=False, speaker_mapping=None): pin_memory=False) return loader -def format_data(data, speaker_mapping=None): - if speaker_mapping is None and c.use_speaker_embedding and not c.use_external_speaker_embedding_file: - speaker_mapping = load_speaker_mapping(OUT_PATH) - +def format_data(data): # setup input data text_input = data[0] text_lengths = data[1] @@ -127,7 +124,7 @@ def format_data(data, speaker_mapping=None): def train(model, criterion, optimizer, optimizer_st, scheduler, - ap, global_step, epoch, scaler, scaler_st, speaker_mapping=None): + ap, global_step, epoch, scaler, scaler_st): data_loader = setup_loader(ap, model.decoder.r, is_val=False, verbose=(epoch == 0), speaker_mapping=speaker_mapping) model.train() @@ -144,7 +141,7 @@ def train(model, criterion, optimizer, optimizer_st, scheduler, start_time = time.time() # format data - text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, max_text_length, max_spec_length = format_data(data, speaker_mapping) + text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, max_text_length, max_spec_length = format_data(data) loader_time = time.time() - end_time global_step += 1 @@ -327,7 +324,7 @@ def train(model, criterion, optimizer, optimizer_st, scheduler, @torch.no_grad() -def evaluate(model, criterion, ap, global_step, epoch, speaker_mapping=None): +def evaluate(model, criterion, ap, global_step, epoch): data_loader = setup_loader(ap, model.decoder.r, is_val=True, speaker_mapping=speaker_mapping) model.eval() epoch_time = 0 @@ -338,7 +335,7 @@ def evaluate(model, criterion, ap, global_step, epoch, speaker_mapping=None): start_time = time.time() # format data - text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, _, _ = format_data(data, speaker_mapping) + text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, _, _ = format_data(data) assert mel_input.shape[1] % model.decoder.r == 0 # forward pass model @@ -493,7 +490,7 @@ def evaluate(model, criterion, ap, global_step, epoch, speaker_mapping=None): # FIXME: move args definition/parsing inside of main? def main(args): # pylint: disable=redefined-outer-name # pylint: disable=global-variable-undefined - global meta_data_train, meta_data_eval, symbols, phonemes + global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping # Audio processor ap = AudioProcessor(**c.audio) if 'characters' in c.keys(): @@ -599,8 +596,8 @@ def main(args): # pylint: disable=redefined-outer-name print("\n > Number of output frames:", model.decoder.r) train_avg_loss_dict, global_step = train(model, criterion, optimizer, optimizer_st, scheduler, ap, - global_step, epoch, scaler, scaler_st, speaker_mapping) - eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch, speaker_mapping) + global_step, epoch, scaler, scaler_st) + eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) target_loss = train_avg_loss_dict['avg_postnet_loss'] if c.run_eval: diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index dec8243a..734c11e5 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -104,6 +104,7 @@ class GlowTts(nn.Module): c_in_channels=self.c_in_channels) if num_speakers > 1 and not external_speaker_embedding_dim: + # speaker embedding layer self.emb_g = nn.Embedding(num_speakers, self.c_in_channels) nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)