From 7448177b72442b90136af6ba39423db3c4a70aeb Mon Sep 17 00:00:00 2001 From: Edresson Date: Sat, 29 May 2021 21:11:53 -0300 Subject: [PATCH] use SpeakerManager on compute embeddings script --- TTS/bin/compute_embeddings.py | 35 ++++++++++--------- TTS/bin/train_encoder.py | 4 +-- .../configs/config_resnet_angleproto.json | 2 +- TTS/speaker_encoder/dataset.py | 2 +- 4 files changed, 23 insertions(+), 20 deletions(-) diff --git a/TTS/bin/compute_embeddings.py b/TTS/bin/compute_embeddings.py index 9affac64..003da1e5 100644 --- a/TTS/bin/compute_embeddings.py +++ b/TTS/bin/compute_embeddings.py @@ -7,7 +7,7 @@ from tqdm import tqdm from TTS.speaker_encoder.utils.generic_utils import setup_model from TTS.tts.datasets.preprocess import load_meta_data -from TTS.tts.utils.speakers import save_speaker_mapping +from TTS.tts.utils.speakers import SpeakerManager from TTS.utils.audio import AudioProcessor from TTS.config import load_config, BaseDatasetConfig @@ -28,7 +28,7 @@ parser.add_argument( default="", help="Target dataset to pick a processor from TTS.tts.dataset.preprocess. Necessary to create a speakers.json file.", ) -parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=False) +parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True) parser.add_argument("--separator", type=str, help="Separator used in file if CSV is passed for data_path", default="|") args = parser.parse_args() @@ -69,9 +69,6 @@ else: # Parse all wav files in data_path wav_files = glob.glob(data_path + "/**/*.wav", recursive=True) - -os.makedirs(args.output_path, exist_ok=True) - # define Encoder model model = setup_model(c) model.load_state_dict(torch.load(args.model_path)["model"]) @@ -85,6 +82,8 @@ for idx, wav_file in enumerate(tqdm(wav_files)): if isinstance(wav_file, list): speaker_name = wav_file[2] wav_file = wav_file[1] + else: + speaker_name = None mel_spec = ap.melspectrogram(ap.load_wav(wav_file, sr=ap.sample_rate)).T mel_spec = torch.FloatTensor(mel_spec[None, :, :]) @@ -93,16 +92,20 @@ for idx, wav_file in enumerate(tqdm(wav_files)): embedd = model.compute_embedding(mel_spec) embedd = embedd.detach().cpu().numpy() - if args.target_dataset != "": - # create speaker_mapping if target dataset is defined - wav_file_name = os.path.basename(wav_file) - speaker_mapping[wav_file_name] = {} - speaker_mapping[wav_file_name]["name"] = speaker_name - speaker_mapping[wav_file_name]["embedding"] = embedd.flatten().tolist() + # create speaker_mapping if target dataset is defined + wav_file_name = os.path.basename(wav_file) + speaker_mapping[wav_file_name] = {} + speaker_mapping[wav_file_name]["name"] = speaker_name + speaker_mapping[wav_file_name]["embedding"] = embedd.flatten().tolist() -if args.target_dataset != "": - if speaker_mapping: - # save speaker_mapping if target dataset is defined +if speaker_mapping: + # save speaker_mapping if target dataset is defined + if '.json' not in args.output_path: mapping_file_path = os.path.join(args.output_path, "speakers.json") - save_speaker_mapping(args.output_path, speaker_mapping) - print("Speaker embedding saved at:", mapping_file_path) + else: + mapping_file_path = args.output_path + os.makedirs(os.path.dirname(mapping_file_path), exist_ok=True) + speaker_manager = SpeakerManager() + # pylint: disable=W0212 + speaker_manager._save_json(mapping_file_path, speaker_mapping) + print("Speaker embeddings saved at:", mapping_file_path) diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index a4191dfb..c9493535 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -9,7 +9,7 @@ import traceback import torch from torch.utils.data import DataLoader -from TTS.speaker_encoder.dataset import MyDataset +from TTS.speaker_encoder.dataset import SpeakerEncoderDataset from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_model @@ -35,7 +35,7 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False if is_val: loader = None else: - dataset = MyDataset( + dataset = SpeakerEncoderDataset( ap, meta_data_eval if is_val else meta_data_train, voice_len=c.voice_len, diff --git a/TTS/speaker_encoder/configs/config_resnet_angleproto.json b/TTS/speaker_encoder/configs/config_resnet_angleproto.json index 95cf5ccf..c26d29ce 100644 --- a/TTS/speaker_encoder/configs/config_resnet_angleproto.json +++ b/TTS/speaker_encoder/configs/config_resnet_angleproto.json @@ -52,7 +52,7 @@ "checkpoint": true, // If true, it saves checkpoints per "save_step" "save_step": 1000, // Number of training steps expected to save the best checkpoints in training. "print_step": 50, // Number of steps to log traning on console. - "output_path": "../checkpoints/speaker_encoder/resnet_voxceleb1_and_voxceleb2-and-common-voice-all-using-angleproto-continue/", // DATASET-RELATED: output path for all training outputs. + "output_path": "../checkpoints/speaker_encoder/angleproto/resnet_voxceleb1_and_voxceleb2-and-common-voice-all-using-angleproto/", // DATASET-RELATED: output path for all training outputs. "audio_augmentation": { "p": 0.5, // propability of apply this method, 0 is disable rir and additive noise augmentation diff --git a/TTS/speaker_encoder/dataset.py b/TTS/speaker_encoder/dataset.py index 45a7bc12..cd95a4f5 100644 --- a/TTS/speaker_encoder/dataset.py +++ b/TTS/speaker_encoder/dataset.py @@ -6,7 +6,7 @@ import torch from torch.utils.data import Dataset from TTS.speaker_encoder.utils.generic_utils import AugmentWAV, Storage -class MyDataset(Dataset): +class SpeakerEncoderDataset(Dataset): def __init__( self, ap,