From 4fdc864f746619c386a37c2b4f101780e5d77275 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 11 Mar 2022 19:01:51 -0300 Subject: [PATCH] Add EmbeddingManager and BaseIDManager --- TTS/tts/utils/managers.py | 279 ++++++++++++++++++++++++++++++++++++++ TTS/tts/utils/speakers.py | 264 ++---------------------------------- 2 files changed, 291 insertions(+), 252 deletions(-) diff --git a/TTS/tts/utils/managers.py b/TTS/tts/utils/managers.py index e69de29b..7242a31d 100644 --- a/TTS/tts/utils/managers.py +++ b/TTS/tts/utils/managers.py @@ -0,0 +1,279 @@ +import json +import random +from typing import Any, Dict, List, Union + +import fsspec +import numpy as np +import torch + +from TTS.config import load_config +from TTS.encoder.utils.generic_utils import setup_encoder_model +from TTS.utils.audio import AudioProcessor + + +class BaseIDManager: + """ Base `ID` Manager class. Every new `ID` manager must inherit this. + It defines common `ID` manager specific functions. + """ + def __init__( + self, + id_file_path: str = "" + ): + self.ids = {} + + if id_file_path: + self.set_ids_from_file(id_file_path) + + @staticmethod + def _load_json(json_file_path: str) -> Dict: + with fsspec.open(json_file_path, "r") as f: + return json.load(f) + + @staticmethod + def _save_json(json_file_path: str, data: dict) -> None: + with fsspec.open(json_file_path, "w") as f: + json.dump(data, f, indent=4) + + + def set_ids_from_data(self, items: List) -> None: + """Set IDs from data samples. + + Args: + items (List): Data sampled returned by `load_tts_samples()`. + """ + self.ids, _ = self.parse_ids_from_data(items) + + def set_ids_from_file(self, file_path: str) -> None: + """Set speaker IDs from a file. + + Args: + file_path (str): Path to the file. + """ + self.ids = self._load_json(file_path) + + def save_ids_to_file(self, file_path: str) -> None: + """Save speaker IDs to a json file. + + Args: + file_path (str): Path to the output file. + """ + self._save_json(file_path, self.ids) + + def get_random_speaker_id(self) -> Any: + """Get a random embedding. + + Args: + + Returns: + np.ndarray: embedding. + """ + if self.ids: + return self.ids[random.choices(list(self.ids.keys()))[0]] + + return None + + @staticmethod + def parse_ids_from_data(items: list) -> Any: + raise NotImplementedError + + +class EmbeddingManager(BaseIDManager): + """ Base `Embedding` Manager class. Every new `Embedding` manager must inherit this. + It defines common `Embedding` manager specific functions. + """ + def __init__( + self, + embedding_file_path: str = "", + id_file_path: str = "", + encoder_model_path: str = "", + encoder_config_path: str = "", + use_cuda: bool = False, + ): + super().__init__(id_file_path=id_file_path) + + self.embeddings = {} + self.embeddings_by_names = {} + self.clip_ids = [] + self.encoder = None + self.encoder_ap = None + self.use_cuda = use_cuda + + if embedding_file_path: + self.set_embeddings_from_file(embedding_file_path) + + if encoder_model_path and encoder_config_path: + self.init_encoder(encoder_model_path, encoder_config_path) + + @property + def embedding_dim(self): + """Dimensionality of embeddings. If embeddings are not loaded, returns zero.""" + if self.embeddings: + return len(self.embeddings[list(self.embeddings.keys())[0]]["embedding"]) + return 0 + + def save_embeddings_to_file(self, file_path: str) -> None: + """Save embeddings to a json file. + + Args: + file_path (str): Path to the output file. + """ + self._save_json(file_path, self.embeddings) + + def set_embeddings_from_file(self, file_path: str) -> None: + """Load embeddings from a json file. + + Args: + file_path (str): Path to the target json file. + """ + self.embeddings = self._load_json(file_path) + + speakers = sorted({x["name"] for x in self.embeddings.values()}) + self.ids = {name: i for i, name in enumerate(speakers)} + + self.clip_ids = list(set(sorted(clip_name for clip_name in self.embeddings.keys()))) + # cache embeddings_by_names for fast inference using a bigger speakers.json + self.embeddings_by_names = self.get_embeddings_by_names() + + def get_embedding_by_clip(self, clip_idx: str) -> List: + """Get embedding by clip ID. + + Args: + clip_idx (str): Target clip ID. + + Returns: + List: embedding as a list. + """ + return self.embeddings[clip_idx]["embedding"] + + def get_embeddings_by_name(self, idx: str) -> List[List]: + """Get all embeddings of a speaker. + + Args: + idx (str): Target name. + + Returns: + List[List]: all the embeddings of the given speaker. + """ + return self.embeddings_by_names[idx] + + def get_embeddings_by_names(self) -> Dict: + """Get all embeddings by names. + + Returns: + Dict: all the embeddings of each speaker. + """ + embeddings_by_names = {} + for x in self.embeddings.values(): + if x["name"] not in embeddings_by_names.keys(): + embeddings_by_names[x["name"]] = [x["embedding"]] + else: + embeddings_by_names[x["name"]].append(x["embedding"]) + return embeddings_by_names + + def get_mean_embedding(self, idx: str, num_samples: int = None, randomize: bool = False) -> np.ndarray: + """Get mean embedding of a idx. + + Args: + idx (str): Target name. + num_samples (int, optional): Number of samples to be averaged. Defaults to None. + randomize (bool, optional): Pick random `num_samples` of embeddings. Defaults to False. + + Returns: + np.ndarray: Mean embedding. + """ + embeddings = self.get_embeddings_by_name(idx) + if num_samples is None: + embeddings = np.stack(embeddings).mean(0) + else: + assert len(embeddings) >= num_samples, f" [!] {idx} has number of samples < {num_samples}" + if randomize: + embeddings = np.stack(random.choices(embeddings, k=num_samples)).mean(0) + else: + embeddings = np.stack(embeddings[:num_samples]).mean(0) + return embeddings + + def get_random_embedding(self) -> Any: + """Get a random embedding. + + Args: + + Returns: + np.ndarray: embedding. + """ + if self.embeddings: + return self.embeddings[random.choices(list(self.embeddings.keys()))[0]]["embedding"] + + return None + + def get_clips(self) -> List: + return sorted(self.embeddings.keys()) + + def init_encoder(self, model_path: str, config_path: str) -> None: + """Initialize a speaker encoder model. + + Args: + model_path (str): Model file path. + config_path (str): Model config file path. + """ + self.encoder_config = load_config(config_path) + self.encoder = setup_encoder_model(self.encoder_config) + self.encoder_criterion = self.encoder.load_checkpoint(self.encoder_config, model_path, eval=True, use_cuda=self.use_cuda) + self.encoder_ap = AudioProcessor(**self.encoder_config.audio) + + def compute_embedding_from_clip(self, wav_file: Union[str, List[str]]) -> list: + """Compute a embedding from a given audio file. + + Args: + wav_file (Union[str, List[str]]): Target file path. + + Returns: + list: Computed embedding. + """ + + def _compute(wav_file: str): + waveform = self.encoder_ap.load_wav(wav_file, sr=self.encoder_ap.sample_rate) + if not self.encoder_config.model_params.get("use_torch_spec", False): + m_input = self.encoder_ap.melspectrogram(waveform) + m_input = torch.from_numpy(m_input) + else: + m_input = torch.from_numpy(waveform) + + if self.use_cuda: + m_input = m_input.cuda() + m_input = m_input.unsqueeze(0) + embedding = self.encoder.compute_embedding(m_input) + return embedding + + if isinstance(wav_file, list): + # compute the mean embedding + embeddings = None + for wf in wav_file: + embedding = _compute(wf) + if embeddings is None: + embeddings = embedding + else: + embeddings += embedding + return (embeddings / len(wav_file))[0].tolist() + embedding = _compute(wav_file) + return embedding[0].tolist() + + def compute_embedding(self, feats: Union[torch.Tensor, np.ndarray]) -> List: + """Compute embedding from features. + + Args: + feats (Union[torch.Tensor, np.ndarray]): Input features. + + Returns: + List: computed embedding. + """ + if isinstance(feats, np.ndarray): + feats = torch.from_numpy(feats) + if feats.ndim == 2: + feats = feats.unsqueeze(0) + if self.use_cuda: + feats = feats.cuda() + return self.encoder.compute_embedding(feats) + + @staticmethod + def parse_ids_from_data(items: list) -> Any: + raise NotImplementedError diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index 9f1fc7a9..79631c65 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -1,6 +1,5 @@ import json import os -import random from typing import Any, Dict, List, Tuple, Union import fsspec @@ -8,12 +7,10 @@ import numpy as np import torch from coqpit import Coqpit -from TTS.config import get_from_config_or_model_args_with_default, load_config -from TTS.encoder.utils.generic_utils import setup_encoder_model -from TTS.utils.audio import AudioProcessor +from TTS.config import get_from_config_or_model_args_with_default +from TTS.tts.utils.managers import EmbeddingManager - -class SpeakerManager: +class SpeakerManager(EmbeddingManager): """Manage the speakers for multi-speaker 🐸TTS models. Load a datafile and parse the information in a way that can be queried by speaker or clip. @@ -62,36 +59,16 @@ class SpeakerManager: encoder_config_path: str = "", use_cuda: bool = False, ): - - self.embeddings = {} - self.ids = {} - self.embeddings_by_names = {} - self.clip_ids = [] - self.encoder = None - self.encoder_ap = None - self.use_cuda = use_cuda + super().__init__( + embedding_file_path=d_vectors_file_path, + id_file_path=speaker_id_file_path, + encoder_model_path=encoder_model_path, + encoder_config_path=encoder_config_path, + use_cuda=use_cuda + ) if data_items: - self.ids, _ = self.parse_from_data(data_items) - - if d_vectors_file_path: - self.set_embeddings_from_file(d_vectors_file_path) - - if speaker_id_file_path: - self.set_ids_from_file(speaker_id_file_path) - - if encoder_model_path and encoder_config_path: - self.init_encoder(encoder_model_path, encoder_config_path) - - @staticmethod - def _load_json(json_file_path: str) -> Dict: - with fsspec.open(json_file_path, "r") as f: - return json.load(f) - - @staticmethod - def _save_json(json_file_path: str, data: dict) -> None: - with fsspec.open(json_file_path, "w") as f: - json.dump(data, f, indent=4) + self.ids, _ = self.parse_ids_from_data(data_items) @property def num_speakers(self): @@ -101,15 +78,8 @@ class SpeakerManager: def speaker_names(self): return list(self.ids.keys()) - @property - def embedding_dim(self): - """Dimensionality of embeddings. If embeddings are not loaded, returns zero.""" - if self.embeddings: - return len(self.embeddings[list(self.embeddings.keys())[0]]["embedding"]) - return 0 - @staticmethod - def parse_from_data(items: list) -> Tuple[Dict, int]: + def parse_ids_from_data(items: list) -> Tuple[Dict, int]: """Parse speaker IDs from data samples retured by `load_tts_samples()`. Args: @@ -123,219 +93,9 @@ class SpeakerManager: num_speakers = len(speaker_ids) return speaker_ids, num_speakers - def set_ids_from_data(self, items: List) -> None: - """Set IDs from data samples. - - Args: - items (List): Data sampled returned by `load_tts_samples()`. - """ - self.ids, _ = self.parse_from_data(items) - - def set_ids_from_file(self, file_path: str) -> None: - """Set speaker IDs from a file. - - Args: - file_path (str): Path to the file. - """ - self.ids = self._load_json(file_path) - - def save_ids_to_file(self, file_path: str) -> None: - """Save speaker IDs to a json file. - - Args: - file_path (str): Path to the output file. - """ - self._save_json(file_path, self.ids) - - def save_embeddings_to_file(self, file_path: str) -> None: - """Save embeddings to a json file. - - Args: - file_path (str): Path to the output file. - """ - self._save_json(file_path, self.embeddings) - - def set_embeddings_from_file(self, file_path: str) -> None: - """Load embeddings from a json file. - - Args: - file_path (str): Path to the target json file. - """ - self.embeddings = self._load_json(file_path) - - speakers = sorted({x["name"] for x in self.embeddings.values()}) - self.ids = {name: i for i, name in enumerate(speakers)} - - self.clip_ids = list(set(sorted(clip_name for clip_name in self.embeddings.keys()))) - # cache embeddings_by_names for fast inference using a bigger speakers.json - self.embeddings_by_names = self.get_embeddings_by_names() - - def get_embedding_by_clip(self, clip_idx: str) -> List: - """Get embedding by clip ID. - - Args: - clip_idx (str): Target clip ID. - - Returns: - List: embedding as a list. - """ - return self.embeddings[clip_idx]["embedding"] - - def get_embeddings_by_name(self, idx: str) -> List[List]: - """Get all embeddings of a speaker. - - Args: - idx (str): Target name. - - Returns: - List[List]: all the embeddings of the given speaker. - """ - return self.embeddings_by_names[idx] - - def get_embeddings_by_names(self) -> Dict: - """Get all embeddings by names. - - Returns: - Dict: all the embeddings of each speaker. - """ - embeddings_by_names = {} - for x in self.embeddings.values(): - if x["name"] not in embeddings_by_names.keys(): - embeddings_by_names[x["name"]] = [x["embedding"]] - else: - embeddings_by_names[x["name"]].append(x["embedding"]) - return embeddings_by_names - - def get_mean_embedding(self, idx: str, num_samples: int = None, randomize: bool = False) -> np.ndarray: - """Get mean embedding of a idx. - - Args: - idx (str): Target name. - num_samples (int, optional): Number of samples to be averaged. Defaults to None. - randomize (bool, optional): Pick random `num_samples` of embeddings. Defaults to False. - - Returns: - np.ndarray: Mean embedding. - """ - embeddings = self.get_embeddings_by_name(idx) - if num_samples is None: - embeddings = np.stack(embeddings).mean(0) - else: - assert len(embeddings) >= num_samples, f" [!] {idx} has number of samples < {num_samples}" - if randomize: - embeddings = np.stack(random.choices(embeddings, k=num_samples)).mean(0) - else: - embeddings = np.stack(embeddings[:num_samples]).mean(0) - return embeddings - - def get_random_speaker_id(self) -> Any: - """Get a random embedding. - - Args: - - Returns: - np.ndarray: embedding. - """ - if self.ids: - return self.ids[random.choices(list(self.ids.keys()))[0]] - - return None - - def get_random_embedding(self) -> Any: - """Get a random embedding. - - Args: - - Returns: - np.ndarray: embedding. - """ - if self.embeddings: - return self.embeddings[random.choices(list(self.embeddings.keys()))[0]]["embedding"] - - return None - def get_speakers(self) -> List: return self.ids - def get_clips(self) -> List: - return sorted(self.embeddings.keys()) - - def init_encoder(self, model_path: str, config_path: str) -> None: - """Initialize a speaker encoder model. - - Args: - model_path (str): Model file path. - config_path (str): Model config file path. - """ - self.encoder_config = load_config(config_path) - self.encoder = setup_encoder_model(self.encoder_config) - self.speaker_encoder_criterion = self.speaker_encoder.load_checkpoint( - self.encoder_config, model_path, eval=True, use_cuda=self.use_cuda - ) - self.encoder_ap = AudioProcessor(**self.encoder_config.audio) - - def compute_embedding_from_clip(self, wav_file: Union[str, List[str]]) -> list: - """Compute a embedding from a given audio file. - - Args: - wav_file (Union[str, List[str]]): Target file path. - - Returns: - list: Computed embedding. - """ - - def _compute(wav_file: str): - waveform = self.encoder_ap.load_wav(wav_file, sr=self.encoder_ap.sample_rate) - if not self.encoder_config.model_params.get("use_torch_spec", False): - m_input = self.encoder_ap.melspectrogram(waveform) - m_input = torch.from_numpy(m_input) - else: - m_input = torch.from_numpy(waveform) - - if self.use_cuda: - m_input = m_input.cuda() - m_input = m_input.unsqueeze(0) - embedding = self.encoder.compute_embedding(m_input) - return embedding - - if isinstance(wav_file, list): - # compute the mean embedding - embeddings = None - for wf in wav_file: - embedding = _compute(wf) - if embeddings is None: - embeddings = embedding - else: - embeddings += embedding - return (embeddings / len(wav_file))[0].tolist() - embedding = _compute(wav_file) - return embedding[0].tolist() - - def compute_embedding(self, feats: Union[torch.Tensor, np.ndarray]) -> List: - """Compute embedding from features. - - Args: - feats (Union[torch.Tensor, np.ndarray]): Input features. - - Returns: - List: computed embedding. - """ - if isinstance(feats, np.ndarray): - feats = torch.from_numpy(feats) - if feats.ndim == 2: - feats = feats.unsqueeze(0) - if self.use_cuda: - feats = feats.cuda() - return self.encoder.compute_embedding(feats) - - def run_umap(self): - # TODO: implement speaker encoder - raise NotImplementedError - - def plot_embeddings(self): - # TODO: implement speaker encoder - raise NotImplementedError - @staticmethod def init_from_config(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "SpeakerManager": """Initialize a speaker manager from config