Add EmbeddingManager and BaseIDManager

This commit is contained in:
Edresson Casanova 2022-03-11 19:01:51 -03:00
parent 40df2cfdd1
commit 4fdc864f74
2 changed files with 291 additions and 252 deletions

View File

@ -0,0 +1,279 @@
import json
import random
from typing import Any, Dict, List, Union
import fsspec
import numpy as np
import torch
from TTS.config import load_config
from TTS.encoder.utils.generic_utils import setup_encoder_model
from TTS.utils.audio import AudioProcessor
class BaseIDManager:
""" Base `ID` Manager class. Every new `ID` manager must inherit this.
It defines common `ID` manager specific functions.
"""
def __init__(
self,
id_file_path: str = ""
):
self.ids = {}
if id_file_path:
self.set_ids_from_file(id_file_path)
@staticmethod
def _load_json(json_file_path: str) -> Dict:
with fsspec.open(json_file_path, "r") as f:
return json.load(f)
@staticmethod
def _save_json(json_file_path: str, data: dict) -> None:
with fsspec.open(json_file_path, "w") as f:
json.dump(data, f, indent=4)
def set_ids_from_data(self, items: List) -> None:
"""Set IDs from data samples.
Args:
items (List): Data sampled returned by `load_tts_samples()`.
"""
self.ids, _ = self.parse_ids_from_data(items)
def set_ids_from_file(self, file_path: str) -> None:
"""Set speaker IDs from a file.
Args:
file_path (str): Path to the file.
"""
self.ids = self._load_json(file_path)
def save_ids_to_file(self, file_path: str) -> None:
"""Save speaker IDs to a json file.
Args:
file_path (str): Path to the output file.
"""
self._save_json(file_path, self.ids)
def get_random_speaker_id(self) -> Any:
"""Get a random embedding.
Args:
Returns:
np.ndarray: embedding.
"""
if self.ids:
return self.ids[random.choices(list(self.ids.keys()))[0]]
return None
@staticmethod
def parse_ids_from_data(items: list) -> Any:
raise NotImplementedError
class EmbeddingManager(BaseIDManager):
""" Base `Embedding` Manager class. Every new `Embedding` manager must inherit this.
It defines common `Embedding` manager specific functions.
"""
def __init__(
self,
embedding_file_path: str = "",
id_file_path: str = "",
encoder_model_path: str = "",
encoder_config_path: str = "",
use_cuda: bool = False,
):
super().__init__(id_file_path=id_file_path)
self.embeddings = {}
self.embeddings_by_names = {}
self.clip_ids = []
self.encoder = None
self.encoder_ap = None
self.use_cuda = use_cuda
if embedding_file_path:
self.set_embeddings_from_file(embedding_file_path)
if encoder_model_path and encoder_config_path:
self.init_encoder(encoder_model_path, encoder_config_path)
@property
def embedding_dim(self):
"""Dimensionality of embeddings. If embeddings are not loaded, returns zero."""
if self.embeddings:
return len(self.embeddings[list(self.embeddings.keys())[0]]["embedding"])
return 0
def save_embeddings_to_file(self, file_path: str) -> None:
"""Save embeddings to a json file.
Args:
file_path (str): Path to the output file.
"""
self._save_json(file_path, self.embeddings)
def set_embeddings_from_file(self, file_path: str) -> None:
"""Load embeddings from a json file.
Args:
file_path (str): Path to the target json file.
"""
self.embeddings = self._load_json(file_path)
speakers = sorted({x["name"] for x in self.embeddings.values()})
self.ids = {name: i for i, name in enumerate(speakers)}
self.clip_ids = list(set(sorted(clip_name for clip_name in self.embeddings.keys())))
# cache embeddings_by_names for fast inference using a bigger speakers.json
self.embeddings_by_names = self.get_embeddings_by_names()
def get_embedding_by_clip(self, clip_idx: str) -> List:
"""Get embedding by clip ID.
Args:
clip_idx (str): Target clip ID.
Returns:
List: embedding as a list.
"""
return self.embeddings[clip_idx]["embedding"]
def get_embeddings_by_name(self, idx: str) -> List[List]:
"""Get all embeddings of a speaker.
Args:
idx (str): Target name.
Returns:
List[List]: all the embeddings of the given speaker.
"""
return self.embeddings_by_names[idx]
def get_embeddings_by_names(self) -> Dict:
"""Get all embeddings by names.
Returns:
Dict: all the embeddings of each speaker.
"""
embeddings_by_names = {}
for x in self.embeddings.values():
if x["name"] not in embeddings_by_names.keys():
embeddings_by_names[x["name"]] = [x["embedding"]]
else:
embeddings_by_names[x["name"]].append(x["embedding"])
return embeddings_by_names
def get_mean_embedding(self, idx: str, num_samples: int = None, randomize: bool = False) -> np.ndarray:
"""Get mean embedding of a idx.
Args:
idx (str): Target name.
num_samples (int, optional): Number of samples to be averaged. Defaults to None.
randomize (bool, optional): Pick random `num_samples` of embeddings. Defaults to False.
Returns:
np.ndarray: Mean embedding.
"""
embeddings = self.get_embeddings_by_name(idx)
if num_samples is None:
embeddings = np.stack(embeddings).mean(0)
else:
assert len(embeddings) >= num_samples, f" [!] {idx} has number of samples < {num_samples}"
if randomize:
embeddings = np.stack(random.choices(embeddings, k=num_samples)).mean(0)
else:
embeddings = np.stack(embeddings[:num_samples]).mean(0)
return embeddings
def get_random_embedding(self) -> Any:
"""Get a random embedding.
Args:
Returns:
np.ndarray: embedding.
"""
if self.embeddings:
return self.embeddings[random.choices(list(self.embeddings.keys()))[0]]["embedding"]
return None
def get_clips(self) -> List:
return sorted(self.embeddings.keys())
def init_encoder(self, model_path: str, config_path: str) -> None:
"""Initialize a speaker encoder model.
Args:
model_path (str): Model file path.
config_path (str): Model config file path.
"""
self.encoder_config = load_config(config_path)
self.encoder = setup_encoder_model(self.encoder_config)
self.encoder_criterion = self.encoder.load_checkpoint(self.encoder_config, model_path, eval=True, use_cuda=self.use_cuda)
self.encoder_ap = AudioProcessor(**self.encoder_config.audio)
def compute_embedding_from_clip(self, wav_file: Union[str, List[str]]) -> list:
"""Compute a embedding from a given audio file.
Args:
wav_file (Union[str, List[str]]): Target file path.
Returns:
list: Computed embedding.
"""
def _compute(wav_file: str):
waveform = self.encoder_ap.load_wav(wav_file, sr=self.encoder_ap.sample_rate)
if not self.encoder_config.model_params.get("use_torch_spec", False):
m_input = self.encoder_ap.melspectrogram(waveform)
m_input = torch.from_numpy(m_input)
else:
m_input = torch.from_numpy(waveform)
if self.use_cuda:
m_input = m_input.cuda()
m_input = m_input.unsqueeze(0)
embedding = self.encoder.compute_embedding(m_input)
return embedding
if isinstance(wav_file, list):
# compute the mean embedding
embeddings = None
for wf in wav_file:
embedding = _compute(wf)
if embeddings is None:
embeddings = embedding
else:
embeddings += embedding
return (embeddings / len(wav_file))[0].tolist()
embedding = _compute(wav_file)
return embedding[0].tolist()
def compute_embedding(self, feats: Union[torch.Tensor, np.ndarray]) -> List:
"""Compute embedding from features.
Args:
feats (Union[torch.Tensor, np.ndarray]): Input features.
Returns:
List: computed embedding.
"""
if isinstance(feats, np.ndarray):
feats = torch.from_numpy(feats)
if feats.ndim == 2:
feats = feats.unsqueeze(0)
if self.use_cuda:
feats = feats.cuda()
return self.encoder.compute_embedding(feats)
@staticmethod
def parse_ids_from_data(items: list) -> Any:
raise NotImplementedError

View File

@ -1,6 +1,5 @@
import json
import os
import random
from typing import Any, Dict, List, Tuple, Union
import fsspec
@ -8,12 +7,10 @@ import numpy as np
import torch
from coqpit import Coqpit
from TTS.config import get_from_config_or_model_args_with_default, load_config
from TTS.encoder.utils.generic_utils import setup_encoder_model
from TTS.utils.audio import AudioProcessor
from TTS.config import get_from_config_or_model_args_with_default
from TTS.tts.utils.managers import EmbeddingManager
class SpeakerManager:
class SpeakerManager(EmbeddingManager):
"""Manage the speakers for multi-speaker 🐸TTS models. Load a datafile and parse the information
in a way that can be queried by speaker or clip.
@ -62,36 +59,16 @@ class SpeakerManager:
encoder_config_path: str = "",
use_cuda: bool = False,
):
self.embeddings = {}
self.ids = {}
self.embeddings_by_names = {}
self.clip_ids = []
self.encoder = None
self.encoder_ap = None
self.use_cuda = use_cuda
super().__init__(
embedding_file_path=d_vectors_file_path,
id_file_path=speaker_id_file_path,
encoder_model_path=encoder_model_path,
encoder_config_path=encoder_config_path,
use_cuda=use_cuda
)
if data_items:
self.ids, _ = self.parse_from_data(data_items)
if d_vectors_file_path:
self.set_embeddings_from_file(d_vectors_file_path)
if speaker_id_file_path:
self.set_ids_from_file(speaker_id_file_path)
if encoder_model_path and encoder_config_path:
self.init_encoder(encoder_model_path, encoder_config_path)
@staticmethod
def _load_json(json_file_path: str) -> Dict:
with fsspec.open(json_file_path, "r") as f:
return json.load(f)
@staticmethod
def _save_json(json_file_path: str, data: dict) -> None:
with fsspec.open(json_file_path, "w") as f:
json.dump(data, f, indent=4)
self.ids, _ = self.parse_ids_from_data(data_items)
@property
def num_speakers(self):
@ -101,15 +78,8 @@ class SpeakerManager:
def speaker_names(self):
return list(self.ids.keys())
@property
def embedding_dim(self):
"""Dimensionality of embeddings. If embeddings are not loaded, returns zero."""
if self.embeddings:
return len(self.embeddings[list(self.embeddings.keys())[0]]["embedding"])
return 0
@staticmethod
def parse_from_data(items: list) -> Tuple[Dict, int]:
def parse_ids_from_data(items: list) -> Tuple[Dict, int]:
"""Parse speaker IDs from data samples retured by `load_tts_samples()`.
Args:
@ -123,219 +93,9 @@ class SpeakerManager:
num_speakers = len(speaker_ids)
return speaker_ids, num_speakers
def set_ids_from_data(self, items: List) -> None:
"""Set IDs from data samples.
Args:
items (List): Data sampled returned by `load_tts_samples()`.
"""
self.ids, _ = self.parse_from_data(items)
def set_ids_from_file(self, file_path: str) -> None:
"""Set speaker IDs from a file.
Args:
file_path (str): Path to the file.
"""
self.ids = self._load_json(file_path)
def save_ids_to_file(self, file_path: str) -> None:
"""Save speaker IDs to a json file.
Args:
file_path (str): Path to the output file.
"""
self._save_json(file_path, self.ids)
def save_embeddings_to_file(self, file_path: str) -> None:
"""Save embeddings to a json file.
Args:
file_path (str): Path to the output file.
"""
self._save_json(file_path, self.embeddings)
def set_embeddings_from_file(self, file_path: str) -> None:
"""Load embeddings from a json file.
Args:
file_path (str): Path to the target json file.
"""
self.embeddings = self._load_json(file_path)
speakers = sorted({x["name"] for x in self.embeddings.values()})
self.ids = {name: i for i, name in enumerate(speakers)}
self.clip_ids = list(set(sorted(clip_name for clip_name in self.embeddings.keys())))
# cache embeddings_by_names for fast inference using a bigger speakers.json
self.embeddings_by_names = self.get_embeddings_by_names()
def get_embedding_by_clip(self, clip_idx: str) -> List:
"""Get embedding by clip ID.
Args:
clip_idx (str): Target clip ID.
Returns:
List: embedding as a list.
"""
return self.embeddings[clip_idx]["embedding"]
def get_embeddings_by_name(self, idx: str) -> List[List]:
"""Get all embeddings of a speaker.
Args:
idx (str): Target name.
Returns:
List[List]: all the embeddings of the given speaker.
"""
return self.embeddings_by_names[idx]
def get_embeddings_by_names(self) -> Dict:
"""Get all embeddings by names.
Returns:
Dict: all the embeddings of each speaker.
"""
embeddings_by_names = {}
for x in self.embeddings.values():
if x["name"] not in embeddings_by_names.keys():
embeddings_by_names[x["name"]] = [x["embedding"]]
else:
embeddings_by_names[x["name"]].append(x["embedding"])
return embeddings_by_names
def get_mean_embedding(self, idx: str, num_samples: int = None, randomize: bool = False) -> np.ndarray:
"""Get mean embedding of a idx.
Args:
idx (str): Target name.
num_samples (int, optional): Number of samples to be averaged. Defaults to None.
randomize (bool, optional): Pick random `num_samples` of embeddings. Defaults to False.
Returns:
np.ndarray: Mean embedding.
"""
embeddings = self.get_embeddings_by_name(idx)
if num_samples is None:
embeddings = np.stack(embeddings).mean(0)
else:
assert len(embeddings) >= num_samples, f" [!] {idx} has number of samples < {num_samples}"
if randomize:
embeddings = np.stack(random.choices(embeddings, k=num_samples)).mean(0)
else:
embeddings = np.stack(embeddings[:num_samples]).mean(0)
return embeddings
def get_random_speaker_id(self) -> Any:
"""Get a random embedding.
Args:
Returns:
np.ndarray: embedding.
"""
if self.ids:
return self.ids[random.choices(list(self.ids.keys()))[0]]
return None
def get_random_embedding(self) -> Any:
"""Get a random embedding.
Args:
Returns:
np.ndarray: embedding.
"""
if self.embeddings:
return self.embeddings[random.choices(list(self.embeddings.keys()))[0]]["embedding"]
return None
def get_speakers(self) -> List:
return self.ids
def get_clips(self) -> List:
return sorted(self.embeddings.keys())
def init_encoder(self, model_path: str, config_path: str) -> None:
"""Initialize a speaker encoder model.
Args:
model_path (str): Model file path.
config_path (str): Model config file path.
"""
self.encoder_config = load_config(config_path)
self.encoder = setup_encoder_model(self.encoder_config)
self.speaker_encoder_criterion = self.speaker_encoder.load_checkpoint(
self.encoder_config, model_path, eval=True, use_cuda=self.use_cuda
)
self.encoder_ap = AudioProcessor(**self.encoder_config.audio)
def compute_embedding_from_clip(self, wav_file: Union[str, List[str]]) -> list:
"""Compute a embedding from a given audio file.
Args:
wav_file (Union[str, List[str]]): Target file path.
Returns:
list: Computed embedding.
"""
def _compute(wav_file: str):
waveform = self.encoder_ap.load_wav(wav_file, sr=self.encoder_ap.sample_rate)
if not self.encoder_config.model_params.get("use_torch_spec", False):
m_input = self.encoder_ap.melspectrogram(waveform)
m_input = torch.from_numpy(m_input)
else:
m_input = torch.from_numpy(waveform)
if self.use_cuda:
m_input = m_input.cuda()
m_input = m_input.unsqueeze(0)
embedding = self.encoder.compute_embedding(m_input)
return embedding
if isinstance(wav_file, list):
# compute the mean embedding
embeddings = None
for wf in wav_file:
embedding = _compute(wf)
if embeddings is None:
embeddings = embedding
else:
embeddings += embedding
return (embeddings / len(wav_file))[0].tolist()
embedding = _compute(wav_file)
return embedding[0].tolist()
def compute_embedding(self, feats: Union[torch.Tensor, np.ndarray]) -> List:
"""Compute embedding from features.
Args:
feats (Union[torch.Tensor, np.ndarray]): Input features.
Returns:
List: computed embedding.
"""
if isinstance(feats, np.ndarray):
feats = torch.from_numpy(feats)
if feats.ndim == 2:
feats = feats.unsqueeze(0)
if self.use_cuda:
feats = feats.cuda()
return self.encoder.compute_embedding(feats)
def run_umap(self):
# TODO: implement speaker encoder
raise NotImplementedError
def plot_embeddings(self):
# TODO: implement speaker encoder
raise NotImplementedError
@staticmethod
def init_from_config(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "SpeakerManager":
"""Initialize a speaker manager from config