diff --git a/TTS/bin/train_glow_tts.py b/TTS/bin/train_glow_tts.py index c5e570e5..d924b906 100644 --- a/TTS/bin/train_glow_tts.py +++ b/TTS/bin/train_glow_tts.py @@ -9,6 +9,7 @@ import time import traceback import torch +from random import randrange from torch.utils.data import DataLoader from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.TTSDataset import MyDataset @@ -36,8 +37,7 @@ from TTS.utils.training import (NoamLR, check_update, use_cuda, num_gpus = setup_torch_training_env(True, False) -def setup_loader(ap, r, is_val=False, verbose=False): - +def setup_loader(ap, r, is_val=False, verbose=False, speaker_mapping=None): if is_val and not c.run_eval: loader = None else: @@ -56,7 +56,8 @@ def setup_loader(ap, r, is_val=False, verbose=False): use_phonemes=c.use_phonemes, phoneme_language=c.phoneme_language, enable_eos_bos=c.enable_eos_bos_chars, - verbose=verbose) + verbose=verbose, + speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None) sampler = DistributedSampler(dataset) if num_gpus > 1 else None loader = DataLoader( dataset, @@ -86,10 +87,13 @@ def format_data(data): avg_spec_length = torch.mean(mel_lengths.float()) if c.use_speaker_embedding: - speaker_ids = [ - speaker_mapping[speaker_name] for speaker_name in speaker_names - ] - speaker_ids = torch.LongTensor(speaker_ids) + if c.use_external_speaker_embedding_file: + speaker_ids = data[8] + else: + speaker_ids = [ + speaker_mapping[speaker_name] for speaker_name in speaker_names + ] + speaker_ids = torch.LongTensor(speaker_ids) else: speaker_ids = None @@ -107,7 +111,7 @@ def format_data(data): avg_text_length, avg_spec_length, attn_mask -def data_depended_init(model, ap): +def data_depended_init(model, ap, speaker_mapping=None): """Data depended initialization for activation normalization.""" if hasattr(model, 'module'): for f in model.module.decoder.flows: @@ -118,19 +122,19 @@ def data_depended_init(model, ap): if getattr(f, "set_ddi", False): f.set_ddi(True) - data_loader = setup_loader(ap, 1, is_val=False) + data_loader = setup_loader(ap, 1, is_val=False, speaker_mapping=speaker_mapping) model.train() print(" > Data depended initialization ... ") with torch.no_grad(): for _, data in enumerate(data_loader): # format data - text_input, text_lengths, mel_input, mel_lengths, _,\ + text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\ _, _, attn_mask = format_data(data) # forward pass model _ = model.forward( - text_input, text_lengths, mel_input, mel_lengths, attn_mask) + text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids) break if hasattr(model, 'module'): @@ -145,9 +149,9 @@ def data_depended_init(model, ap): def train(model, criterion, optimizer, scheduler, - ap, global_step, epoch, amp): + ap, global_step, epoch, amp, speaker_mapping=None): data_loader = setup_loader(ap, 1, is_val=False, - verbose=(epoch == 0)) + verbose=(epoch == 0), speaker_mapping=speaker_mapping) model.train() epoch_time = 0 keep_avg = KeepAverage() @@ -162,7 +166,7 @@ def train(model, criterion, optimizer, scheduler, start_time = time.time() # format data - text_input, text_lengths, mel_input, mel_lengths, _,\ + text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\ avg_text_length, avg_spec_length, attn_mask = format_data(data) loader_time = time.time() - end_time @@ -176,7 +180,7 @@ def train(model, criterion, optimizer, scheduler, # forward pass model z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward( - text_input, text_lengths, mel_input, mel_lengths, attn_mask) + text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids) # compute loss loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, @@ -262,7 +266,7 @@ def train(model, criterion, optimizer, scheduler, # Diagnostic visualizations # direct pass on model for spec predictions - spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1]) + spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=speaker_ids[:1]) spec_pred = spec_pred.permute(0, 2, 1) gt_spec = mel_input.permute(0, 2, 1) const_spec = spec_pred[0].data.cpu().numpy() @@ -298,8 +302,8 @@ def train(model, criterion, optimizer, scheduler, @torch.no_grad() -def evaluate(model, criterion, ap, global_step, epoch): - data_loader = setup_loader(ap, 1, is_val=True) +def evaluate(model, criterion, ap, global_step, epoch, speaker_mapping): + data_loader = setup_loader(ap, 1, is_val=True, speaker_mapping=speaker_mapping) model.eval() epoch_time = 0 keep_avg = KeepAverage() @@ -309,12 +313,12 @@ def evaluate(model, criterion, ap, global_step, epoch): start_time = time.time() # format data - text_input, text_lengths, mel_input, mel_lengths, _,\ + text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\ _, _, attn_mask = format_data(data) # forward pass model z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward( - text_input, text_lengths, mel_input, mel_lengths, attn_mask) + text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids) # compute loss loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, @@ -356,9 +360,9 @@ def evaluate(model, criterion, ap, global_step, epoch): # Diagnostic visualizations # direct pass on model for spec predictions if hasattr(model, 'module'): - spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1]) + spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1], g=speaker_ids[:1]) else: - spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1]) + spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=speaker_ids[:1]) spec_pred = spec_pred.permute(0, 2, 1) gt_spec = mel_input.permute(0, 2, 1) @@ -398,7 +402,17 @@ def evaluate(model, criterion, ap, global_step, epoch): test_audios = {} test_figures = {} print(" | > Synthesizing test sentences") - speaker_id = 0 if c.use_speaker_embedding else None + if c.use_speaker_embedding: + if c.use_external_speaker_embedding_file: + speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping)-1)]]['embedding'] + speaker_id = None + else: + speaker_id = 0 + speaker_embedding = None + else: + speaker_id = None + speaker_embedding = None + style_wav = c.get("style_wav_for_test") for idx, test_sentence in enumerate(test_sentences): try: @@ -409,6 +423,7 @@ def evaluate(model, criterion, ap, global_step, epoch): use_cuda, ap, speaker_id=speaker_id, + speaker_embedding=speaker_embedding, style_wav=style_wav, truncated=False, enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument @@ -462,23 +477,42 @@ def main(args): # pylint: disable=redefined-outer-name if c.use_speaker_embedding: speakers = get_speakers(meta_data_train) if args.restore_path: - prev_out_path = os.path.dirname(args.restore_path) - speaker_mapping = load_speaker_mapping(prev_out_path) - assert all([speaker in speaker_mapping - for speaker in speakers]), "As of now you, you cannot " \ - "introduce new speakers to " \ - "a previously trained model." - else: + if c.use_external_speaker_embedding_file: # if restore checkpoint and use External Embedding file + prev_out_path = os.path.dirname(args.restore_path) + speaker_mapping = load_speaker_mapping(prev_out_path) + if not speaker_mapping: + print("WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file") + speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file) + if not speaker_mapping: + raise RuntimeError("You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.external_speaker_embedding_file") + speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding']) + elif not c.use_external_speaker_embedding_file: # if restore checkpoint and don't use External Embedding file + prev_out_path = os.path.dirname(args.restore_path) + speaker_mapping = load_speaker_mapping(prev_out_path) + speaker_embedding_dim = None + assert all([speaker in speaker_mapping + for speaker in speakers]), "As of now you, you cannot " \ + "introduce new speakers to " \ + "a previously trained model." + elif c.use_external_speaker_embedding_file and c.external_speaker_embedding_file: # if start new train using External Embedding file + speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file) + speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding']) + elif c.use_external_speaker_embedding_file and not c.external_speaker_embedding_file: # if start new train using External Embedding file and don't pass external embedding file + raise "use_external_speaker_embedding_file is True, so you need pass a external speaker embedding file, run GE2E-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb or AngularPrototypical-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb notebook in notebooks/ folder" + else: # if start new train and don't use External Embedding file speaker_mapping = {name: i for i, name in enumerate(speakers)} + speaker_embedding_dim = None save_speaker_mapping(OUT_PATH, speaker_mapping) num_speakers = len(speaker_mapping) - print("Training with {} speakers: {}".format(num_speakers, + print("Training with {} speakers: {}".format(len(speakers), ", ".join(speakers))) else: num_speakers = 0 + speaker_embedding_dim = None + speaker_mapping = None # setup model - model = setup_model(num_chars, num_speakers, c) + model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim=speaker_embedding_dim) optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0, betas=(0.9, 0.98), eps=1e-9) criterion = GlowTTSLoss() @@ -540,13 +574,13 @@ def main(args): # pylint: disable=redefined-outer-name best_loss = float('inf') global_step = args.restore_step - model = data_depended_init(model, ap) + model = data_depended_init(model, ap, speaker_mapping) for epoch in range(0, c.epochs): c_logger.print_epoch_start(epoch, c.epochs) train_avg_loss_dict, global_step = train(model, criterion, optimizer, scheduler, ap, global_step, - epoch, amp) - eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch) + epoch, amp, speaker_mapping) + eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch, speaker_mapping=speaker_mapping) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) target_loss = train_avg_loss_dict['avg_loss'] if c.run_eval: diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index a9b6f8c0..dec8243a 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -68,13 +68,14 @@ class GlowTts(nn.Module): self.use_encoder_prenet = use_encoder_prenet self.noise_scale = 0.66 self.length_scale = 1. + self.external_speaker_embedding_dim = external_speaker_embedding_dim # if is a multispeaker and c_in_channels is 0, set to 256 if num_speakers > 1: - if self.c_in_channels == 0 and not external_speaker_embedding_dim: - self.c_in_channels = 256 - elif external_speaker_embedding_dim: - self.c_in_channels = external_speaker_embedding_dim + if self.c_in_channels == 0 and not self.external_speaker_embedding_dim: + self.c_in_channels = 512 + elif self.external_speaker_embedding_dim: + self.c_in_channels = self.external_speaker_embedding_dim self.encoder = Encoder(num_chars, out_channels=out_channels, @@ -88,7 +89,7 @@ class GlowTts(nn.Module): dropout_p=dropout_p, mean_only=mean_only, use_prenet=use_encoder_prenet, - c_in_channels=c_in_channels) + c_in_channels=self.c_in_channels) self.decoder = Decoder(out_channels, hidden_channels_dec or hidden_channels, @@ -100,10 +101,10 @@ class GlowTts(nn.Module): num_splits=num_splits, num_sqz=num_sqz, sigmoid_scale=sigmoid_scale, - c_in_channels=c_in_channels) + c_in_channels=self.c_in_channels) if num_speakers > 1 and not external_speaker_embedding_dim: - self.emb_g = nn.Embedding(num_speakers, c_in_channels) + self.emb_g = nn.Embedding(num_speakers, self.c_in_channels) nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) @staticmethod @@ -130,7 +131,11 @@ class GlowTts(nn.Module): y_max_length = y.size(2) # norm speaker embeddings if g is not None: - g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h] + if self.external_speaker_embedding_dim: + g = F.normalize(g).unsqueeze(-1) + else: + g = F.normalize(self.emb_g(g)).unsqueeze(-1)# [b, h] + # embedding pass o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, @@ -165,8 +170,13 @@ class GlowTts(nn.Module): @torch.no_grad() def inference(self, x, x_lengths, g=None): + if g is not None: - g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h] + if self.external_speaker_embedding_dim: + g = F.normalize(g).unsqueeze(-1) + else: + g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h] + # embedding pass o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index aacac898..2361fa85 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -126,7 +126,8 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): mean_only=True, hidden_channels_enc=192, hidden_channels_dec=192, - use_encoder_prenet=True) + use_encoder_prenet=True, + external_speaker_embedding_dim=speaker_embedding_dim) return model def is_tacotron(c): diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index f810e213..0dfea5cc 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -59,7 +59,7 @@ def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings) elif 'glow' in CONFIG.model.lower(): inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable - postnet_output, _, _, _, alignments, _, _ = model.inference(inputs, inputs_lengths) + postnet_output, _, _, _, alignments, _, _ = model.inference(inputs, inputs_lengths, g=speaker_id if speaker_id else speaker_embeddings) postnet_output = postnet_output.permute(0, 2, 1) # these only belong to tacotron models. decoder_output = None