From f840268181b2c3b0570083a322a0ccff3fd37ef2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 28 May 2021 15:46:28 +0200 Subject: [PATCH] refactor `SpeakerManager` --- TTS/tts/utils/speakers.py | 199 +++++++++++++++++++++++++++++++++----- 1 file changed, 177 insertions(+), 22 deletions(-) diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index 3239e9a5..5c10c589 100755 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -1,6 +1,7 @@ import json +import os import random -from typing import Any, List, Union +from typing import Any, Dict, List, Tuple, Union import numpy as np import torch @@ -10,6 +11,71 @@ from TTS.speaker_encoder.utils.generic_utils import setup_model from TTS.utils.audio import AudioProcessor +def make_speakers_json_path(out_path): + """Returns conventional speakers.json location.""" + return os.path.join(out_path, "speakers.json") + + +def load_speaker_mapping(out_path): + """Loads speaker mapping if already present.""" + if os.path.splitext(out_path)[1] == ".json": + json_file = out_path + else: + json_file = make_speakers_json_path(out_path) + with open(json_file) as f: + return json.load(f) + + +def save_speaker_mapping(out_path, speaker_mapping): + """Saves speaker mapping if not yet present.""" + if out_path is not None: + speakers_json_path = make_speakers_json_path(out_path) + with open(speakers_json_path, "w") as f: + json.dump(speaker_mapping, f, indent=4) + + +def get_speaker_manager(c, args, meta_data_train): + """Inititalize and return a `SpeakerManager` based on config values""" + speaker_manager = SpeakerManager() + if c.use_speaker_embedding: + speaker_manager.set_speaker_ids_from_data(meta_data_train) + if args.restore_path: + # restoring speaker manager from a previous run. + if c.use_external_speaker_embedding_file: + # restore speaker manager with the embedding file + speakers_file = os.path.dirname(args.restore_path) + if not os.path.exists(speakers_file): + print( + "WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file" + ) + if not os.path.exists(c.external_speaker_embedding_file): + raise RuntimeError( + "You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.external_speaker_embedding_file" + ) + speaker_manager.load_x_vectors_file(c.external_speaker_embedding_file) + speaker_manager.set_x_vectors_from_file(speakers_file) + elif not c.use_external_speaker_embedding_file: # restor speaker manager with speaker ID file. + speakers_file = os.path.dirname(args.restore_path) + speaker_ids_from_data = speaker_manager.speaker_ids + speaker_manager.set_speaker_ids_from_file(speakers_file) + assert all( + speaker in speaker_manager.speaker_ids for speaker in speaker_ids_from_data + ), " [!] You cannot introduce new speakers to a pre-trained model." + elif c.use_external_speaker_embedding_file and c.external_speaker_embedding_file: + # new speaker manager with external speaker embeddings. + speaker_manager.set_x_vectors_from_file(c.external_speaker_embedding_file) + elif ( + c.use_external_speaker_embedding_file and not c.external_speaker_embedding_file + ): # new speaker manager with speaker IDs file. + raise "use_external_speaker_embedding_file is True, so you need pass a external speaker embedding file, run GE2E-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb or AngularPrototypical-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb notebook in notebooks/ folder" + print( + " > Training with {} speakers: {}".format( + speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids) + ) + ) + return speaker_manager + + class SpeakerManager: """It manages the multi-speaker setup for 🐸TTS models. It loads the speaker files and parses the information in a way that you can query. There are 3 different scenarios considered. @@ -64,24 +130,24 @@ class SpeakerManager: self.speaker_encoder_ap = None if data_items: - self.speaker_ids = self.parse_speakers() + self.speaker_ids, _ = self.parse_speakers_from_data(self.data_items) if x_vectors_file_path: - self.load_x_vectors_file(x_vectors_file_path) + self.set_x_vectors_from_file(x_vectors_file_path) if speaker_id_file_path: - self.load_ids_file(speaker_id_file_path) + self.set_speaker_ids_from_file(speaker_id_file_path) if encoder_model_path and encoder_config_path: self.init_speaker_encoder(encoder_model_path, encoder_config_path) @staticmethod - def _load_json(json_file_path: str): + def _load_json(json_file_path: str) -> Dict: with open(json_file_path) as f: return json.load(f) @staticmethod - def _save_json(json_file_path: str, data: dict): + def _save_json(json_file_path: str, data: dict) -> None: with open(json_file_path, "w") as f: json.dump(data, f, indent=4) @@ -91,35 +157,101 @@ class SpeakerManager: @property def x_vector_dim(self): - return len(self.x_vectors[list(self.x_vectors.keys())[0]]["embedding"]) + """Dimensionality of x_vectors. If x_vectors are not loaded, returns zero.""" + if self.x_vectors: + return len(self.x_vectors[list(self.x_vectors.keys())[0]]["embedding"]) + return 0 - def parse_speakers_from_items(self, items: list): + @staticmethod + def parse_speakers_from_data(items: list) -> Tuple[Dict, int]: + """Parse speaker IDs from data samples retured by `load_meta_data()`. + + Args: + items (list): Data sampled returned by `load_meta_data()`. + + Returns: + Tuple[Dict, int]: speaker IDs and number of speakers. + """ speakers = sorted({item[2] for item in items}) - self.speaker_ids = {name: i for i, name in enumerate(speakers)} - num_speakers = len(self.speaker_ids) - return self.speaker_ids, num_speakers + speaker_ids = {name: i for i, name in enumerate(speakers)} + num_speakers = len(speaker_ids) + return speaker_ids, num_speakers - def save_ids_file(self, file_path: str): - self._save_json(file_path, self.speaker_ids) + def set_speaker_ids_from_data(self, items: List) -> None: + """Set speaker IDs from data samples. - def load_ids_file(self, file_path: str): + Args: + items (List): Data sampled returned by `load_meta_data()`. + """ + self.speaker_ids, _ = self.parse_speakers_from_data(items) + + def set_speaker_ids_from_file(self, file_path: str) -> None: + """Set speaker IDs from a file. + + Args: + file_path (str): Path to the file. + """ self.speaker_ids = self._load_json(file_path) - def save_x_vectors_file(self, file_path: str): + def save_speaker_ids_to_file(self, file_path: str) -> None: + """Save speaker IDs to a json file. + + Args: + file_path (str): Path to the output file. + """ + self._save_json(file_path, self.speaker_ids) + + def save_x_vectors_to_file(self, file_path: str) -> None: + """Save x_vectors to a json file. + + Args: + file_path (str): Path to the output file. + """ self._save_json(file_path, self.x_vectors) - def load_x_vectors_file(self, file_path: str): + def set_x_vectors_from_file(self, file_path: str) -> None: + """Load x_vectors from a json file. + + Args: + file_path (str): Path to the target json file. + """ self.x_vectors = self._load_json(file_path) self.speaker_ids = list(set(sorted(x["name"] for x in self.x_vectors.values()))) self.clip_ids = list(set(sorted(clip_name for clip_name in self.x_vectors.keys()))) - def get_x_vector_by_clip(self, clip_idx: str): + def get_x_vector_by_clip(self, clip_idx: str) -> List: + """Get x_vector by clip ID. + + Args: + clip_idx (str): Target clip ID. + + Returns: + List: x_vector as a list. + """ return self.x_vectors[clip_idx]["embedding"] - def get_x_vectors_by_speaker(self, speaker_idx: str): + def get_x_vectors_by_speaker(self, speaker_idx: str) -> List[List]: + """Get all x_vectors of a speaker. + + Args: + speaker_idx (str): Target speaker ID. + + Returns: + List[List]: all the x_vectors of the given speaker. + """ return [x["embedding"] for x in self.x_vectors.values() if x["name"] == speaker_idx] - def get_mean_x_vector(self, speaker_idx: str, num_samples: int = None, randomize: bool = False): + def get_mean_x_vector(self, speaker_idx: str, num_samples: int = None, randomize: bool = False) -> np.Array: + """Get mean x_vector of a speaker ID. + + Args: + speaker_idx (str): Target speaker ID. + num_samples (int, optional): Number of samples to be averaged. Defaults to None. + randomize (bool, optional): Pick random `num_samples`of x_vectors. Defaults to False. + + Returns: + np.Array: Mean x_vector. + """ x_vectors = self.get_x_vectors_by_speaker(speaker_idx) if num_samples is None: x_vectors = np.stack(x_vectors).mean(0) @@ -131,13 +263,19 @@ class SpeakerManager: x_vectors = np.stack(x_vectors[:num_samples]).mean(0) return x_vectors - def get_speakers(self): + def get_speakers(self) -> List: return self.speaker_ids - def get_clips(self): + def get_clips(self) -> List: return sorted(self.x_vectors.keys()) def init_speaker_encoder(self, model_path: str, config_path: str) -> None: + """Initialize a speaker encoder model. + + Args: + model_path (str): Model file path. + config_path (str): Model config file path. + """ self.speaker_encoder_config = load_config(config_path) self.speaker_encoder = setup_model(self.speaker_encoder_config) self.speaker_encoder.load_checkpoint(config_path, model_path, True) @@ -147,6 +285,15 @@ class SpeakerManager: self.speaker_encoder_ap.do_trim_silence = True def compute_x_vector_from_clip(self, wav_file: Union[str, list]) -> list: + """Compute a x_vector from a given audio file. + + Args: + wav_file (Union[str, list]): Target file path. + + Returns: + list: Computed x_vector. + """ + def _compute(wav_file: str): waveform = self.speaker_encoder_ap.load_wav(wav_file, sr=self.speaker_encoder_ap.sample_rate) spec = self.speaker_encoder_ap.melspectrogram(waveform) @@ -168,7 +315,15 @@ class SpeakerManager: x_vector = _compute(wav_file) return x_vector[0].tolist() - def compute_x_vector(self, feats): + def compute_x_vector(self, feats: Union[torch.Tensor, np.Array]) -> List: + """Compute x_vector from features. + + Args: + feats (Union[torch.Tensor, np.Array]): Input features. + + Returns: + List: computed x_vector. + """ if isinstance(feats, np.ndarray): feats = torch.from_numpy(feats) if feats.ndim == 2: