Add EmbeddingManager and BaseIDManager (#1374)

This commit is contained in:
Edresson Casanova 2022-03-31 08:41:16 -03:00 committed by GitHub
parent 1b22f03e98
commit 060e0f9368
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 412 additions and 404 deletions

View File

@ -49,7 +49,7 @@ encoder_manager = SpeakerManager(
use_cuda=args.use_cuda,
)
class_name_key = encoder_manager.speaker_encoder_config.class_name_key
class_name_key = encoder_manager.encoder_config.class_name_key
# compute speaker embeddings
speaker_mapping = {}
@ -63,10 +63,10 @@ for idx, wav_file in enumerate(tqdm(wav_files)):
wav_file_name = os.path.basename(wav_file)
if args.old_file is not None and wav_file_name in encoder_manager.clip_ids:
# get the embedding from the old file
embedd = encoder_manager.get_d_vector_by_clip(wav_file_name)
embedd = encoder_manager.get_embedding_by_clip(wav_file_name)
else:
# extract the embedding
embedd = encoder_manager.compute_d_vector_from_clip(wav_file)
embedd = encoder_manager.compute_embedding_from_clip(wav_file)
# create speaker_mapping if target dataset is defined
speaker_mapping[wav_file_name] = {}

View File

@ -11,8 +11,8 @@ from TTS.tts.utils.speakers import SpeakerManager
def compute_encoder_accuracy(dataset_items, encoder_manager):
class_name_key = encoder_manager.speaker_encoder_config.class_name_key
map_classid_to_classname = getattr(encoder_manager.speaker_encoder_config, "map_classid_to_classname", None)
class_name_key = encoder_manager.encoder_config.class_name_key
map_classid_to_classname = getattr(encoder_manager.encoder_config, "map_classid_to_classname", None)
class_acc_dict = {}
@ -22,13 +22,13 @@ def compute_encoder_accuracy(dataset_items, encoder_manager):
wav_file = item["audio_file"]
# extract the embedding
embedd = encoder_manager.compute_d_vector_from_clip(wav_file)
if encoder_manager.speaker_encoder_criterion is not None and map_classid_to_classname is not None:
embedd = encoder_manager.compute_embedding_from_clip(wav_file)
if encoder_manager.encoder_criterion is not None and map_classid_to_classname is not None:
embedding = torch.FloatTensor(embedd).unsqueeze(0)
if encoder_manager.use_cuda:
embedding = embedding.cuda()
class_id = encoder_manager.speaker_encoder_criterion.softmax.inference(embedding).item()
class_id = encoder_manager.encoder_criterion.softmax.inference(embedding).item()
predicted_label = map_classid_to_classname[str(class_id)]
else:
predicted_label = None

View File

@ -37,8 +37,8 @@ def setup_loader(ap, r, verbose=False):
precompute_num_workers=0,
use_noise_augment=False,
verbose=verbose,
speaker_id_mapping=speaker_manager.speaker_ids if c.use_speaker_embedding else None,
d_vector_mapping=speaker_manager.d_vectors if c.use_d_vector_file else None,
speaker_id_mapping=speaker_manager.ids if c.use_speaker_embedding else None,
d_vector_mapping=speaker_manager.embeddings if c.use_d_vector_file else None,
)
if c.use_phonemes and c.compute_input_seq_cache:

View File

@ -278,7 +278,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
print(
" > Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model."
)
print(synthesizer.tts_model.speaker_manager.speaker_ids)
print(synthesizer.tts_model.speaker_manager.ids)
return
# query langauge ids of a multi-lingual model.
@ -286,7 +286,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
print(
" > Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model."
)
print(synthesizer.tts_model.language_manager.language_id_mapping)
print(synthesizer.tts_model.language_manager.ids)
return
# check the arguments against a multi-speaker model.

View File

@ -12,7 +12,7 @@ from trainer.torch import NoamLR
from trainer.trainer_utils import get_optimizer
from TTS.encoder.dataset import EncoderDataset
from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_speaker_encoder_model
from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_encoder_model
from TTS.encoder.utils.samplers import PerfectBatchSampler
from TTS.encoder.utils.training import init_training
from TTS.encoder.utils.visual import plot_embeddings
@ -258,7 +258,7 @@ def main(args): # pylint: disable=redefined-outer-name
global train_classes
ap = AudioProcessor(**c.audio)
model = setup_speaker_encoder_model(c)
model = setup_encoder_model(c)
optimizer = get_optimizer(c.optimizer, c.optimizer_params, c.lr, model)

View File

@ -125,7 +125,7 @@ def to_camel(text):
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
def setup_speaker_encoder_model(config: "Coqpit"):
def setup_encoder_model(config: "Coqpit"):
if config.model_params["model_name"].lower() == "lstm":
model = LSTMSpeakerEncoder(
config.model_params["input_dim"],

View File

@ -143,7 +143,7 @@ def index():
"index.html",
show_details=args.show_details,
use_multi_speaker=use_multi_speaker,
speaker_ids=speaker_manager.speaker_ids if speaker_manager is not None else None,
speaker_ids=speaker_manager.ids if speaker_manager is not None else None,
use_gst=use_gst,
)

View File

@ -136,18 +136,18 @@ class BaseTTS(BaseTrainerModel):
if hasattr(self, "speaker_manager"):
if config.use_d_vector_file:
if speaker_name is None:
d_vector = self.speaker_manager.get_random_d_vector()
d_vector = self.speaker_manager.get_random_embeddings()
else:
d_vector = self.speaker_manager.get_d_vector_by_speaker(speaker_name)
d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name)
elif config.use_speaker_embedding:
if speaker_name is None:
speaker_id = self.speaker_manager.get_random_speaker_id()
speaker_id = self.speaker_manager.get_random_id()
else:
speaker_id = self.speaker_manager.speaker_ids[speaker_name]
speaker_id = self.speaker_manager.ids[speaker_name]
# get language id
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
language_id = self.language_manager.language_id_mapping[language_name]
language_id = self.language_manager.ids[language_name]
return {
"text": text,
@ -279,23 +279,19 @@ class BaseTTS(BaseTrainerModel):
# setup multi-speaker attributes
if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
if hasattr(config, "model_args"):
speaker_id_mapping = (
self.speaker_manager.speaker_ids if config.model_args.use_speaker_embedding else None
)
d_vector_mapping = self.speaker_manager.d_vectors if config.model_args.use_d_vector_file else None
speaker_id_mapping = self.speaker_manager.ids if config.model_args.use_speaker_embedding else None
d_vector_mapping = self.speaker_manager.embeddings if config.model_args.use_d_vector_file else None
config.use_d_vector_file = config.model_args.use_d_vector_file
else:
speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None
d_vector_mapping = self.speaker_manager.d_vectors if config.use_d_vector_file else None
speaker_id_mapping = self.speaker_manager.ids if config.use_speaker_embedding else None
d_vector_mapping = self.speaker_manager.embeddings if config.use_d_vector_file else None
else:
speaker_id_mapping = None
d_vector_mapping = None
# setup multi-lingual attributes
if hasattr(self, "language_manager") and self.language_manager is not None:
language_id_mapping = (
self.language_manager.language_id_mapping if self.args.use_language_embedding else None
)
language_id_mapping = self.language_manager.ids if self.args.use_language_embedding else None
else:
language_id_mapping = None
@ -352,13 +348,13 @@ class BaseTTS(BaseTrainerModel):
d_vector = None
if self.config.use_d_vector_file:
d_vector = [self.speaker_manager.d_vectors[name]["embedding"] for name in self.speaker_manager.d_vectors]
d_vector = [self.speaker_manager.embeddings[name]["embedding"] for name in self.speaker_manager.embeddings]
d_vector = (random.sample(sorted(d_vector), 1),)
aux_inputs = {
"speaker_id": None
if not self.config.use_speaker_embedding
else random.sample(sorted(self.speaker_manager.speaker_ids.values()), 1),
else random.sample(sorted(self.speaker_manager.ids.values()), 1),
"d_vector": d_vector,
"style_wav": None, # TODO: handle GST style input
}
@ -405,7 +401,7 @@ class BaseTTS(BaseTrainerModel):
"""Save the speaker.json and language_ids.json at the beginning of the training. Also update both paths."""
if self.speaker_manager is not None:
output_path = os.path.join(trainer.output_path, "speakers.json")
self.speaker_manager.save_speaker_ids_to_file(output_path)
self.speaker_manager.save_ids_to_file(output_path)
trainer.config.speakers_file = output_path
# some models don't have `model_args` set
if hasattr(trainer.config, "model_args"):
@ -416,7 +412,7 @@ class BaseTTS(BaseTrainerModel):
if hasattr(self, "language_manager") and self.language_manager is not None:
output_path = os.path.join(trainer.output_path, "language_ids.json")
self.language_manager.save_language_ids_to_file(output_path)
self.language_manager.save_ids_to_file(output_path)
trainer.config.language_ids_file = output_path
if hasattr(trainer.config, "model_args"):
trainer.config.model_args.language_ids_file = output_path

View File

@ -124,7 +124,7 @@ class GlowTTS(BaseTTS):
)
if self.speaker_manager is not None:
assert (
config.d_vector_dim == self.speaker_manager.d_vector_dim
config.d_vector_dim == self.speaker_manager.embedding_dim
), " [!] d-vector dimension mismatch b/w config and speaker manager."
# init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file:

View File

@ -652,28 +652,28 @@ class Vits(BaseTTS):
# TODO: make this a function
if self.args.use_speaker_encoder_as_loss:
if self.speaker_manager.speaker_encoder is None and (
if self.speaker_manager.encoder is None and (
not self.args.speaker_encoder_model_path or not self.args.speaker_encoder_config_path
):
raise RuntimeError(
" [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!"
)
self.speaker_manager.speaker_encoder.eval()
self.speaker_manager.encoder.eval()
print(" > External Speaker Encoder Loaded !!")
if (
hasattr(self.speaker_manager.speaker_encoder, "audio_config")
and self.config.audio["sample_rate"] != self.speaker_manager.speaker_encoder.audio_config["sample_rate"]
hasattr(self.speaker_manager.encoder, "audio_config")
and self.config.audio["sample_rate"] != self.speaker_manager.encoder.audio_config["sample_rate"]
):
self.audio_transform = torchaudio.transforms.Resample(
orig_freq=self.audio_config["sample_rate"],
new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"],
new_freq=self.speaker_manager.encoder.audio_config["sample_rate"],
)
# pylint: disable=W0101,W0105
self.audio_transform = torchaudio.transforms.Resample(
orig_freq=self.config.audio.sample_rate,
new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"],
new_freq=self.speaker_manager.encoder.audio_config["sample_rate"],
)
def _init_speaker_embedding(self):
@ -887,7 +887,7 @@ class Vits(BaseTTS):
pad_short=True,
)
if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None:
if self.args.use_speaker_encoder_as_loss and self.speaker_manager.encoder is not None:
# concate generated and GT waveforms
wavs_batch = torch.cat((wav_seg, o), dim=0)
@ -896,7 +896,7 @@ class Vits(BaseTTS):
if self.audio_transform is not None:
wavs_batch = self.audio_transform(wavs_batch)
pred_embs = self.speaker_manager.speaker_encoder.forward(wavs_batch, l2_norm=True)
pred_embs = self.speaker_manager.encoder.forward(wavs_batch, l2_norm=True)
# split generated and GT speaker embeddings
gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0)
@ -1223,18 +1223,18 @@ class Vits(BaseTTS):
if hasattr(self, "speaker_manager"):
if config.use_d_vector_file:
if speaker_name is None:
d_vector = self.speaker_manager.get_random_d_vector()
d_vector = self.speaker_manager.get_random_embeddings()
else:
d_vector = self.speaker_manager.get_mean_d_vector(speaker_name, num_samples=None, randomize=False)
d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False)
elif config.use_speaker_embedding:
if speaker_name is None:
speaker_id = self.speaker_manager.get_random_speaker_id()
speaker_id = self.speaker_manager.get_random_id()
else:
speaker_id = self.speaker_manager.speaker_ids[speaker_name]
speaker_id = self.speaker_manager.ids[speaker_name]
# get language id
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
language_id = self.language_manager.language_id_mapping[language_name]
language_id = self.language_manager.ids[language_name]
return {
"text": text,
@ -1289,26 +1289,22 @@ class Vits(BaseTTS):
d_vectors = None
# get numerical speaker ids from speaker names
if self.speaker_manager is not None and self.speaker_manager.speaker_ids and self.args.use_speaker_embedding:
speaker_ids = [self.speaker_manager.speaker_ids[sn] for sn in batch["speaker_names"]]
if self.speaker_manager is not None and self.speaker_manager.ids and self.args.use_speaker_embedding:
speaker_ids = [self.speaker_manager.ids[sn] for sn in batch["speaker_names"]]
if speaker_ids is not None:
speaker_ids = torch.LongTensor(speaker_ids)
batch["speaker_ids"] = speaker_ids
# get d_vectors from audio file names
if self.speaker_manager is not None and self.speaker_manager.d_vectors and self.args.use_d_vector_file:
d_vector_mapping = self.speaker_manager.d_vectors
if self.speaker_manager is not None and self.speaker_manager.embeddings and self.args.use_d_vector_file:
d_vector_mapping = self.speaker_manager.embeddings
d_vectors = [d_vector_mapping[w]["embedding"] for w in batch["audio_files"]]
d_vectors = torch.FloatTensor(d_vectors)
# get language ids from language names
if (
self.language_manager is not None
and self.language_manager.language_id_mapping
and self.args.use_language_embedding
):
language_ids = [self.language_manager.language_id_mapping[ln] for ln in batch["language_names"]]
if self.language_manager is not None and self.language_manager.ids and self.args.use_language_embedding:
language_ids = [self.language_manager.ids[ln] for ln in batch["language_names"]]
if language_ids is not None:
language_ids = torch.LongTensor(language_ids)
@ -1490,7 +1486,7 @@ class Vits(BaseTTS):
language_manager = LanguageManager.init_from_config(config)
if config.model_args.speaker_encoder_model_path:
speaker_manager.init_speaker_encoder(
speaker_manager.init_encoder(
config.model_args.speaker_encoder_model_path, config.model_args.speaker_encoder_config_path
)
return Vits(new_config, ap, tokenizer, speaker_manager, language_manager)

View File

@ -1,6 +1,5 @@
import json
import os
from typing import Dict, List
from typing import Any, Dict, List
import fsspec
import numpy as np
@ -8,9 +7,10 @@ import torch
from coqpit import Coqpit
from TTS.config import check_config_and_model_args
from TTS.tts.utils.managers import BaseIDManager
class LanguageManager:
class LanguageManager(BaseIDManager):
"""Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information
in a way that can be queried by language.
@ -25,37 +25,23 @@ class LanguageManager:
>>> language_id_mapper = manager.language_ids
"""
language_id_mapping: Dict = {}
def __init__(
self,
language_ids_file_path: str = "",
config: Coqpit = None,
):
self.language_id_mapping = {}
if language_ids_file_path:
self.set_language_ids_from_file(language_ids_file_path)
super().__init__(id_file_path=language_ids_file_path)
if config:
self.set_language_ids_from_config(config)
@staticmethod
def _load_json(json_file_path: str) -> Dict:
with fsspec.open(json_file_path, "r") as f:
return json.load(f)
@staticmethod
def _save_json(json_file_path: str, data: dict) -> None:
with fsspec.open(json_file_path, "w") as f:
json.dump(data, f, indent=4)
@property
def num_languages(self) -> int:
return len(list(self.language_id_mapping.keys()))
return len(list(self.ids.keys()))
@property
def language_names(self) -> List:
return list(self.language_id_mapping.keys())
return list(self.ids.keys())
@staticmethod
def parse_language_ids_from_config(c: Coqpit) -> Dict:
@ -79,25 +65,24 @@ class LanguageManager:
"""Set language IDs from config samples.
Args:
items (List): Data sampled returned by `load_meta_data()`.
c (Coqpit): Config.
"""
self.language_id_mapping = self.parse_language_ids_from_config(c)
self.ids = self.parse_language_ids_from_config(c)
def set_language_ids_from_file(self, file_path: str) -> None:
"""Load language ids from a json file.
@staticmethod
def parse_ids_from_data(items: List, parse_key: str) -> Any:
raise NotImplementedError
Args:
file_path (str): Path to the target json file.
"""
self.language_id_mapping = self._load_json(file_path)
def set_ids_from_data(self, items: List, parse_key: str) -> Any:
raise NotImplementedError
def save_language_ids_to_file(self, file_path: str) -> None:
def save_ids_to_file(self, file_path: str) -> None:
"""Save language IDs to a json file.
Args:
file_path (str): Path to the output file.
"""
self._save_json(file_path, self.language_id_mapping)
self._save_json(file_path, self.ids)
@staticmethod
def init_from_config(config: Coqpit) -> "LanguageManager":

285
TTS/tts/utils/managers.py Normal file
View File

@ -0,0 +1,285 @@
import json
import random
from typing import Any, Dict, List, Tuple, Union
import fsspec
import numpy as np
import torch
from TTS.config import load_config
from TTS.encoder.utils.generic_utils import setup_encoder_model
from TTS.utils.audio import AudioProcessor
class BaseIDManager:
"""Base `ID` Manager class. Every new `ID` manager must inherit this.
It defines common `ID` manager specific functions.
"""
def __init__(self, id_file_path: str = ""):
self.ids = {}
if id_file_path:
self.load_ids_from_file(id_file_path)
@staticmethod
def _load_json(json_file_path: str) -> Dict:
with fsspec.open(json_file_path, "r") as f:
return json.load(f)
@staticmethod
def _save_json(json_file_path: str, data: dict) -> None:
with fsspec.open(json_file_path, "w") as f:
json.dump(data, f, indent=4)
def set_ids_from_data(self, items: List, parse_key: str) -> None:
"""Set IDs from data samples.
Args:
items (List): Data sampled returned by `load_tts_samples()`.
"""
self.ids = self.parse_ids_from_data(items, parse_key=parse_key)
def load_ids_from_file(self, file_path: str) -> None:
"""Set IDs from a file.
Args:
file_path (str): Path to the file.
"""
self.ids = self._load_json(file_path)
def save_ids_to_file(self, file_path: str) -> None:
"""Save IDs to a json file.
Args:
file_path (str): Path to the output file.
"""
self._save_json(file_path, self.ids)
def get_random_id(self) -> Any:
"""Get a random embedding.
Args:
Returns:
np.ndarray: embedding.
"""
if self.ids:
return self.ids[random.choices(list(self.ids.keys()))[0]]
return None
@staticmethod
def parse_ids_from_data(items: List, parse_key: str) -> Tuple[Dict]:
"""Parse IDs from data samples retured by `load_tts_samples()`.
Args:
items (list): Data sampled returned by `load_tts_samples()`.
parse_key (str): The key to being used to parse the data.
Returns:
Tuple[Dict]: speaker IDs.
"""
classes = sorted({item[parse_key] for item in items})
ids = {name: i for i, name in enumerate(classes)}
return ids
class EmbeddingManager(BaseIDManager):
"""Base `Embedding` Manager class. Every new `Embedding` manager must inherit this.
It defines common `Embedding` manager specific functions.
"""
def __init__(
self,
embedding_file_path: str = "",
id_file_path: str = "",
encoder_model_path: str = "",
encoder_config_path: str = "",
use_cuda: bool = False,
):
super().__init__(id_file_path=id_file_path)
self.embeddings = {}
self.embeddings_by_names = {}
self.clip_ids = []
self.encoder = None
self.encoder_ap = None
self.use_cuda = use_cuda
if embedding_file_path:
self.load_embeddings_from_file(embedding_file_path)
if encoder_model_path and encoder_config_path:
self.init_encoder(encoder_model_path, encoder_config_path)
@property
def embedding_dim(self):
"""Dimensionality of embeddings. If embeddings are not loaded, returns zero."""
if self.embeddings:
return len(self.embeddings[list(self.embeddings.keys())[0]]["embedding"])
return 0
def save_embeddings_to_file(self, file_path: str) -> None:
"""Save embeddings to a json file.
Args:
file_path (str): Path to the output file.
"""
self._save_json(file_path, self.embeddings)
def load_embeddings_from_file(self, file_path: str) -> None:
"""Load embeddings from a json file.
Args:
file_path (str): Path to the target json file.
"""
self.embeddings = self._load_json(file_path)
speakers = sorted({x["name"] for x in self.embeddings.values()})
self.ids = {name: i for i, name in enumerate(speakers)}
self.clip_ids = list(set(sorted(clip_name for clip_name in self.embeddings.keys())))
# cache embeddings_by_names for fast inference using a bigger speakers.json
self.embeddings_by_names = self.get_embeddings_by_names()
def get_embedding_by_clip(self, clip_idx: str) -> List:
"""Get embedding by clip ID.
Args:
clip_idx (str): Target clip ID.
Returns:
List: embedding as a list.
"""
return self.embeddings[clip_idx]["embedding"]
def get_embeddings_by_name(self, idx: str) -> List[List]:
"""Get all embeddings of a speaker.
Args:
idx (str): Target name.
Returns:
List[List]: all the embeddings of the given speaker.
"""
return self.embeddings_by_names[idx]
def get_embeddings_by_names(self) -> Dict:
"""Get all embeddings by names.
Returns:
Dict: all the embeddings of each speaker.
"""
embeddings_by_names = {}
for x in self.embeddings.values():
if x["name"] not in embeddings_by_names.keys():
embeddings_by_names[x["name"]] = [x["embedding"]]
else:
embeddings_by_names[x["name"]].append(x["embedding"])
return embeddings_by_names
def get_mean_embedding(self, idx: str, num_samples: int = None, randomize: bool = False) -> np.ndarray:
"""Get mean embedding of a idx.
Args:
idx (str): Target name.
num_samples (int, optional): Number of samples to be averaged. Defaults to None.
randomize (bool, optional): Pick random `num_samples` of embeddings. Defaults to False.
Returns:
np.ndarray: Mean embedding.
"""
embeddings = self.get_embeddings_by_name(idx)
if num_samples is None:
embeddings = np.stack(embeddings).mean(0)
else:
assert len(embeddings) >= num_samples, f" [!] {idx} has number of samples < {num_samples}"
if randomize:
embeddings = np.stack(random.choices(embeddings, k=num_samples)).mean(0)
else:
embeddings = np.stack(embeddings[:num_samples]).mean(0)
return embeddings
def get_random_embedding(self) -> Any:
"""Get a random embedding.
Args:
Returns:
np.ndarray: embedding.
"""
if self.embeddings:
return self.embeddings[random.choices(list(self.embeddings.keys()))[0]]["embedding"]
return None
def get_clips(self) -> List:
return sorted(self.embeddings.keys())
def init_encoder(self, model_path: str, config_path: str) -> None:
"""Initialize a speaker encoder model.
Args:
model_path (str): Model file path.
config_path (str): Model config file path.
"""
self.encoder_config = load_config(config_path)
self.encoder = setup_encoder_model(self.encoder_config)
self.encoder_criterion = self.encoder.load_checkpoint(
self.encoder_config, model_path, eval=True, use_cuda=self.use_cuda
)
self.encoder_ap = AudioProcessor(**self.encoder_config.audio)
def compute_embedding_from_clip(self, wav_file: Union[str, List[str]]) -> list:
"""Compute a embedding from a given audio file.
Args:
wav_file (Union[str, List[str]]): Target file path.
Returns:
list: Computed embedding.
"""
def _compute(wav_file: str):
waveform = self.encoder_ap.load_wav(wav_file, sr=self.encoder_ap.sample_rate)
if not self.encoder_config.model_params.get("use_torch_spec", False):
m_input = self.encoder_ap.melspectrogram(waveform)
m_input = torch.from_numpy(m_input)
else:
m_input = torch.from_numpy(waveform)
if self.use_cuda:
m_input = m_input.cuda()
m_input = m_input.unsqueeze(0)
embedding = self.encoder.compute_embedding(m_input)
return embedding
if isinstance(wav_file, list):
# compute the mean embedding
embeddings = None
for wf in wav_file:
embedding = _compute(wf)
if embeddings is None:
embeddings = embedding
else:
embeddings += embedding
return (embeddings / len(wav_file))[0].tolist()
embedding = _compute(wav_file)
return embedding[0].tolist()
def compute_embeddings(self, feats: Union[torch.Tensor, np.ndarray]) -> List:
"""Compute embedding from features.
Args:
feats (Union[torch.Tensor, np.ndarray]): Input features.
Returns:
List: computed embedding.
"""
if isinstance(feats, np.ndarray):
feats = torch.from_numpy(feats)
if feats.ndim == 2:
feats = feats.unsqueeze(0)
if self.use_cuda:
feats = feats.cuda()
return self.encoder.compute_embedding(feats)

View File

@ -1,19 +1,17 @@
import json
import os
import random
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Union
import fsspec
import numpy as np
import torch
from coqpit import Coqpit
from TTS.config import get_from_config_or_model_args_with_default, load_config
from TTS.encoder.utils.generic_utils import setup_speaker_encoder_model
from TTS.utils.audio import AudioProcessor
from TTS.config import get_from_config_or_model_args_with_default
from TTS.tts.utils.managers import EmbeddingManager
class SpeakerManager:
class SpeakerManager(EmbeddingManager):
"""Manage the speakers for multi-speaker 🐸TTS models. Load a datafile and parse the information
in a way that can be queried by speaker or clip.
@ -50,7 +48,7 @@ class SpeakerManager:
>>> # load a sample audio and compute embedding
>>> waveform = ap.load_wav(sample_wav_path)
>>> mel = ap.melspectrogram(waveform)
>>> d_vector = manager.compute_d_vector(mel.T)
>>> d_vector = manager.compute_embeddings(mel.T)
"""
def __init__(
@ -62,279 +60,27 @@ class SpeakerManager:
encoder_config_path: str = "",
use_cuda: bool = False,
):
self.d_vectors = {}
self.speaker_ids = {}
self.d_vectors_by_speakers = {}
self.clip_ids = []
self.speaker_encoder = None
self.speaker_encoder_ap = None
self.use_cuda = use_cuda
super().__init__(
embedding_file_path=d_vectors_file_path,
id_file_path=speaker_id_file_path,
encoder_model_path=encoder_model_path,
encoder_config_path=encoder_config_path,
use_cuda=use_cuda,
)
if data_items:
self.speaker_ids, _ = self.parse_speakers_from_data(data_items)
if d_vectors_file_path:
self.set_d_vectors_from_file(d_vectors_file_path)
if speaker_id_file_path:
self.set_speaker_ids_from_file(speaker_id_file_path)
if encoder_model_path and encoder_config_path:
self.init_speaker_encoder(encoder_model_path, encoder_config_path)
@staticmethod
def _load_json(json_file_path: str) -> Dict:
with fsspec.open(json_file_path, "r") as f:
return json.load(f)
@staticmethod
def _save_json(json_file_path: str, data: dict) -> None:
with fsspec.open(json_file_path, "w") as f:
json.dump(data, f, indent=4)
self.set_ids_from_data(data_items, parse_key="speaker_name")
@property
def num_speakers(self):
return len(self.speaker_ids)
return len(self.ids)
@property
def speaker_names(self):
return list(self.speaker_ids.keys())
@property
def d_vector_dim(self):
"""Dimensionality of d_vectors. If d_vectors are not loaded, returns zero."""
if self.d_vectors:
return len(self.d_vectors[list(self.d_vectors.keys())[0]]["embedding"])
return 0
@staticmethod
def parse_speakers_from_data(items: list) -> Tuple[Dict, int]:
"""Parse speaker IDs from data samples retured by `load_tts_samples()`.
Args:
items (list): Data sampled returned by `load_tts_samples()`.
Returns:
Tuple[Dict, int]: speaker IDs and number of speakers.
"""
speakers = sorted({item["speaker_name"] for item in items})
speaker_ids = {name: i for i, name in enumerate(speakers)}
num_speakers = len(speaker_ids)
return speaker_ids, num_speakers
def set_speaker_ids_from_data(self, items: List) -> None:
"""Set speaker IDs from data samples.
Args:
items (List): Data sampled returned by `load_tts_samples()`.
"""
self.speaker_ids, _ = self.parse_speakers_from_data(items)
def set_speaker_ids_from_file(self, file_path: str) -> None:
"""Set speaker IDs from a file.
Args:
file_path (str): Path to the file.
"""
self.speaker_ids = self._load_json(file_path)
def save_speaker_ids_to_file(self, file_path: str) -> None:
"""Save speaker IDs to a json file.
Args:
file_path (str): Path to the output file.
"""
self._save_json(file_path, self.speaker_ids)
def save_d_vectors_to_file(self, file_path: str) -> None:
"""Save d_vectors to a json file.
Args:
file_path (str): Path to the output file.
"""
self._save_json(file_path, self.d_vectors)
def set_d_vectors_from_file(self, file_path: str) -> None:
"""Load d_vectors from a json file.
Args:
file_path (str): Path to the target json file.
"""
self.d_vectors = self._load_json(file_path)
speakers = sorted({x["name"] for x in self.d_vectors.values()})
self.speaker_ids = {name: i for i, name in enumerate(speakers)}
self.clip_ids = list(set(sorted(clip_name for clip_name in self.d_vectors.keys())))
# cache d_vectors_by_speakers for fast inference using a bigger speakers.json
self.d_vectors_by_speakers = self.get_d_vectors_by_speakers()
def get_d_vector_by_clip(self, clip_idx: str) -> List:
"""Get d_vector by clip ID.
Args:
clip_idx (str): Target clip ID.
Returns:
List: d_vector as a list.
"""
return self.d_vectors[clip_idx]["embedding"]
def get_d_vectors_by_speaker(self, speaker_idx: str) -> List[List]:
"""Get all d_vectors of a speaker.
Args:
speaker_idx (str): Target speaker ID.
Returns:
List[List]: all the d_vectors of the given speaker.
"""
return self.d_vectors_by_speakers[speaker_idx]
def get_d_vectors_by_speakers(self) -> Dict:
"""Get all d_vectors by speaker.
Returns:
Dict: all the d_vectors of each speaker.
"""
d_vectors_by_speakers = {}
for x in self.d_vectors.values():
if x["name"] not in d_vectors_by_speakers.keys():
d_vectors_by_speakers[x["name"]] = [x["embedding"]]
else:
d_vectors_by_speakers[x["name"]].append(x["embedding"])
return d_vectors_by_speakers
def get_mean_d_vector(self, speaker_idx: str, num_samples: int = None, randomize: bool = False) -> np.ndarray:
"""Get mean d_vector of a speaker ID.
Args:
speaker_idx (str): Target speaker ID.
num_samples (int, optional): Number of samples to be averaged. Defaults to None.
randomize (bool, optional): Pick random `num_samples` of d_vectors. Defaults to False.
Returns:
np.ndarray: Mean d_vector.
"""
d_vectors = self.get_d_vectors_by_speaker(speaker_idx)
if num_samples is None:
d_vectors = np.stack(d_vectors).mean(0)
else:
assert len(d_vectors) >= num_samples, f" [!] speaker {speaker_idx} has number of samples < {num_samples}"
if randomize:
d_vectors = np.stack(random.choices(d_vectors, k=num_samples)).mean(0)
else:
d_vectors = np.stack(d_vectors[:num_samples]).mean(0)
return d_vectors
def get_random_speaker_id(self) -> Any:
"""Get a random d_vector.
Args:
Returns:
np.ndarray: d_vector.
"""
if self.speaker_ids:
return self.speaker_ids[random.choices(list(self.speaker_ids.keys()))[0]]
return None
def get_random_d_vector(self) -> Any:
"""Get a random D ID.
Args:
Returns:
np.ndarray: d_vector.
"""
if self.d_vectors:
return self.d_vectors[random.choices(list(self.d_vectors.keys()))[0]]["embedding"]
return None
return list(self.ids.keys())
def get_speakers(self) -> List:
return self.speaker_ids
def get_clips(self) -> List:
return sorted(self.d_vectors.keys())
def init_speaker_encoder(self, model_path: str, config_path: str) -> None:
"""Initialize a speaker encoder model.
Args:
model_path (str): Model file path.
config_path (str): Model config file path.
"""
self.speaker_encoder_config = load_config(config_path)
self.speaker_encoder = setup_speaker_encoder_model(self.speaker_encoder_config)
self.speaker_encoder_criterion = self.speaker_encoder.load_checkpoint(
self.speaker_encoder_config, model_path, eval=True, use_cuda=self.use_cuda
)
self.speaker_encoder_ap = AudioProcessor(**self.speaker_encoder_config.audio)
def compute_d_vector_from_clip(self, wav_file: Union[str, List[str]]) -> list:
"""Compute a d_vector from a given audio file.
Args:
wav_file (Union[str, List[str]]): Target file path.
Returns:
list: Computed d_vector.
"""
def _compute(wav_file: str):
waveform = self.speaker_encoder_ap.load_wav(wav_file, sr=self.speaker_encoder_ap.sample_rate)
if not self.speaker_encoder_config.model_params.get("use_torch_spec", False):
m_input = self.speaker_encoder_ap.melspectrogram(waveform)
m_input = torch.from_numpy(m_input)
else:
m_input = torch.from_numpy(waveform)
if self.use_cuda:
m_input = m_input.cuda()
m_input = m_input.unsqueeze(0)
d_vector = self.speaker_encoder.compute_embedding(m_input)
return d_vector
if isinstance(wav_file, list):
# compute the mean d_vector
d_vectors = None
for wf in wav_file:
d_vector = _compute(wf)
if d_vectors is None:
d_vectors = d_vector
else:
d_vectors += d_vector
return (d_vectors / len(wav_file))[0].tolist()
d_vector = _compute(wav_file)
return d_vector[0].tolist()
def compute_d_vector(self, feats: Union[torch.Tensor, np.ndarray]) -> List:
"""Compute d_vector from features.
Args:
feats (Union[torch.Tensor, np.ndarray]): Input features.
Returns:
List: computed d_vector.
"""
if isinstance(feats, np.ndarray):
feats = torch.from_numpy(feats)
if feats.ndim == 2:
feats = feats.unsqueeze(0)
if self.use_cuda:
feats = feats.cuda()
return self.speaker_encoder.compute_embedding(feats)
def run_umap(self):
# TODO: implement speaker encoder
raise NotImplementedError
def plot_embeddings(self):
# TODO: implement speaker encoder
raise NotImplementedError
return self.ids
@staticmethod
def init_from_config(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "SpeakerManager":
@ -420,7 +166,7 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
speaker_manager = SpeakerManager()
if c.use_speaker_embedding:
if data is not None:
speaker_manager.set_speaker_ids_from_data(data)
speaker_manager.set_ids_from_data(data, parse_key="speaker_name")
if restore_path:
speakers_file = _set_file_path(restore_path)
# restoring speaker manager from a previous run.
@ -432,27 +178,27 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
raise RuntimeError(
"You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.d_vector_file"
)
speaker_manager.load_d_vectors_file(c.d_vector_file)
speaker_manager.set_d_vectors_from_file(speakers_file)
speaker_manager.load_embeddings_from_file(c.d_vector_file)
speaker_manager.load_embeddings_from_file(speakers_file)
elif not c.use_d_vector_file: # restor speaker manager with speaker ID file.
speaker_ids_from_data = speaker_manager.speaker_ids
speaker_manager.set_speaker_ids_from_file(speakers_file)
speaker_ids_from_data = speaker_manager.ids
speaker_manager.load_ids_from_file(speakers_file)
assert all(
speaker in speaker_manager.speaker_ids for speaker in speaker_ids_from_data
speaker in speaker_manager.ids for speaker in speaker_ids_from_data
), " [!] You cannot introduce new speakers to a pre-trained model."
elif c.use_d_vector_file and c.d_vector_file:
# new speaker manager with external speaker embeddings.
speaker_manager.set_d_vectors_from_file(c.d_vector_file)
speaker_manager.load_embeddings_from_file(c.d_vector_file)
elif c.use_d_vector_file and not c.d_vector_file:
raise "use_d_vector_file is True, so you need pass a external speaker embedding file."
elif c.use_speaker_embedding and "speakers_file" in c and c.speakers_file:
# new speaker manager with speaker IDs file.
speaker_manager.set_speaker_ids_from_file(c.speakers_file)
speaker_manager.load_ids_from_file(c.speakers_file)
if speaker_manager.num_speakers > 0:
print(
" > Speaker manager is loaded with {} speakers: {}".format(
speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids)
speaker_manager.num_speakers, ", ".join(speaker_manager.ids)
)
)
@ -461,9 +207,9 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
out_file_path = os.path.join(out_path, "speakers.json")
print(f" > Saving `speakers.json` to {out_file_path}.")
if c.use_d_vector_file and c.d_vector_file:
speaker_manager.save_d_vectors_to_file(out_file_path)
speaker_manager.save_embeddings_to_file(out_file_path)
else:
speaker_manager.save_speaker_ids_to_file(out_file_path)
speaker_manager.save_ids_to_file(out_file_path)
return speaker_manager

View File

@ -122,7 +122,7 @@ class Synthesizer(object):
self.tts_model.cuda()
if self.encoder_checkpoint and hasattr(self.tts_model, "speaker_manager"):
self.tts_model.speaker_manager.init_speaker_encoder(self.encoder_checkpoint, self.encoder_config)
self.tts_model.speaker_manager.init_encoder(self.encoder_checkpoint, self.encoder_config)
def _set_speaker_encoder_paths_from_tts_config(self):
"""Set the encoder paths from the tts model config for models with speaker encoders."""
@ -212,17 +212,17 @@ class Synthesizer(object):
# handle multi-speaker
speaker_embedding = None
speaker_id = None
if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "speaker_ids"):
if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "ids"):
if speaker_name and isinstance(speaker_name, str):
if self.tts_config.use_d_vector_file:
# get the average speaker embedding from the saved d_vectors.
speaker_embedding = self.tts_model.speaker_manager.get_mean_d_vector(
speaker_embedding = self.tts_model.speaker_manager.get_mean_embedding(
speaker_name, num_samples=None, randomize=False
)
speaker_embedding = np.array(speaker_embedding)[None, :] # [1 x embedding_dim]
else:
# get speaker idx from the speaker name
speaker_id = self.tts_model.speaker_manager.speaker_ids[speaker_name]
speaker_id = self.tts_model.speaker_manager.ids[speaker_name]
elif not speaker_name and not speaker_wav:
raise ValueError(
@ -244,7 +244,7 @@ class Synthesizer(object):
hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None
):
if language_name and isinstance(language_name, str):
language_id = self.tts_model.language_manager.language_id_mapping[language_name]
language_id = self.tts_model.language_manager.ids[language_name]
elif not language_name:
raise ValueError(
@ -260,7 +260,7 @@ class Synthesizer(object):
# compute a new d_vector from the given clip.
if speaker_wav is not None:
speaker_embedding = self.tts_model.speaker_manager.compute_d_vector_from_clip(speaker_wav)
speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav)
use_gl = self.vocoder_model is None
@ -319,7 +319,7 @@ class Synthesizer(object):
if reference_speaker_name and isinstance(reference_speaker_name, str):
if self.tts_config.use_d_vector_file:
# get the speaker embedding from the saved d_vectors.
reference_speaker_embedding = self.tts_model.speaker_manager.get_d_vectors_by_speaker(
reference_speaker_embedding = self.tts_model.speaker_manager.get_embeddings_by_name(
reference_speaker_name
)[0]
reference_speaker_embedding = np.array(reference_speaker_embedding)[
@ -327,9 +327,9 @@ class Synthesizer(object):
] # [1 x embedding_dim]
else:
# get speaker idx from the speaker name
reference_speaker_id = self.tts_model.speaker_manager.speaker_ids[reference_speaker_name]
reference_speaker_id = self.tts_model.speaker_manager.ids[reference_speaker_name]
else:
reference_speaker_embedding = self.tts_model.speaker_manager.compute_d_vector_from_clip(
reference_speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(
reference_wav
)

View File

@ -119,7 +119,7 @@ train_samples, eval_samples = load_tts_samples(
# init speaker manager for multi-speaker training
# it maps speaker-id to speaker-name in the model and data-loader
speaker_manager = SpeakerManager()
speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples)
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
config.model_args.num_speakers = speaker_manager.num_speakers
language_manager = LanguageManager(config=config)

View File

@ -81,7 +81,7 @@ train_samples, eval_samples = load_tts_samples(
# init speaker manager for multi-speaker training
# it maps speaker-id to speaker-name in the model and data-loader
speaker_manager = SpeakerManager()
speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples)
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
config.model_args.num_speakers = speaker_manager.num_speakers
# init model

View File

@ -79,7 +79,7 @@ train_samples, eval_samples = load_tts_samples(
# init speaker manager for multi-speaker training
# it maps speaker-id to speaker-name in the model and data-loader
speaker_manager = SpeakerManager()
speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples)
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
config.model_args.num_speakers = speaker_manager.num_speakers
# init model

View File

@ -79,7 +79,7 @@ train_samples, eval_samples = load_tts_samples(
# init speaker manager for multi-speaker training
# it maps speaker-id to speaker-name in the model and data-loader
speaker_manager = SpeakerManager()
speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples)
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
config.num_speakers = speaker_manager.num_speakers
# init model

View File

@ -79,7 +79,7 @@ train_samples, eval_samples = load_tts_samples(
# init speaker manager for multi-speaker training
# it maps speaker-id to speaker-name in the model and data-loader
speaker_manager = SpeakerManager()
speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples)
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
config.model_args.num_speakers = speaker_manager.num_speakers
# init model

View File

@ -82,7 +82,7 @@ train_samples, eval_samples = load_tts_samples(
# init speaker manager for multi-speaker training
# it mainly handles speaker-id to speaker-name for the model and the data-loader
speaker_manager = SpeakerManager()
speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples)
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
# init model
model = Tacotron(config, ap, tokenizer, speaker_manager)

View File

@ -88,7 +88,7 @@ train_samples, eval_samples = load_tts_samples(
# init speaker manager for multi-speaker training
# it mainly handles speaker-id to speaker-name for the model and the data-loader
speaker_manager = SpeakerManager()
speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples)
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
# init model
model = Tacotron2(config, ap, tokenizer, speaker_manager)

View File

@ -88,7 +88,7 @@ train_samples, eval_samples = load_tts_samples(
# init speaker manager for multi-speaker training
# it mainly handles speaker-id to speaker-name for the model and the data-loader
speaker_manager = SpeakerManager()
speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples)
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
# init model
model = Tacotron2(config, ap, tokenizer, speaker_manager)

View File

@ -89,7 +89,7 @@ train_samples, eval_samples = load_tts_samples(
# init speaker manager for multi-speaker training
# it maps speaker-id to speaker-name in the model and data-loader
speaker_manager = SpeakerManager()
speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples)
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
config.model_args.num_speakers = speaker_manager.num_speakers
# init model

View File

@ -6,7 +6,7 @@ import torch
from tests import get_tests_input_path
from TTS.config import load_config
from TTS.encoder.utils.generic_utils import setup_speaker_encoder_model
from TTS.encoder.utils.generic_utils import setup_encoder_model
from TTS.encoder.utils.io import save_checkpoint
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor
@ -28,7 +28,7 @@ class SpeakerManagerTest(unittest.TestCase):
config.audio.resample = True
# create a dummy speaker encoder
model = setup_speaker_encoder_model(config)
model = setup_encoder_model(config)
save_checkpoint(model, None, None, get_tests_input_path(), 0)
# load audio processor and speaker encoder
@ -38,19 +38,19 @@ class SpeakerManagerTest(unittest.TestCase):
# load a sample audio and compute embedding
waveform = ap.load_wav(sample_wav_path)
mel = ap.melspectrogram(waveform)
d_vector = manager.compute_d_vector(mel)
d_vector = manager.compute_embeddings(mel)
assert d_vector.shape[1] == 256
# compute d_vector directly from an input file
d_vector = manager.compute_d_vector_from_clip(sample_wav_path)
d_vector2 = manager.compute_d_vector_from_clip(sample_wav_path)
d_vector = manager.compute_embedding_from_clip(sample_wav_path)
d_vector2 = manager.compute_embedding_from_clip(sample_wav_path)
d_vector = torch.FloatTensor(d_vector)
d_vector2 = torch.FloatTensor(d_vector2)
assert d_vector.shape[0] == 256
assert (d_vector - d_vector2).sum() == 0.0
# compute d_vector from a list of wav files.
d_vector3 = manager.compute_d_vector_from_clip([sample_wav_path, sample_wav_path2])
d_vector3 = manager.compute_embedding_from_clip([sample_wav_path, sample_wav_path2])
d_vector3 = torch.FloatTensor(d_vector3)
assert d_vector3.shape[0] == 256
assert (d_vector - d_vector3).sum() != 0.0
@ -62,14 +62,14 @@ class SpeakerManagerTest(unittest.TestCase):
def test_speakers_file_processing():
manager = SpeakerManager(d_vectors_file_path=d_vectors_file_path)
print(manager.num_speakers)
print(manager.d_vector_dim)
print(manager.embedding_dim)
print(manager.clip_ids)
d_vector = manager.get_d_vector_by_clip(manager.clip_ids[0])
d_vector = manager.get_embedding_by_clip(manager.clip_ids[0])
assert len(d_vector) == 256
d_vectors = manager.get_d_vectors_by_speaker(manager.speaker_names[0])
d_vectors = manager.get_embeddings_by_name(manager.speaker_names[0])
assert len(d_vectors[0]) == 256
d_vector1 = manager.get_mean_d_vector(manager.speaker_names[0], num_samples=2, randomize=True)
d_vector1 = manager.get_mean_embedding(manager.speaker_names[0], num_samples=2, randomize=True)
assert len(d_vector1) == 256
d_vector2 = manager.get_mean_d_vector(manager.speaker_names[0], num_samples=2, randomize=False)
d_vector2 = manager.get_mean_embedding(manager.speaker_names[0], num_samples=2, randomize=False)
assert len(d_vector2) == 256
assert np.sum(np.array(d_vector1) - np.array(d_vector2)) != 0

View File

@ -86,7 +86,7 @@ class TestGlowTTS(unittest.TestCase):
model = GlowTTS(config)
model.speaker_manager = speaker_manager
model.init_multispeaker(config)
self.assertEqual(model.c_in_channels, speaker_manager.d_vector_dim)
self.assertEqual(model.c_in_channels, speaker_manager.embedding_dim)
self.assertEqual(model.num_speakers, speaker_manager.num_speakers)
def test_unlock_act_norm_layers(self):

View File

@ -7,7 +7,7 @@ from trainer.logging.tensorboard_logger import TensorboardLogger
from tests import assertHasAttr, assertHasNotAttr, get_tests_data_path, get_tests_input_path, get_tests_output_path
from TTS.config import load_config
from TTS.encoder.utils.generic_utils import setup_speaker_encoder_model
from TTS.encoder.utils.generic_utils import setup_encoder_model
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.models.vits import Vits, VitsArgs, amp_to_db, db_to_amp, load_audio, spec_to_mel, wav_to_mel, wav_to_spec
from TTS.tts.utils.speakers import SpeakerManager
@ -242,9 +242,9 @@ class TestVits(unittest.TestCase):
speaker_encoder_config = load_config(SPEAKER_ENCODER_CONFIG)
speaker_encoder_config.model_params["use_torch_spec"] = True
speaker_encoder = setup_speaker_encoder_model(speaker_encoder_config).to(device)
speaker_encoder = setup_encoder_model(speaker_encoder_config).to(device)
speaker_manager = SpeakerManager()
speaker_manager.speaker_encoder = speaker_encoder
speaker_manager.encoder = speaker_encoder
args = VitsArgs(
language_ids_file=LANG_FILE,

View File

@ -38,7 +38,7 @@ def test_run_all_models():
language_manager = LanguageManager(language_ids_file_path=language_files[0])
language_id = language_manager.language_names[0]
speaker_id = list(speaker_manager.speaker_ids.keys())[0]
speaker_id = list(speaker_manager.ids.keys())[0]
run_cli(
f"tts --model_name {model_name} "
f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}" --language_idx "{language_id}" '