diff --git a/TTS/bin/train_glow_tts.py b/TTS/bin/train_glow_tts.py index 3d34d978..f4d04abb 100644 --- a/TTS/bin/train_glow_tts.py +++ b/TTS/bin/train_glow_tts.py @@ -9,17 +9,17 @@ import time import traceback import torch +from random import randrange from torch.utils.data import DataLoader from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.TTSDataset import MyDataset from TTS.tts.layers.losses import GlowTTSLoss from TTS.tts.utils.distribute import (DistributedSampler, init_distributed, reduce_tensor) -from TTS.tts.utils.generic_utils import setup_model +from TTS.tts.utils.generic_utils import setup_model, check_config_tts from TTS.tts.utils.io import save_best_model, save_checkpoint from TTS.tts.utils.measures import alignment_diagonal_score -from TTS.tts.utils.speakers import (get_speakers, load_speaker_mapping, - save_speaker_mapping) +from TTS.tts.utils.speakers import parse_speakers, load_speaker_mapping from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.tts.utils.visual import plot_alignment, plot_spectrogram @@ -36,8 +36,7 @@ from TTS.utils.training import (NoamLR, check_update, use_cuda, num_gpus = setup_torch_training_env(True, False) -def setup_loader(ap, r, is_val=False, verbose=False): - +def setup_loader(ap, r, is_val=False, verbose=False, speaker_mapping=None): if is_val and not c.run_eval: loader = None else: @@ -48,6 +47,7 @@ def setup_loader(ap, r, is_val=False, verbose=False): meta_data=meta_data_eval if is_val else meta_data_train, ap=ap, tp=c.characters if 'characters' in c.keys() else None, + add_blank=c['add_blank'] if 'add_blank' in c.keys() else False, batch_group_size=0 if is_val else c.batch_group_size * c.batch_size, min_seq_len=c.min_seq_len, @@ -56,7 +56,8 @@ def setup_loader(ap, r, is_val=False, verbose=False): use_phonemes=c.use_phonemes, phoneme_language=c.phoneme_language, enable_eos_bos=c.enable_eos_bos_chars, - verbose=verbose) + verbose=verbose, + speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None) sampler = DistributedSampler(dataset) if num_gpus > 1 else None loader = DataLoader( dataset, @@ -86,10 +87,13 @@ def format_data(data): avg_spec_length = torch.mean(mel_lengths.float()) if c.use_speaker_embedding: - speaker_ids = [ - speaker_mapping[speaker_name] for speaker_name in speaker_names - ] - speaker_ids = torch.LongTensor(speaker_ids) + if c.use_external_speaker_embedding_file: + speaker_ids = data[8] + else: + speaker_ids = [ + speaker_mapping[speaker_name] for speaker_name in speaker_names + ] + speaker_ids = torch.LongTensor(speaker_ids) else: speaker_ids = None @@ -107,7 +111,7 @@ def format_data(data): avg_text_length, avg_spec_length, attn_mask -def data_depended_init(model, ap): +def data_depended_init(model, ap, speaker_mapping=None): """Data depended initialization for activation normalization.""" if hasattr(model, 'module'): for f in model.module.decoder.flows: @@ -118,19 +122,19 @@ def data_depended_init(model, ap): if getattr(f, "set_ddi", False): f.set_ddi(True) - data_loader = setup_loader(ap, 1, is_val=False) + data_loader = setup_loader(ap, 1, is_val=False, speaker_mapping=speaker_mapping) model.train() print(" > Data depended initialization ... ") with torch.no_grad(): for _, data in enumerate(data_loader): # format data - text_input, text_lengths, mel_input, mel_lengths, _,\ + text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\ _, _, attn_mask = format_data(data) # forward pass model _ = model.forward( - text_input, text_lengths, mel_input, mel_lengths, attn_mask) + text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids) break if hasattr(model, 'module'): @@ -145,9 +149,9 @@ def data_depended_init(model, ap): def train(model, criterion, optimizer, scheduler, - ap, global_step, epoch, amp): + ap, global_step, epoch, amp, speaker_mapping=None): data_loader = setup_loader(ap, 1, is_val=False, - verbose=(epoch == 0)) + verbose=(epoch == 0), speaker_mapping=speaker_mapping) model.train() epoch_time = 0 keep_avg = KeepAverage() @@ -162,7 +166,7 @@ def train(model, criterion, optimizer, scheduler, start_time = time.time() # format data - text_input, text_lengths, mel_input, mel_lengths, _,\ + text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\ avg_text_length, avg_spec_length, attn_mask = format_data(data) loader_time = time.time() - end_time @@ -176,7 +180,7 @@ def train(model, criterion, optimizer, scheduler, # forward pass model z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward( - text_input, text_lengths, mel_input, mel_lengths, attn_mask) + text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids) # compute loss loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, @@ -262,7 +266,7 @@ def train(model, criterion, optimizer, scheduler, # Diagnostic visualizations # direct pass on model for spec predictions - spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1]) + spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=speaker_ids[:1]) spec_pred = spec_pred.permute(0, 2, 1) gt_spec = mel_input.permute(0, 2, 1) const_spec = spec_pred[0].data.cpu().numpy() @@ -298,8 +302,8 @@ def train(model, criterion, optimizer, scheduler, @torch.no_grad() -def evaluate(model, criterion, ap, global_step, epoch): - data_loader = setup_loader(ap, 1, is_val=True) +def evaluate(model, criterion, ap, global_step, epoch, speaker_mapping): + data_loader = setup_loader(ap, 1, is_val=True, speaker_mapping=speaker_mapping) model.eval() epoch_time = 0 keep_avg = KeepAverage() @@ -309,12 +313,12 @@ def evaluate(model, criterion, ap, global_step, epoch): start_time = time.time() # format data - text_input, text_lengths, mel_input, mel_lengths, _,\ + text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\ _, _, attn_mask = format_data(data) # forward pass model z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward( - text_input, text_lengths, mel_input, mel_lengths, attn_mask) + text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids) # compute loss loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, @@ -356,9 +360,9 @@ def evaluate(model, criterion, ap, global_step, epoch): # Diagnostic visualizations # direct pass on model for spec predictions if hasattr(model, 'module'): - spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1]) + spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1], g=speaker_ids[:1]) else: - spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1]) + spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=speaker_ids[:1]) spec_pred = spec_pred.permute(0, 2, 1) gt_spec = mel_input.permute(0, 2, 1) @@ -398,7 +402,17 @@ def evaluate(model, criterion, ap, global_step, epoch): test_audios = {} test_figures = {} print(" | > Synthesizing test sentences") - speaker_id = 0 if c.use_speaker_embedding else None + if c.use_speaker_embedding: + if c.use_external_speaker_embedding_file: + speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping)-1)]]['embedding'] + speaker_id = None + else: + speaker_id = 0 + speaker_embedding = None + else: + speaker_id = None + speaker_embedding = None + style_wav = c.get("style_wav_for_test") for idx, test_sentence in enumerate(test_sentences): try: @@ -409,6 +423,7 @@ def evaluate(model, criterion, ap, global_step, epoch): use_cuda, ap, speaker_id=speaker_id, + speaker_embedding=speaker_embedding, style_wav=style_wav, truncated=False, enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument @@ -459,26 +474,10 @@ def main(args): # pylint: disable=redefined-outer-name meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)] # parse speakers - if c.use_speaker_embedding: - speakers = get_speakers(meta_data_train) - if args.restore_path: - prev_out_path = os.path.dirname(args.restore_path) - speaker_mapping = load_speaker_mapping(prev_out_path) - assert all([speaker in speaker_mapping - for speaker in speakers]), "As of now you, you cannot " \ - "introduce new speakers to " \ - "a previously trained model." - else: - speaker_mapping = {name: i for i, name in enumerate(speakers)} - save_speaker_mapping(OUT_PATH, speaker_mapping) - num_speakers = len(speaker_mapping) - print("Training with {} speakers: {}".format(num_speakers, - ", ".join(speakers))) - else: - num_speakers = 0 + num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, OUT_PATH) # setup model - model = setup_model(num_chars, num_speakers, c) + model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim=speaker_embedding_dim) optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0, betas=(0.9, 0.98), eps=1e-9) criterion = GlowTTSLoss() @@ -540,13 +539,13 @@ def main(args): # pylint: disable=redefined-outer-name best_loss = float('inf') global_step = args.restore_step - model = data_depended_init(model, ap) + model = data_depended_init(model, ap, speaker_mapping) for epoch in range(0, c.epochs): c_logger.print_epoch_start(epoch, c.epochs) train_avg_loss_dict, global_step = train(model, criterion, optimizer, scheduler, ap, global_step, - epoch, amp) - eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch) + epoch, amp, speaker_mapping) + eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch, speaker_mapping=speaker_mapping) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) target_loss = train_avg_loss_dict['avg_loss'] if c.run_eval: @@ -602,6 +601,7 @@ if __name__ == '__main__': # setup output paths and read configs c = load_config(args.config_path) # check_config(c) + check_config_tts(c) _ = os.path.dirname(os.path.realpath(__file__)) if c.apex_amp_level: diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index 88e10aea..e4f8bf7a 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -22,8 +22,7 @@ from TTS.tts.utils.distribute import (DistributedSampler, from TTS.tts.utils.generic_utils import setup_model, check_config_tts from TTS.tts.utils.io import save_best_model, save_checkpoint from TTS.tts.utils.measures import alignment_diagonal_score -from TTS.tts.utils.speakers import (get_speakers, load_speaker_mapping, - save_speaker_mapping) +from TTS.tts.utils.speakers import parse_speakers, load_speaker_mapping from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.tts.utils.visual import plot_alignment, plot_spectrogram @@ -52,6 +51,7 @@ def setup_loader(ap, r, is_val=False, verbose=False, speaker_mapping=None): meta_data=meta_data_eval if is_val else meta_data_train, ap=ap, tp=c.characters if 'characters' in c.keys() else None, + add_blank=c['add_blank'] if 'add_blank' in c.keys() else False, batch_group_size=0 if is_val else c.batch_group_size * c.batch_size, min_seq_len=c.min_seq_len, @@ -502,42 +502,7 @@ def main(args): # pylint: disable=redefined-outer-name meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)] # parse speakers - if c.use_speaker_embedding: - speakers = get_speakers(meta_data_train) - if args.restore_path: - if c.use_external_speaker_embedding_file: # if restore checkpoint and use External Embedding file - prev_out_path = os.path.dirname(args.restore_path) - speaker_mapping = load_speaker_mapping(prev_out_path) - if not speaker_mapping: - print("WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file") - speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file) - if not speaker_mapping: - raise RuntimeError("You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.external_speaker_embedding_file") - speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding']) - elif not c.use_external_speaker_embedding_file: # if restore checkpoint and don't use External Embedding file - prev_out_path = os.path.dirname(args.restore_path) - speaker_mapping = load_speaker_mapping(prev_out_path) - speaker_embedding_dim = None - assert all([speaker in speaker_mapping - for speaker in speakers]), "As of now you, you cannot " \ - "introduce new speakers to " \ - "a previously trained model." - elif c.use_external_speaker_embedding_file and c.external_speaker_embedding_file: # if start new train using External Embedding file - speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file) - speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding']) - elif c.use_external_speaker_embedding_file and not c.external_speaker_embedding_file: # if start new train using External Embedding file and don't pass external embedding file - raise "use_external_speaker_embedding_file is True, so you need pass a external speaker embedding file, run GE2E-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb or AngularPrototypical-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb notebook in notebooks/ folder" - else: # if start new train and don't use External Embedding file - speaker_mapping = {name: i for i, name in enumerate(speakers)} - speaker_embedding_dim = None - save_speaker_mapping(OUT_PATH, speaker_mapping) - num_speakers = len(speaker_mapping) - print("Training with {} speakers: {}".format(num_speakers, - ", ".join(speakers))) - else: - num_speakers = 0 - speaker_embedding_dim = None - speaker_mapping = None + num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, OUT_PATH) model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim) diff --git a/TTS/tts/configs/glow_tts_gated_conv.json b/TTS/tts/configs/glow_tts_gated_conv.json index 696bdaf7..5c30e0bc 100644 --- a/TTS/tts/configs/glow_tts_gated_conv.json +++ b/TTS/tts/configs/glow_tts_gated_conv.json @@ -51,6 +51,8 @@ // "phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ" // }, + "add_blank": false, // if true add a new token after each token of the sentence. This increases the size of the input sequence, but has considerably improved the prosody of the GlowTTS model. + // DISTRIBUTED TRAINING "distributed":{ "backend": "nccl", diff --git a/TTS/tts/configs/glow_tts_tdsep.json b/TTS/tts/configs/glow_tts_tdsep.json index 67047523..25d41291 100644 --- a/TTS/tts/configs/glow_tts_tdsep.json +++ b/TTS/tts/configs/glow_tts_tdsep.json @@ -51,6 +51,8 @@ // "phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ" // }, + "add_blank": false, // if true add a new token after each token of the sentence. This increases the size of the input sequence, but has considerably improved the prosody of the GlowTTS model. + // DISTRIBUTED TRAINING "distributed":{ "backend": "nccl", diff --git a/TTS/tts/datasets/TTSDataset.py b/TTS/tts/datasets/TTSDataset.py index ab8f3f88..7b671397 100644 --- a/TTS/tts/datasets/TTSDataset.py +++ b/TTS/tts/datasets/TTSDataset.py @@ -17,6 +17,7 @@ class MyDataset(Dataset): ap, meta_data, tp=None, + add_blank=False, batch_group_size=0, min_seq_len=0, max_seq_len=float("inf"), @@ -55,6 +56,7 @@ class MyDataset(Dataset): self.max_seq_len = max_seq_len self.ap = ap self.tp = tp + self.add_blank = add_blank self.use_phonemes = use_phonemes self.phoneme_cache_path = phoneme_cache_path self.phoneme_language = phoneme_language @@ -88,7 +90,7 @@ class MyDataset(Dataset): phonemes = phoneme_to_sequence(text, [self.cleaners], language=self.phoneme_language, enable_eos_bos=False, - tp=self.tp) + tp=self.tp, add_blank=self.add_blank) phonemes = np.asarray(phonemes, dtype=np.int32) np.save(cache_path, phonemes) return phonemes @@ -127,7 +129,7 @@ class MyDataset(Dataset): text = self._load_or_generate_phoneme_sequence(wav_file, text) else: text = np.asarray(text_to_sequence(text, [self.cleaners], - tp=self.tp), + tp=self.tp, add_blank=self.add_blank), dtype=np.int32) assert text.size > 0, self.items[idx][1] diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 902de699..dec8243a 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -37,7 +37,8 @@ class GlowTts(nn.Module): hidden_channels_enc=None, hidden_channels_dec=None, use_encoder_prenet=False, - encoder_type="transformer"): + encoder_type="transformer", + external_speaker_embedding_dim=None): super().__init__() self.num_chars = num_chars @@ -67,6 +68,14 @@ class GlowTts(nn.Module): self.use_encoder_prenet = use_encoder_prenet self.noise_scale = 0.66 self.length_scale = 1. + self.external_speaker_embedding_dim = external_speaker_embedding_dim + + # if is a multispeaker and c_in_channels is 0, set to 256 + if num_speakers > 1: + if self.c_in_channels == 0 and not self.external_speaker_embedding_dim: + self.c_in_channels = 512 + elif self.external_speaker_embedding_dim: + self.c_in_channels = self.external_speaker_embedding_dim self.encoder = Encoder(num_chars, out_channels=out_channels, @@ -80,7 +89,7 @@ class GlowTts(nn.Module): dropout_p=dropout_p, mean_only=mean_only, use_prenet=use_encoder_prenet, - c_in_channels=c_in_channels) + c_in_channels=self.c_in_channels) self.decoder = Decoder(out_channels, hidden_channels_dec or hidden_channels, @@ -92,10 +101,10 @@ class GlowTts(nn.Module): num_splits=num_splits, num_sqz=num_sqz, sigmoid_scale=sigmoid_scale, - c_in_channels=c_in_channels) + c_in_channels=self.c_in_channels) - if num_speakers > 1: - self.emb_g = nn.Embedding(num_speakers, c_in_channels) + if num_speakers > 1 and not external_speaker_embedding_dim: + self.emb_g = nn.Embedding(num_speakers, self.c_in_channels) nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) @staticmethod @@ -122,7 +131,11 @@ class GlowTts(nn.Module): y_max_length = y.size(2) # norm speaker embeddings if g is not None: - g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h] + if self.external_speaker_embedding_dim: + g = F.normalize(g).unsqueeze(-1) + else: + g = F.normalize(self.emb_g(g)).unsqueeze(-1)# [b, h] + # embedding pass o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, @@ -157,8 +170,13 @@ class GlowTts(nn.Module): @torch.no_grad() def inference(self, x, x_lengths, g=None): + if g is not None: - g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h] + if self.external_speaker_embedding_dim: + g = F.normalize(g).unsqueeze(-1) + else: + g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h] + # embedding pass o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index 5480cbcd..2361fa85 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -126,13 +126,15 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): mean_only=True, hidden_channels_enc=192, hidden_channels_dec=192, - use_encoder_prenet=True) + use_encoder_prenet=True, + external_speaker_embedding_dim=speaker_embedding_dim) return model - +def is_tacotron(c): + return False if c['model'] == 'glow_tts' else True def check_config_tts(c): - check_argument('model', c, enum_list=['tacotron', 'tacotron2'], restricted=True, val_type=str) + check_argument('model', c, enum_list=['tacotron', 'tacotron2', 'glow_tts'], restricted=True, val_type=str) check_argument('run_name', c, restricted=True, val_type=str) check_argument('run_description', c, val_type=str) @@ -195,27 +197,30 @@ def check_config_tts(c): check_argument('seq_len_norm', c, restricted=True, val_type=bool) # tacotron prenet - check_argument('memory_size', c, restricted=True, val_type=int, min_val=-1) - check_argument('prenet_type', c, restricted=True, val_type=str, enum_list=['original', 'bn']) - check_argument('prenet_dropout', c, restricted=True, val_type=bool) + check_argument('memory_size', c, restricted=is_tacotron(c), val_type=int, min_val=-1) + check_argument('prenet_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['original', 'bn']) + check_argument('prenet_dropout', c, restricted=is_tacotron(c), val_type=bool) # attention - check_argument('attention_type', c, restricted=True, val_type=str, enum_list=['graves', 'original']) - check_argument('attention_heads', c, restricted=True, val_type=int) - check_argument('attention_norm', c, restricted=True, val_type=str, enum_list=['sigmoid', 'softmax']) - check_argument('windowing', c, restricted=True, val_type=bool) - check_argument('use_forward_attn', c, restricted=True, val_type=bool) - check_argument('forward_attn_mask', c, restricted=True, val_type=bool) - check_argument('transition_agent', c, restricted=True, val_type=bool) - check_argument('transition_agent', c, restricted=True, val_type=bool) - check_argument('location_attn', c, restricted=True, val_type=bool) - check_argument('bidirectional_decoder', c, restricted=True, val_type=bool) - check_argument('double_decoder_consistency', c, restricted=True, val_type=bool) + check_argument('attention_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['graves', 'original']) + check_argument('attention_heads', c, restricted=is_tacotron(c), val_type=int) + check_argument('attention_norm', c, restricted=is_tacotron(c), val_type=str, enum_list=['sigmoid', 'softmax']) + check_argument('windowing', c, restricted=is_tacotron(c), val_type=bool) + check_argument('use_forward_attn', c, restricted=is_tacotron(c), val_type=bool) + check_argument('forward_attn_mask', c, restricted=is_tacotron(c), val_type=bool) + check_argument('transition_agent', c, restricted=is_tacotron(c), val_type=bool) + check_argument('transition_agent', c, restricted=is_tacotron(c), val_type=bool) + check_argument('location_attn', c, restricted=is_tacotron(c), val_type=bool) + check_argument('bidirectional_decoder', c, restricted=is_tacotron(c), val_type=bool) + check_argument('double_decoder_consistency', c, restricted=is_tacotron(c), val_type=bool) check_argument('ddc_r', c, restricted='double_decoder_consistency' in c.keys(), min_val=1, max_val=7, val_type=int) # stopnet - check_argument('stopnet', c, restricted=True, val_type=bool) - check_argument('separate_stopnet', c, restricted=True, val_type=bool) + check_argument('stopnet', c, restricted=is_tacotron(c), val_type=bool) + check_argument('separate_stopnet', c, restricted=is_tacotron(c), val_type=bool) + + # GlowTTS parameters + check_argument('encoder_type', c, restricted=not is_tacotron(c), val_type=str) # tensorboard check_argument('print_step', c, restricted=True, val_type=int, min_val=1) @@ -240,15 +245,16 @@ def check_config_tts(c): # multi-speaker and gst check_argument('use_speaker_embedding', c, restricted=True, val_type=bool) - check_argument('use_external_speaker_embedding_file', c, restricted=True, val_type=bool) - check_argument('external_speaker_embedding_file', c, restricted=True, val_type=str) - check_argument('use_gst', c, restricted=True, val_type=bool) - check_argument('gst', c, restricted=True, val_type=dict) - check_argument('gst_style_input', c['gst'], restricted=True, val_type=[str, dict]) - check_argument('gst_embedding_dim', c['gst'], restricted=True, val_type=int, min_val=0, max_val=1000) - check_argument('gst_use_speaker_embedding', c['gst'], restricted=True, val_type=bool) - check_argument('gst_num_heads', c['gst'], restricted=True, val_type=int, min_val=2, max_val=10) - check_argument('gst_style_tokens', c['gst'], restricted=True, val_type=int, min_val=1, max_val=1000) + check_argument('use_external_speaker_embedding_file', c, restricted=True if c['use_speaker_embedding'] else False, val_type=bool) + check_argument('external_speaker_embedding_file', c, restricted=True if c['use_external_speaker_embedding_file'] else False, val_type=str) + check_argument('use_gst', c, restricted=is_tacotron(c), val_type=bool) + if c['use_gst']: + check_argument('gst', c, restricted=is_tacotron(c), val_type=dict) + check_argument('gst_style_input', c['gst'], restricted=is_tacotron(c), val_type=[str, dict]) + check_argument('gst_embedding_dim', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=0, max_val=1000) + check_argument('gst_use_speaker_embedding', c['gst'], restricted=is_tacotron(c), val_type=bool) + check_argument('gst_num_heads', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=2, max_val=10) + check_argument('gst_style_tokens', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=1, max_val=1000) # datasets - checking only the first entry check_argument('datasets', c, restricted=True, val_type=list) diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index 156e42af..d507ff3d 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -30,3 +30,44 @@ def get_speakers(items): """Returns a sorted, unique list of speakers in a given dataset.""" speakers = {e[2] for e in items} return sorted(speakers) + +def parse_speakers(c, args, meta_data_train, OUT_PATH): + """ Returns number of speakers, speaker embedding shape and speaker mapping""" + if c.use_speaker_embedding: + speakers = get_speakers(meta_data_train) + if args.restore_path: + if c.use_external_speaker_embedding_file: # if restore checkpoint and use External Embedding file + prev_out_path = os.path.dirname(args.restore_path) + speaker_mapping = load_speaker_mapping(prev_out_path) + if not speaker_mapping: + print("WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file") + speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file) + if not speaker_mapping: + raise RuntimeError("You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.external_speaker_embedding_file") + speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding']) + elif not c.use_external_speaker_embedding_file: # if restore checkpoint and don't use External Embedding file + prev_out_path = os.path.dirname(args.restore_path) + speaker_mapping = load_speaker_mapping(prev_out_path) + speaker_embedding_dim = None + assert all([speaker in speaker_mapping + for speaker in speakers]), "As of now you, you cannot " \ + "introduce new speakers to " \ + "a previously trained model." + elif c.use_external_speaker_embedding_file and c.external_speaker_embedding_file: # if start new train using External Embedding file + speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file) + speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding']) + elif c.use_external_speaker_embedding_file and not c.external_speaker_embedding_file: # if start new train using External Embedding file and don't pass external embedding file + raise "use_external_speaker_embedding_file is True, so you need pass a external speaker embedding file, run GE2E-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb or AngularPrototypical-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb notebook in notebooks/ folder" + else: # if start new train and don't use External Embedding file + speaker_mapping = {name: i for i, name in enumerate(speakers)} + speaker_embedding_dim = None + save_speaker_mapping(OUT_PATH, speaker_mapping) + num_speakers = len(speaker_mapping) + print("Training with {} speakers: {}".format(len(speakers), + ", ".join(speakers))) + else: + num_speakers = 0 + speaker_embedding_dim = None + speaker_mapping = None + + return num_speakers, speaker_embedding_dim, speaker_mapping \ No newline at end of file diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index f810e213..3d2dd13c 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -14,10 +14,13 @@ def text_to_seqvec(text, CONFIG): seq = np.asarray( phoneme_to_sequence(text, text_cleaner, CONFIG.phoneme_language, CONFIG.enable_eos_bos_chars, - tp=CONFIG.characters if 'characters' in CONFIG.keys() else None), + tp=CONFIG.characters if 'characters' in CONFIG.keys() else None, + add_blank=CONFIG['add_blank'] if 'add_blank' in CONFIG.keys() else False), dtype=np.int32) else: - seq = np.asarray(text_to_sequence(text, text_cleaner, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None), dtype=np.int32) + seq = np.asarray( + text_to_sequence(text, text_cleaner, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None, + add_blank=CONFIG['add_blank'] if 'add_blank' in CONFIG.keys() else False), dtype=np.int32) return seq @@ -59,7 +62,7 @@ def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings) elif 'glow' in CONFIG.model.lower(): inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable - postnet_output, _, _, _, alignments, _, _ = model.inference(inputs, inputs_lengths) + postnet_output, _, _, _, alignments, _, _ = model.inference(inputs, inputs_lengths, g=speaker_id if speaker_id else speaker_embeddings) postnet_output = postnet_output.permute(0, 2, 1) # these only belong to tacotron models. decoder_output = None diff --git a/TTS/tts/utils/text/__init__.py b/TTS/tts/utils/text/__init__.py index 33972f25..29f4af1d 100644 --- a/TTS/tts/utils/text/__init__.py +++ b/TTS/tts/utils/text/__init__.py @@ -16,6 +16,8 @@ _id_to_symbol = {i: s for i, s in enumerate(symbols)} _phonemes_to_id = {s: i for i, s in enumerate(phonemes)} _id_to_phonemes = {i: s for i, s in enumerate(phonemes)} +_symbols = symbols +_phonemes = phonemes # Regular expression matching text enclosed in curly braces: _CURLY_RE = re.compile(r'(.*?)\{(.+?)\}(.*)') @@ -57,6 +59,10 @@ def text2phone(text, language): return ph +def intersperse(sequence, token): + result = [token] * (len(sequence) * 2 + 1) + result[1::2] = sequence + return result def pad_with_eos_bos(phoneme_sequence, tp=None): # pylint: disable=global-statement @@ -69,10 +75,9 @@ def pad_with_eos_bos(phoneme_sequence, tp=None): return [_phonemes_to_id[_bos]] + list(phoneme_sequence) + [_phonemes_to_id[_eos]] - -def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False, tp=None): +def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False, tp=None, add_blank=False): # pylint: disable=global-statement - global _phonemes_to_id + global _phonemes_to_id, _phonemes if tp: _, _phonemes = make_symbols(**tp) _phonemes_to_id = {s: i for i, s in enumerate(_phonemes)} @@ -88,13 +93,17 @@ def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False, tp= # Append EOS char if enable_eos_bos: sequence = pad_with_eos_bos(sequence, tp=tp) + if add_blank: + sequence = intersperse(sequence, len(_phonemes)) # add a blank token (new), whose id number is len(_phonemes) return sequence -def sequence_to_phoneme(sequence, tp=None): +def sequence_to_phoneme(sequence, tp=None, add_blank=False): # pylint: disable=global-statement '''Converts a sequence of IDs back to a string''' - global _id_to_phonemes + global _id_to_phonemes, _phonemes + if add_blank: + sequence = list(filter(lambda x: x != len(_phonemes), sequence)) result = '' if tp: _, _phonemes = make_symbols(**tp) @@ -107,7 +116,7 @@ def sequence_to_phoneme(sequence, tp=None): return result.replace('}{', ' ') -def text_to_sequence(text, cleaner_names, tp=None): +def text_to_sequence(text, cleaner_names, tp=None, add_blank=False): '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. The text can optionally have ARPAbet sequences enclosed in curly braces embedded @@ -121,7 +130,7 @@ def text_to_sequence(text, cleaner_names, tp=None): List of integers corresponding to the symbols in the text ''' # pylint: disable=global-statement - global _symbol_to_id + global _symbol_to_id, _symbols if tp: _symbols, _ = make_symbols(**tp) _symbol_to_id = {s: i for i, s in enumerate(_symbols)} @@ -137,13 +146,19 @@ def text_to_sequence(text, cleaner_names, tp=None): _clean_text(m.group(1), cleaner_names)) sequence += _arpabet_to_sequence(m.group(2)) text = m.group(3) + + if add_blank: + sequence = intersperse(sequence, len(_symbols)) # add a blank token (new), whose id number is len(_symbols) return sequence -def sequence_to_text(sequence, tp=None): +def sequence_to_text(sequence, tp=None, add_blank=False): '''Converts a sequence of IDs back to a string''' # pylint: disable=global-statement - global _id_to_symbol + global _id_to_symbol, _symbols + if add_blank: + sequence = list(filter(lambda x: x != len(_symbols), sequence)) + if tp: _symbols, _ = make_symbols(**tp) _id_to_symbol = {i: s for i, s in enumerate(_symbols)} diff --git a/tests/test_text_processing.py b/tests/test_text_processing.py index 1eb9f9a8..ae3250a8 100644 --- a/tests/test_text_processing.py +++ b/tests/test_text_processing.py @@ -11,6 +11,7 @@ from TTS.utils.io import load_config conf = load_config(os.path.join(get_tests_input_path(), 'test_config.json')) def test_phoneme_to_sequence(): + text = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!" text_cleaner = ["phoneme_cleaners"] lang = "en-us" @@ -20,7 +21,7 @@ def test_phoneme_to_sequence(): text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters) gt = "ɹiːsənt ɹɪsɜːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪnkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹɪspɑːnsəbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjuːleɪʃən ænd lɜːnɪŋ!" assert text_hat == text_hat_with_params == gt - + # multiple punctuations text = "Be a voice, not an! echo?" sequence = phoneme_to_sequence(text, text_cleaner, lang) @@ -87,6 +88,84 @@ def test_phoneme_to_sequence(): print(len(sequence)) assert text_hat == text_hat_with_params == gt +def test_phoneme_to_sequence_with_blank_token(): + + text = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!" + text_cleaner = ["phoneme_cleaners"] + lang = "en-us" + sequence = phoneme_to_sequence(text, text_cleaner, lang) + text_hat = sequence_to_phoneme(sequence) + _ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True) + text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True) + gt = "ɹiːsənt ɹɪsɜːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪnkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹɪspɑːnsəbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjuːleɪʃən ænd lɜːnɪŋ!" + assert text_hat == text_hat_with_params == gt + + # multiple punctuations + text = "Be a voice, not an! echo?" + sequence = phoneme_to_sequence(text, text_cleaner, lang) + text_hat = sequence_to_phoneme(sequence) + _ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True) + text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True) + gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ?" + print(text_hat) + print(len(sequence)) + assert text_hat == text_hat_with_params == gt + + # not ending with punctuation + text = "Be a voice, not an! echo" + sequence = phoneme_to_sequence(text, text_cleaner, lang) + text_hat = sequence_to_phoneme(sequence) + _ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True) + text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True) + gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ" + print(text_hat) + print(len(sequence)) + assert text_hat == text_hat_with_params == gt + + # original + text = "Be a voice, not an echo!" + sequence = phoneme_to_sequence(text, text_cleaner, lang) + text_hat = sequence_to_phoneme(sequence) + _ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True) + text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True) + gt = "biː ɐ vɔɪs, nɑːt ɐn ɛkoʊ!" + print(text_hat) + print(len(sequence)) + assert text_hat == text_hat_with_params == gt + + # extra space after the sentence + text = "Be a voice, not an! echo. " + sequence = phoneme_to_sequence(text, text_cleaner, lang) + text_hat = sequence_to_phoneme(sequence) + _ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True) + text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True) + gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ." + print(text_hat) + print(len(sequence)) + assert text_hat == text_hat_with_params == gt + + # extra space after the sentence + text = "Be a voice, not an! echo. " + sequence = phoneme_to_sequence(text, text_cleaner, lang, True) + text_hat = sequence_to_phoneme(sequence) + _ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True) + text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True) + gt = "^biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ.~" + print(text_hat) + print(len(sequence)) + assert text_hat == text_hat_with_params == gt + + # padding char + text = "_Be a _voice, not an! echo_" + sequence = phoneme_to_sequence(text, text_cleaner, lang) + text_hat = sequence_to_phoneme(sequence) + _ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True) + text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True) + gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ" + print(text_hat) + print(len(sequence)) + assert text_hat == text_hat_with_params == gt + def test_text2phone(): text = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!" gt = "ɹ|iː|s|ə|n|t| |ɹ|ɪ|s|ɜː|tʃ| |æ|t| |h|ɑːɹ|v|ɚ|d| |h|ɐ|z| |ʃ|oʊ|n| |m|ɛ|d|ᵻ|t|eɪ|ɾ|ɪ|ŋ| |f|ɔː|ɹ| |æ|z| |l|ɪ|ɾ|əl| |æ|z| |eɪ|t| |w|iː|k|s| |k|æ|n| |æ|k|tʃ|uː|əl|i| |ɪ|n|k|ɹ|iː|s|,| |ð|ə| |ɡ|ɹ|eɪ| |m|æ|ɾ|ɚ|ɹ| |ɪ|n|ð|ə| |p|ɑːɹ|t|s| |ʌ|v|ð|ə| |b|ɹ|eɪ|n| |ɹ|ɪ|s|p|ɑː|n|s|ə|b|əl| |f|ɔː|ɹ| |ɪ|m|oʊ|ʃ|ə|n|əl| |ɹ|ɛ|ɡ|j|uː|l|eɪ|ʃ|ə|n| |æ|n|d| |l|ɜː|n|ɪ|ŋ|!"