From 6f33506d892cb3ec899aee5fb7829e6c74824922 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 15 Mar 2022 19:40:07 +0000 Subject: [PATCH] Fix unit tests --- TTS/bin/compute_embeddings.py | 49 ++++++++++++++++++++++++----------- TTS/bin/synthesize.py | 1 + TTS/utils/synthesizer.py | 4 +-- 3 files changed, 37 insertions(+), 17 deletions(-) diff --git a/TTS/bin/compute_embeddings.py b/TTS/bin/compute_embeddings.py index d7fe3c4b..67b17241 100644 --- a/TTS/bin/compute_embeddings.py +++ b/TTS/bin/compute_embeddings.py @@ -1,5 +1,6 @@ import argparse import os +import torch from argparse import RawTextHelpFormatter import torch @@ -8,7 +9,7 @@ from tqdm import tqdm from TTS.config import load_config from TTS.tts.datasets import load_tts_samples from TTS.tts.utils.managers import save_file -from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.managers import EmbeddingManager parser = argparse.ArgumentParser( description="""Compute embedding vectors for each wav file in a dataset.\n\n""" @@ -25,6 +26,7 @@ parser.add_argument("--output_path", type=str, help="Path for output `pth` or `j parser.add_argument("--old_file", type=str, help="Previous embedding file to only compute new audios.", default=None) parser.add_argument("--disable_cuda", type=bool, help="Flag to disable cuda.", default=False) parser.add_argument("--no_eval", type=bool, help="Do not compute eval?. Default False", default=False) +parser.add_argument("--use_predicted_label", type=bool, help="If True and predicted label is available with will use it.", default=False) args = parser.parse_args() @@ -39,20 +41,20 @@ if meta_data_eval is None: else: wav_files = meta_data_train + meta_data_eval -encoder_manager = SpeakerManager( +encoder_manager = EmbeddingManager( encoder_model_path=args.model_path, encoder_config_path=args.config_path, - d_vectors_file_path=args.old_file, + embedding_file_path=args.old_file, use_cuda=use_cuda, ) class_name_key = encoder_manager.encoder_config.class_name_key # compute speaker embeddings -speaker_mapping = {} +class_mapping = {} for idx, wav_file in enumerate(tqdm(wav_files)): if isinstance(wav_file, dict): - class_name = wav_file[class_name_key] + class_name = wav_file[class_name_key] if class_name_key in wav_file else None wav_file = wav_file["audio_file"] else: class_name = None @@ -65,20 +67,37 @@ for idx, wav_file in enumerate(tqdm(wav_files)): # extract the embedding embedd = encoder_manager.compute_embedding_from_clip(wav_file) - # create speaker_mapping if target dataset is defined - speaker_mapping[wav_file_name] = {} - speaker_mapping[wav_file_name]["name"] = class_name - speaker_mapping[wav_file_name]["embedding"] = embedd + if args.use_predicted_label: + map_classid_to_classname = getattr(encoder_manager.encoder_config, 'map_classid_to_classname', None) + if encoder_manager.encoder_criterion is not None and map_classid_to_classname is not None: + embedding = torch.FloatTensor(embedd).unsqueeze(0) + if encoder_manager.use_cuda: + embedding = embedding.cuda() -if speaker_mapping: - # save speaker_mapping if target dataset is defined - if os.path.isdir(args.output_path): - mapping_file_path = os.path.join(args.output_path, "speakers.pth") + class_id = encoder_manager.encoder_criterion.softmax.inference(embedding).item() + class_name = map_classid_to_classname[str(class_id)] + else: + raise RuntimeError( + " [!] use_predicted_label is enable and predicted_labels is not available !!" + ) + + # create class_mapping if target dataset is defined + class_mapping[wav_file_name] = {} + class_mapping[wav_file_name]["name"] = class_name + class_mapping[wav_file_name]["embedding"] = embedd + +if class_mapping: + # save class_mapping if target dataset is defined + if ".json" not in args.output_path or ".pth" not in args.output_path: + if class_name_key == "speaker_name": + mapping_file_path = os.path.join(args.output_path, "speakers.pth") + else: + mapping_file_path = os.path.join(args.output_path, "emotions.pth") else: mapping_file_path = args.output_path if os.path.dirname(mapping_file_path) != "": os.makedirs(os.path.dirname(mapping_file_path), exist_ok=True) - save_file(speaker_mapping, mapping_file_path) - print("Speaker embeddings saved at:", mapping_file_path) + save_file(class_mapping, mapping_file_path) + print("Embeddings saved at:", mapping_file_path) diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index 900f5df7..a623a8a8 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -237,6 +237,7 @@ If you don't specify any models, then it uses LJSpeech based English model. model_path = None config_path = None speakers_file_path = None + emotions_file_path = None language_ids_file_path = None vocoder_path = None vocoder_config_path = None diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 91526c56..a48959f4 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -265,9 +265,9 @@ class Synthesizer(object): # handle emotion emotion_embedding, emotion_id = None, None - if self.tts_emotions_file or hasattr(self.tts_model.emotion_manager, "ids"): + if self.tts_emotions_file or (getattr(self.tts_model, "emotion_manager", None) and getattr(self.tts_model.emotion_manager, "ids", None)): if emotion_name and isinstance(emotion_name, str): - if getattr(self.tts_config, "use_external_emotions_embeddings", False) or getattr(self.tts_config.model_args, "use_external_emotions_embeddings", False): + if getattr(self.tts_config, "use_external_emotions_embeddings", False) or (getattr(self.tts_config, "model_args", None) and getattr(self.tts_config.model_args, "use_external_emotions_embeddings", False)): # get the average speaker embedding from the saved embeddings. emotion_embedding = self.tts_model.emotion_manager.get_mean_embedding(emotion_name, num_samples=None, randomize=False) emotion_embedding = np.array(emotion_embedding)[None, :] # [1 x embedding_dim]