mirror of https://github.com/coqui-ai/TTS.git
Add EmbeddingManager and BaseIDManager
This commit is contained in:
parent
40df2cfdd1
commit
4fdc864f74
|
@ -0,0 +1,279 @@
|
|||
import json
|
||||
import random
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import fsspec
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.encoder.utils.generic_utils import setup_encoder_model
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
class BaseIDManager:
|
||||
""" Base `ID` Manager class. Every new `ID` manager must inherit this.
|
||||
It defines common `ID` manager specific functions.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
id_file_path: str = ""
|
||||
):
|
||||
self.ids = {}
|
||||
|
||||
if id_file_path:
|
||||
self.set_ids_from_file(id_file_path)
|
||||
|
||||
@staticmethod
|
||||
def _load_json(json_file_path: str) -> Dict:
|
||||
with fsspec.open(json_file_path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
@staticmethod
|
||||
def _save_json(json_file_path: str, data: dict) -> None:
|
||||
with fsspec.open(json_file_path, "w") as f:
|
||||
json.dump(data, f, indent=4)
|
||||
|
||||
|
||||
def set_ids_from_data(self, items: List) -> None:
|
||||
"""Set IDs from data samples.
|
||||
|
||||
Args:
|
||||
items (List): Data sampled returned by `load_tts_samples()`.
|
||||
"""
|
||||
self.ids, _ = self.parse_ids_from_data(items)
|
||||
|
||||
def set_ids_from_file(self, file_path: str) -> None:
|
||||
"""Set speaker IDs from a file.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the file.
|
||||
"""
|
||||
self.ids = self._load_json(file_path)
|
||||
|
||||
def save_ids_to_file(self, file_path: str) -> None:
|
||||
"""Save speaker IDs to a json file.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the output file.
|
||||
"""
|
||||
self._save_json(file_path, self.ids)
|
||||
|
||||
def get_random_speaker_id(self) -> Any:
|
||||
"""Get a random embedding.
|
||||
|
||||
Args:
|
||||
|
||||
Returns:
|
||||
np.ndarray: embedding.
|
||||
"""
|
||||
if self.ids:
|
||||
return self.ids[random.choices(list(self.ids.keys()))[0]]
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_ids_from_data(items: list) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class EmbeddingManager(BaseIDManager):
|
||||
""" Base `Embedding` Manager class. Every new `Embedding` manager must inherit this.
|
||||
It defines common `Embedding` manager specific functions.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
embedding_file_path: str = "",
|
||||
id_file_path: str = "",
|
||||
encoder_model_path: str = "",
|
||||
encoder_config_path: str = "",
|
||||
use_cuda: bool = False,
|
||||
):
|
||||
super().__init__(id_file_path=id_file_path)
|
||||
|
||||
self.embeddings = {}
|
||||
self.embeddings_by_names = {}
|
||||
self.clip_ids = []
|
||||
self.encoder = None
|
||||
self.encoder_ap = None
|
||||
self.use_cuda = use_cuda
|
||||
|
||||
if embedding_file_path:
|
||||
self.set_embeddings_from_file(embedding_file_path)
|
||||
|
||||
if encoder_model_path and encoder_config_path:
|
||||
self.init_encoder(encoder_model_path, encoder_config_path)
|
||||
|
||||
@property
|
||||
def embedding_dim(self):
|
||||
"""Dimensionality of embeddings. If embeddings are not loaded, returns zero."""
|
||||
if self.embeddings:
|
||||
return len(self.embeddings[list(self.embeddings.keys())[0]]["embedding"])
|
||||
return 0
|
||||
|
||||
def save_embeddings_to_file(self, file_path: str) -> None:
|
||||
"""Save embeddings to a json file.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the output file.
|
||||
"""
|
||||
self._save_json(file_path, self.embeddings)
|
||||
|
||||
def set_embeddings_from_file(self, file_path: str) -> None:
|
||||
"""Load embeddings from a json file.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the target json file.
|
||||
"""
|
||||
self.embeddings = self._load_json(file_path)
|
||||
|
||||
speakers = sorted({x["name"] for x in self.embeddings.values()})
|
||||
self.ids = {name: i for i, name in enumerate(speakers)}
|
||||
|
||||
self.clip_ids = list(set(sorted(clip_name for clip_name in self.embeddings.keys())))
|
||||
# cache embeddings_by_names for fast inference using a bigger speakers.json
|
||||
self.embeddings_by_names = self.get_embeddings_by_names()
|
||||
|
||||
def get_embedding_by_clip(self, clip_idx: str) -> List:
|
||||
"""Get embedding by clip ID.
|
||||
|
||||
Args:
|
||||
clip_idx (str): Target clip ID.
|
||||
|
||||
Returns:
|
||||
List: embedding as a list.
|
||||
"""
|
||||
return self.embeddings[clip_idx]["embedding"]
|
||||
|
||||
def get_embeddings_by_name(self, idx: str) -> List[List]:
|
||||
"""Get all embeddings of a speaker.
|
||||
|
||||
Args:
|
||||
idx (str): Target name.
|
||||
|
||||
Returns:
|
||||
List[List]: all the embeddings of the given speaker.
|
||||
"""
|
||||
return self.embeddings_by_names[idx]
|
||||
|
||||
def get_embeddings_by_names(self) -> Dict:
|
||||
"""Get all embeddings by names.
|
||||
|
||||
Returns:
|
||||
Dict: all the embeddings of each speaker.
|
||||
"""
|
||||
embeddings_by_names = {}
|
||||
for x in self.embeddings.values():
|
||||
if x["name"] not in embeddings_by_names.keys():
|
||||
embeddings_by_names[x["name"]] = [x["embedding"]]
|
||||
else:
|
||||
embeddings_by_names[x["name"]].append(x["embedding"])
|
||||
return embeddings_by_names
|
||||
|
||||
def get_mean_embedding(self, idx: str, num_samples: int = None, randomize: bool = False) -> np.ndarray:
|
||||
"""Get mean embedding of a idx.
|
||||
|
||||
Args:
|
||||
idx (str): Target name.
|
||||
num_samples (int, optional): Number of samples to be averaged. Defaults to None.
|
||||
randomize (bool, optional): Pick random `num_samples` of embeddings. Defaults to False.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Mean embedding.
|
||||
"""
|
||||
embeddings = self.get_embeddings_by_name(idx)
|
||||
if num_samples is None:
|
||||
embeddings = np.stack(embeddings).mean(0)
|
||||
else:
|
||||
assert len(embeddings) >= num_samples, f" [!] {idx} has number of samples < {num_samples}"
|
||||
if randomize:
|
||||
embeddings = np.stack(random.choices(embeddings, k=num_samples)).mean(0)
|
||||
else:
|
||||
embeddings = np.stack(embeddings[:num_samples]).mean(0)
|
||||
return embeddings
|
||||
|
||||
def get_random_embedding(self) -> Any:
|
||||
"""Get a random embedding.
|
||||
|
||||
Args:
|
||||
|
||||
Returns:
|
||||
np.ndarray: embedding.
|
||||
"""
|
||||
if self.embeddings:
|
||||
return self.embeddings[random.choices(list(self.embeddings.keys()))[0]]["embedding"]
|
||||
|
||||
return None
|
||||
|
||||
def get_clips(self) -> List:
|
||||
return sorted(self.embeddings.keys())
|
||||
|
||||
def init_encoder(self, model_path: str, config_path: str) -> None:
|
||||
"""Initialize a speaker encoder model.
|
||||
|
||||
Args:
|
||||
model_path (str): Model file path.
|
||||
config_path (str): Model config file path.
|
||||
"""
|
||||
self.encoder_config = load_config(config_path)
|
||||
self.encoder = setup_encoder_model(self.encoder_config)
|
||||
self.encoder_criterion = self.encoder.load_checkpoint(self.encoder_config, model_path, eval=True, use_cuda=self.use_cuda)
|
||||
self.encoder_ap = AudioProcessor(**self.encoder_config.audio)
|
||||
|
||||
def compute_embedding_from_clip(self, wav_file: Union[str, List[str]]) -> list:
|
||||
"""Compute a embedding from a given audio file.
|
||||
|
||||
Args:
|
||||
wav_file (Union[str, List[str]]): Target file path.
|
||||
|
||||
Returns:
|
||||
list: Computed embedding.
|
||||
"""
|
||||
|
||||
def _compute(wav_file: str):
|
||||
waveform = self.encoder_ap.load_wav(wav_file, sr=self.encoder_ap.sample_rate)
|
||||
if not self.encoder_config.model_params.get("use_torch_spec", False):
|
||||
m_input = self.encoder_ap.melspectrogram(waveform)
|
||||
m_input = torch.from_numpy(m_input)
|
||||
else:
|
||||
m_input = torch.from_numpy(waveform)
|
||||
|
||||
if self.use_cuda:
|
||||
m_input = m_input.cuda()
|
||||
m_input = m_input.unsqueeze(0)
|
||||
embedding = self.encoder.compute_embedding(m_input)
|
||||
return embedding
|
||||
|
||||
if isinstance(wav_file, list):
|
||||
# compute the mean embedding
|
||||
embeddings = None
|
||||
for wf in wav_file:
|
||||
embedding = _compute(wf)
|
||||
if embeddings is None:
|
||||
embeddings = embedding
|
||||
else:
|
||||
embeddings += embedding
|
||||
return (embeddings / len(wav_file))[0].tolist()
|
||||
embedding = _compute(wav_file)
|
||||
return embedding[0].tolist()
|
||||
|
||||
def compute_embedding(self, feats: Union[torch.Tensor, np.ndarray]) -> List:
|
||||
"""Compute embedding from features.
|
||||
|
||||
Args:
|
||||
feats (Union[torch.Tensor, np.ndarray]): Input features.
|
||||
|
||||
Returns:
|
||||
List: computed embedding.
|
||||
"""
|
||||
if isinstance(feats, np.ndarray):
|
||||
feats = torch.from_numpy(feats)
|
||||
if feats.ndim == 2:
|
||||
feats = feats.unsqueeze(0)
|
||||
if self.use_cuda:
|
||||
feats = feats.cuda()
|
||||
return self.encoder.compute_embedding(feats)
|
||||
|
||||
@staticmethod
|
||||
def parse_ids_from_data(items: list) -> Any:
|
||||
raise NotImplementedError
|
|
@ -1,6 +1,5 @@
|
|||
import json
|
||||
import os
|
||||
import random
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
import fsspec
|
||||
|
@ -8,12 +7,10 @@ import numpy as np
|
|||
import torch
|
||||
from coqpit import Coqpit
|
||||
|
||||
from TTS.config import get_from_config_or_model_args_with_default, load_config
|
||||
from TTS.encoder.utils.generic_utils import setup_encoder_model
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.config import get_from_config_or_model_args_with_default
|
||||
from TTS.tts.utils.managers import EmbeddingManager
|
||||
|
||||
|
||||
class SpeakerManager:
|
||||
class SpeakerManager(EmbeddingManager):
|
||||
"""Manage the speakers for multi-speaker 🐸TTS models. Load a datafile and parse the information
|
||||
in a way that can be queried by speaker or clip.
|
||||
|
||||
|
@ -62,36 +59,16 @@ class SpeakerManager:
|
|||
encoder_config_path: str = "",
|
||||
use_cuda: bool = False,
|
||||
):
|
||||
|
||||
self.embeddings = {}
|
||||
self.ids = {}
|
||||
self.embeddings_by_names = {}
|
||||
self.clip_ids = []
|
||||
self.encoder = None
|
||||
self.encoder_ap = None
|
||||
self.use_cuda = use_cuda
|
||||
super().__init__(
|
||||
embedding_file_path=d_vectors_file_path,
|
||||
id_file_path=speaker_id_file_path,
|
||||
encoder_model_path=encoder_model_path,
|
||||
encoder_config_path=encoder_config_path,
|
||||
use_cuda=use_cuda
|
||||
)
|
||||
|
||||
if data_items:
|
||||
self.ids, _ = self.parse_from_data(data_items)
|
||||
|
||||
if d_vectors_file_path:
|
||||
self.set_embeddings_from_file(d_vectors_file_path)
|
||||
|
||||
if speaker_id_file_path:
|
||||
self.set_ids_from_file(speaker_id_file_path)
|
||||
|
||||
if encoder_model_path and encoder_config_path:
|
||||
self.init_encoder(encoder_model_path, encoder_config_path)
|
||||
|
||||
@staticmethod
|
||||
def _load_json(json_file_path: str) -> Dict:
|
||||
with fsspec.open(json_file_path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
@staticmethod
|
||||
def _save_json(json_file_path: str, data: dict) -> None:
|
||||
with fsspec.open(json_file_path, "w") as f:
|
||||
json.dump(data, f, indent=4)
|
||||
self.ids, _ = self.parse_ids_from_data(data_items)
|
||||
|
||||
@property
|
||||
def num_speakers(self):
|
||||
|
@ -101,15 +78,8 @@ class SpeakerManager:
|
|||
def speaker_names(self):
|
||||
return list(self.ids.keys())
|
||||
|
||||
@property
|
||||
def embedding_dim(self):
|
||||
"""Dimensionality of embeddings. If embeddings are not loaded, returns zero."""
|
||||
if self.embeddings:
|
||||
return len(self.embeddings[list(self.embeddings.keys())[0]]["embedding"])
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def parse_from_data(items: list) -> Tuple[Dict, int]:
|
||||
def parse_ids_from_data(items: list) -> Tuple[Dict, int]:
|
||||
"""Parse speaker IDs from data samples retured by `load_tts_samples()`.
|
||||
|
||||
Args:
|
||||
|
@ -123,219 +93,9 @@ class SpeakerManager:
|
|||
num_speakers = len(speaker_ids)
|
||||
return speaker_ids, num_speakers
|
||||
|
||||
def set_ids_from_data(self, items: List) -> None:
|
||||
"""Set IDs from data samples.
|
||||
|
||||
Args:
|
||||
items (List): Data sampled returned by `load_tts_samples()`.
|
||||
"""
|
||||
self.ids, _ = self.parse_from_data(items)
|
||||
|
||||
def set_ids_from_file(self, file_path: str) -> None:
|
||||
"""Set speaker IDs from a file.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the file.
|
||||
"""
|
||||
self.ids = self._load_json(file_path)
|
||||
|
||||
def save_ids_to_file(self, file_path: str) -> None:
|
||||
"""Save speaker IDs to a json file.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the output file.
|
||||
"""
|
||||
self._save_json(file_path, self.ids)
|
||||
|
||||
def save_embeddings_to_file(self, file_path: str) -> None:
|
||||
"""Save embeddings to a json file.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the output file.
|
||||
"""
|
||||
self._save_json(file_path, self.embeddings)
|
||||
|
||||
def set_embeddings_from_file(self, file_path: str) -> None:
|
||||
"""Load embeddings from a json file.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the target json file.
|
||||
"""
|
||||
self.embeddings = self._load_json(file_path)
|
||||
|
||||
speakers = sorted({x["name"] for x in self.embeddings.values()})
|
||||
self.ids = {name: i for i, name in enumerate(speakers)}
|
||||
|
||||
self.clip_ids = list(set(sorted(clip_name for clip_name in self.embeddings.keys())))
|
||||
# cache embeddings_by_names for fast inference using a bigger speakers.json
|
||||
self.embeddings_by_names = self.get_embeddings_by_names()
|
||||
|
||||
def get_embedding_by_clip(self, clip_idx: str) -> List:
|
||||
"""Get embedding by clip ID.
|
||||
|
||||
Args:
|
||||
clip_idx (str): Target clip ID.
|
||||
|
||||
Returns:
|
||||
List: embedding as a list.
|
||||
"""
|
||||
return self.embeddings[clip_idx]["embedding"]
|
||||
|
||||
def get_embeddings_by_name(self, idx: str) -> List[List]:
|
||||
"""Get all embeddings of a speaker.
|
||||
|
||||
Args:
|
||||
idx (str): Target name.
|
||||
|
||||
Returns:
|
||||
List[List]: all the embeddings of the given speaker.
|
||||
"""
|
||||
return self.embeddings_by_names[idx]
|
||||
|
||||
def get_embeddings_by_names(self) -> Dict:
|
||||
"""Get all embeddings by names.
|
||||
|
||||
Returns:
|
||||
Dict: all the embeddings of each speaker.
|
||||
"""
|
||||
embeddings_by_names = {}
|
||||
for x in self.embeddings.values():
|
||||
if x["name"] not in embeddings_by_names.keys():
|
||||
embeddings_by_names[x["name"]] = [x["embedding"]]
|
||||
else:
|
||||
embeddings_by_names[x["name"]].append(x["embedding"])
|
||||
return embeddings_by_names
|
||||
|
||||
def get_mean_embedding(self, idx: str, num_samples: int = None, randomize: bool = False) -> np.ndarray:
|
||||
"""Get mean embedding of a idx.
|
||||
|
||||
Args:
|
||||
idx (str): Target name.
|
||||
num_samples (int, optional): Number of samples to be averaged. Defaults to None.
|
||||
randomize (bool, optional): Pick random `num_samples` of embeddings. Defaults to False.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Mean embedding.
|
||||
"""
|
||||
embeddings = self.get_embeddings_by_name(idx)
|
||||
if num_samples is None:
|
||||
embeddings = np.stack(embeddings).mean(0)
|
||||
else:
|
||||
assert len(embeddings) >= num_samples, f" [!] {idx} has number of samples < {num_samples}"
|
||||
if randomize:
|
||||
embeddings = np.stack(random.choices(embeddings, k=num_samples)).mean(0)
|
||||
else:
|
||||
embeddings = np.stack(embeddings[:num_samples]).mean(0)
|
||||
return embeddings
|
||||
|
||||
def get_random_speaker_id(self) -> Any:
|
||||
"""Get a random embedding.
|
||||
|
||||
Args:
|
||||
|
||||
Returns:
|
||||
np.ndarray: embedding.
|
||||
"""
|
||||
if self.ids:
|
||||
return self.ids[random.choices(list(self.ids.keys()))[0]]
|
||||
|
||||
return None
|
||||
|
||||
def get_random_embedding(self) -> Any:
|
||||
"""Get a random embedding.
|
||||
|
||||
Args:
|
||||
|
||||
Returns:
|
||||
np.ndarray: embedding.
|
||||
"""
|
||||
if self.embeddings:
|
||||
return self.embeddings[random.choices(list(self.embeddings.keys()))[0]]["embedding"]
|
||||
|
||||
return None
|
||||
|
||||
def get_speakers(self) -> List:
|
||||
return self.ids
|
||||
|
||||
def get_clips(self) -> List:
|
||||
return sorted(self.embeddings.keys())
|
||||
|
||||
def init_encoder(self, model_path: str, config_path: str) -> None:
|
||||
"""Initialize a speaker encoder model.
|
||||
|
||||
Args:
|
||||
model_path (str): Model file path.
|
||||
config_path (str): Model config file path.
|
||||
"""
|
||||
self.encoder_config = load_config(config_path)
|
||||
self.encoder = setup_encoder_model(self.encoder_config)
|
||||
self.speaker_encoder_criterion = self.speaker_encoder.load_checkpoint(
|
||||
self.encoder_config, model_path, eval=True, use_cuda=self.use_cuda
|
||||
)
|
||||
self.encoder_ap = AudioProcessor(**self.encoder_config.audio)
|
||||
|
||||
def compute_embedding_from_clip(self, wav_file: Union[str, List[str]]) -> list:
|
||||
"""Compute a embedding from a given audio file.
|
||||
|
||||
Args:
|
||||
wav_file (Union[str, List[str]]): Target file path.
|
||||
|
||||
Returns:
|
||||
list: Computed embedding.
|
||||
"""
|
||||
|
||||
def _compute(wav_file: str):
|
||||
waveform = self.encoder_ap.load_wav(wav_file, sr=self.encoder_ap.sample_rate)
|
||||
if not self.encoder_config.model_params.get("use_torch_spec", False):
|
||||
m_input = self.encoder_ap.melspectrogram(waveform)
|
||||
m_input = torch.from_numpy(m_input)
|
||||
else:
|
||||
m_input = torch.from_numpy(waveform)
|
||||
|
||||
if self.use_cuda:
|
||||
m_input = m_input.cuda()
|
||||
m_input = m_input.unsqueeze(0)
|
||||
embedding = self.encoder.compute_embedding(m_input)
|
||||
return embedding
|
||||
|
||||
if isinstance(wav_file, list):
|
||||
# compute the mean embedding
|
||||
embeddings = None
|
||||
for wf in wav_file:
|
||||
embedding = _compute(wf)
|
||||
if embeddings is None:
|
||||
embeddings = embedding
|
||||
else:
|
||||
embeddings += embedding
|
||||
return (embeddings / len(wav_file))[0].tolist()
|
||||
embedding = _compute(wav_file)
|
||||
return embedding[0].tolist()
|
||||
|
||||
def compute_embedding(self, feats: Union[torch.Tensor, np.ndarray]) -> List:
|
||||
"""Compute embedding from features.
|
||||
|
||||
Args:
|
||||
feats (Union[torch.Tensor, np.ndarray]): Input features.
|
||||
|
||||
Returns:
|
||||
List: computed embedding.
|
||||
"""
|
||||
if isinstance(feats, np.ndarray):
|
||||
feats = torch.from_numpy(feats)
|
||||
if feats.ndim == 2:
|
||||
feats = feats.unsqueeze(0)
|
||||
if self.use_cuda:
|
||||
feats = feats.cuda()
|
||||
return self.encoder.compute_embedding(feats)
|
||||
|
||||
def run_umap(self):
|
||||
# TODO: implement speaker encoder
|
||||
raise NotImplementedError
|
||||
|
||||
def plot_embeddings(self):
|
||||
# TODO: implement speaker encoder
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "SpeakerManager":
|
||||
"""Initialize a speaker manager from config
|
||||
|
|
Loading…
Reference in New Issue