mirror of https://github.com/coqui-ai/TTS.git
Add EmbeddingManager and BaseIDManager (#1374)
This commit is contained in:
parent
1b22f03e98
commit
060e0f9368
|
@ -49,7 +49,7 @@ encoder_manager = SpeakerManager(
|
|||
use_cuda=args.use_cuda,
|
||||
)
|
||||
|
||||
class_name_key = encoder_manager.speaker_encoder_config.class_name_key
|
||||
class_name_key = encoder_manager.encoder_config.class_name_key
|
||||
|
||||
# compute speaker embeddings
|
||||
speaker_mapping = {}
|
||||
|
@ -63,10 +63,10 @@ for idx, wav_file in enumerate(tqdm(wav_files)):
|
|||
wav_file_name = os.path.basename(wav_file)
|
||||
if args.old_file is not None and wav_file_name in encoder_manager.clip_ids:
|
||||
# get the embedding from the old file
|
||||
embedd = encoder_manager.get_d_vector_by_clip(wav_file_name)
|
||||
embedd = encoder_manager.get_embedding_by_clip(wav_file_name)
|
||||
else:
|
||||
# extract the embedding
|
||||
embedd = encoder_manager.compute_d_vector_from_clip(wav_file)
|
||||
embedd = encoder_manager.compute_embedding_from_clip(wav_file)
|
||||
|
||||
# create speaker_mapping if target dataset is defined
|
||||
speaker_mapping[wav_file_name] = {}
|
||||
|
|
|
@ -11,8 +11,8 @@ from TTS.tts.utils.speakers import SpeakerManager
|
|||
|
||||
def compute_encoder_accuracy(dataset_items, encoder_manager):
|
||||
|
||||
class_name_key = encoder_manager.speaker_encoder_config.class_name_key
|
||||
map_classid_to_classname = getattr(encoder_manager.speaker_encoder_config, "map_classid_to_classname", None)
|
||||
class_name_key = encoder_manager.encoder_config.class_name_key
|
||||
map_classid_to_classname = getattr(encoder_manager.encoder_config, "map_classid_to_classname", None)
|
||||
|
||||
class_acc_dict = {}
|
||||
|
||||
|
@ -22,13 +22,13 @@ def compute_encoder_accuracy(dataset_items, encoder_manager):
|
|||
wav_file = item["audio_file"]
|
||||
|
||||
# extract the embedding
|
||||
embedd = encoder_manager.compute_d_vector_from_clip(wav_file)
|
||||
if encoder_manager.speaker_encoder_criterion is not None and map_classid_to_classname is not None:
|
||||
embedd = encoder_manager.compute_embedding_from_clip(wav_file)
|
||||
if encoder_manager.encoder_criterion is not None and map_classid_to_classname is not None:
|
||||
embedding = torch.FloatTensor(embedd).unsqueeze(0)
|
||||
if encoder_manager.use_cuda:
|
||||
embedding = embedding.cuda()
|
||||
|
||||
class_id = encoder_manager.speaker_encoder_criterion.softmax.inference(embedding).item()
|
||||
class_id = encoder_manager.encoder_criterion.softmax.inference(embedding).item()
|
||||
predicted_label = map_classid_to_classname[str(class_id)]
|
||||
else:
|
||||
predicted_label = None
|
||||
|
|
|
@ -37,8 +37,8 @@ def setup_loader(ap, r, verbose=False):
|
|||
precompute_num_workers=0,
|
||||
use_noise_augment=False,
|
||||
verbose=verbose,
|
||||
speaker_id_mapping=speaker_manager.speaker_ids if c.use_speaker_embedding else None,
|
||||
d_vector_mapping=speaker_manager.d_vectors if c.use_d_vector_file else None,
|
||||
speaker_id_mapping=speaker_manager.ids if c.use_speaker_embedding else None,
|
||||
d_vector_mapping=speaker_manager.embeddings if c.use_d_vector_file else None,
|
||||
)
|
||||
|
||||
if c.use_phonemes and c.compute_input_seq_cache:
|
||||
|
|
|
@ -278,7 +278,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
print(
|
||||
" > Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model."
|
||||
)
|
||||
print(synthesizer.tts_model.speaker_manager.speaker_ids)
|
||||
print(synthesizer.tts_model.speaker_manager.ids)
|
||||
return
|
||||
|
||||
# query langauge ids of a multi-lingual model.
|
||||
|
@ -286,7 +286,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
print(
|
||||
" > Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model."
|
||||
)
|
||||
print(synthesizer.tts_model.language_manager.language_id_mapping)
|
||||
print(synthesizer.tts_model.language_manager.ids)
|
||||
return
|
||||
|
||||
# check the arguments against a multi-speaker model.
|
||||
|
|
|
@ -12,7 +12,7 @@ from trainer.torch import NoamLR
|
|||
from trainer.trainer_utils import get_optimizer
|
||||
|
||||
from TTS.encoder.dataset import EncoderDataset
|
||||
from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_speaker_encoder_model
|
||||
from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_encoder_model
|
||||
from TTS.encoder.utils.samplers import PerfectBatchSampler
|
||||
from TTS.encoder.utils.training import init_training
|
||||
from TTS.encoder.utils.visual import plot_embeddings
|
||||
|
@ -258,7 +258,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
global train_classes
|
||||
|
||||
ap = AudioProcessor(**c.audio)
|
||||
model = setup_speaker_encoder_model(c)
|
||||
model = setup_encoder_model(c)
|
||||
|
||||
optimizer = get_optimizer(c.optimizer, c.optimizer_params, c.lr, model)
|
||||
|
||||
|
|
|
@ -125,7 +125,7 @@ def to_camel(text):
|
|||
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
||||
|
||||
|
||||
def setup_speaker_encoder_model(config: "Coqpit"):
|
||||
def setup_encoder_model(config: "Coqpit"):
|
||||
if config.model_params["model_name"].lower() == "lstm":
|
||||
model = LSTMSpeakerEncoder(
|
||||
config.model_params["input_dim"],
|
||||
|
|
|
@ -143,7 +143,7 @@ def index():
|
|||
"index.html",
|
||||
show_details=args.show_details,
|
||||
use_multi_speaker=use_multi_speaker,
|
||||
speaker_ids=speaker_manager.speaker_ids if speaker_manager is not None else None,
|
||||
speaker_ids=speaker_manager.ids if speaker_manager is not None else None,
|
||||
use_gst=use_gst,
|
||||
)
|
||||
|
||||
|
|
|
@ -136,18 +136,18 @@ class BaseTTS(BaseTrainerModel):
|
|||
if hasattr(self, "speaker_manager"):
|
||||
if config.use_d_vector_file:
|
||||
if speaker_name is None:
|
||||
d_vector = self.speaker_manager.get_random_d_vector()
|
||||
d_vector = self.speaker_manager.get_random_embeddings()
|
||||
else:
|
||||
d_vector = self.speaker_manager.get_d_vector_by_speaker(speaker_name)
|
||||
d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name)
|
||||
elif config.use_speaker_embedding:
|
||||
if speaker_name is None:
|
||||
speaker_id = self.speaker_manager.get_random_speaker_id()
|
||||
speaker_id = self.speaker_manager.get_random_id()
|
||||
else:
|
||||
speaker_id = self.speaker_manager.speaker_ids[speaker_name]
|
||||
speaker_id = self.speaker_manager.ids[speaker_name]
|
||||
|
||||
# get language id
|
||||
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
|
||||
language_id = self.language_manager.language_id_mapping[language_name]
|
||||
language_id = self.language_manager.ids[language_name]
|
||||
|
||||
return {
|
||||
"text": text,
|
||||
|
@ -279,23 +279,19 @@ class BaseTTS(BaseTrainerModel):
|
|||
# setup multi-speaker attributes
|
||||
if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
|
||||
if hasattr(config, "model_args"):
|
||||
speaker_id_mapping = (
|
||||
self.speaker_manager.speaker_ids if config.model_args.use_speaker_embedding else None
|
||||
)
|
||||
d_vector_mapping = self.speaker_manager.d_vectors if config.model_args.use_d_vector_file else None
|
||||
speaker_id_mapping = self.speaker_manager.ids if config.model_args.use_speaker_embedding else None
|
||||
d_vector_mapping = self.speaker_manager.embeddings if config.model_args.use_d_vector_file else None
|
||||
config.use_d_vector_file = config.model_args.use_d_vector_file
|
||||
else:
|
||||
speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None
|
||||
d_vector_mapping = self.speaker_manager.d_vectors if config.use_d_vector_file else None
|
||||
speaker_id_mapping = self.speaker_manager.ids if config.use_speaker_embedding else None
|
||||
d_vector_mapping = self.speaker_manager.embeddings if config.use_d_vector_file else None
|
||||
else:
|
||||
speaker_id_mapping = None
|
||||
d_vector_mapping = None
|
||||
|
||||
# setup multi-lingual attributes
|
||||
if hasattr(self, "language_manager") and self.language_manager is not None:
|
||||
language_id_mapping = (
|
||||
self.language_manager.language_id_mapping if self.args.use_language_embedding else None
|
||||
)
|
||||
language_id_mapping = self.language_manager.ids if self.args.use_language_embedding else None
|
||||
else:
|
||||
language_id_mapping = None
|
||||
|
||||
|
@ -352,13 +348,13 @@ class BaseTTS(BaseTrainerModel):
|
|||
|
||||
d_vector = None
|
||||
if self.config.use_d_vector_file:
|
||||
d_vector = [self.speaker_manager.d_vectors[name]["embedding"] for name in self.speaker_manager.d_vectors]
|
||||
d_vector = [self.speaker_manager.embeddings[name]["embedding"] for name in self.speaker_manager.embeddings]
|
||||
d_vector = (random.sample(sorted(d_vector), 1),)
|
||||
|
||||
aux_inputs = {
|
||||
"speaker_id": None
|
||||
if not self.config.use_speaker_embedding
|
||||
else random.sample(sorted(self.speaker_manager.speaker_ids.values()), 1),
|
||||
else random.sample(sorted(self.speaker_manager.ids.values()), 1),
|
||||
"d_vector": d_vector,
|
||||
"style_wav": None, # TODO: handle GST style input
|
||||
}
|
||||
|
@ -405,7 +401,7 @@ class BaseTTS(BaseTrainerModel):
|
|||
"""Save the speaker.json and language_ids.json at the beginning of the training. Also update both paths."""
|
||||
if self.speaker_manager is not None:
|
||||
output_path = os.path.join(trainer.output_path, "speakers.json")
|
||||
self.speaker_manager.save_speaker_ids_to_file(output_path)
|
||||
self.speaker_manager.save_ids_to_file(output_path)
|
||||
trainer.config.speakers_file = output_path
|
||||
# some models don't have `model_args` set
|
||||
if hasattr(trainer.config, "model_args"):
|
||||
|
@ -416,7 +412,7 @@ class BaseTTS(BaseTrainerModel):
|
|||
|
||||
if hasattr(self, "language_manager") and self.language_manager is not None:
|
||||
output_path = os.path.join(trainer.output_path, "language_ids.json")
|
||||
self.language_manager.save_language_ids_to_file(output_path)
|
||||
self.language_manager.save_ids_to_file(output_path)
|
||||
trainer.config.language_ids_file = output_path
|
||||
if hasattr(trainer.config, "model_args"):
|
||||
trainer.config.model_args.language_ids_file = output_path
|
||||
|
|
|
@ -124,7 +124,7 @@ class GlowTTS(BaseTTS):
|
|||
)
|
||||
if self.speaker_manager is not None:
|
||||
assert (
|
||||
config.d_vector_dim == self.speaker_manager.d_vector_dim
|
||||
config.d_vector_dim == self.speaker_manager.embedding_dim
|
||||
), " [!] d-vector dimension mismatch b/w config and speaker manager."
|
||||
# init speaker embedding layer
|
||||
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||
|
|
|
@ -652,28 +652,28 @@ class Vits(BaseTTS):
|
|||
|
||||
# TODO: make this a function
|
||||
if self.args.use_speaker_encoder_as_loss:
|
||||
if self.speaker_manager.speaker_encoder is None and (
|
||||
if self.speaker_manager.encoder is None and (
|
||||
not self.args.speaker_encoder_model_path or not self.args.speaker_encoder_config_path
|
||||
):
|
||||
raise RuntimeError(
|
||||
" [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!"
|
||||
)
|
||||
|
||||
self.speaker_manager.speaker_encoder.eval()
|
||||
self.speaker_manager.encoder.eval()
|
||||
print(" > External Speaker Encoder Loaded !!")
|
||||
|
||||
if (
|
||||
hasattr(self.speaker_manager.speaker_encoder, "audio_config")
|
||||
and self.config.audio["sample_rate"] != self.speaker_manager.speaker_encoder.audio_config["sample_rate"]
|
||||
hasattr(self.speaker_manager.encoder, "audio_config")
|
||||
and self.config.audio["sample_rate"] != self.speaker_manager.encoder.audio_config["sample_rate"]
|
||||
):
|
||||
self.audio_transform = torchaudio.transforms.Resample(
|
||||
orig_freq=self.audio_config["sample_rate"],
|
||||
new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"],
|
||||
new_freq=self.speaker_manager.encoder.audio_config["sample_rate"],
|
||||
)
|
||||
# pylint: disable=W0101,W0105
|
||||
self.audio_transform = torchaudio.transforms.Resample(
|
||||
orig_freq=self.config.audio.sample_rate,
|
||||
new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"],
|
||||
new_freq=self.speaker_manager.encoder.audio_config["sample_rate"],
|
||||
)
|
||||
|
||||
def _init_speaker_embedding(self):
|
||||
|
@ -887,7 +887,7 @@ class Vits(BaseTTS):
|
|||
pad_short=True,
|
||||
)
|
||||
|
||||
if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None:
|
||||
if self.args.use_speaker_encoder_as_loss and self.speaker_manager.encoder is not None:
|
||||
# concate generated and GT waveforms
|
||||
wavs_batch = torch.cat((wav_seg, o), dim=0)
|
||||
|
||||
|
@ -896,7 +896,7 @@ class Vits(BaseTTS):
|
|||
if self.audio_transform is not None:
|
||||
wavs_batch = self.audio_transform(wavs_batch)
|
||||
|
||||
pred_embs = self.speaker_manager.speaker_encoder.forward(wavs_batch, l2_norm=True)
|
||||
pred_embs = self.speaker_manager.encoder.forward(wavs_batch, l2_norm=True)
|
||||
|
||||
# split generated and GT speaker embeddings
|
||||
gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0)
|
||||
|
@ -1223,18 +1223,18 @@ class Vits(BaseTTS):
|
|||
if hasattr(self, "speaker_manager"):
|
||||
if config.use_d_vector_file:
|
||||
if speaker_name is None:
|
||||
d_vector = self.speaker_manager.get_random_d_vector()
|
||||
d_vector = self.speaker_manager.get_random_embeddings()
|
||||
else:
|
||||
d_vector = self.speaker_manager.get_mean_d_vector(speaker_name, num_samples=None, randomize=False)
|
||||
d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False)
|
||||
elif config.use_speaker_embedding:
|
||||
if speaker_name is None:
|
||||
speaker_id = self.speaker_manager.get_random_speaker_id()
|
||||
speaker_id = self.speaker_manager.get_random_id()
|
||||
else:
|
||||
speaker_id = self.speaker_manager.speaker_ids[speaker_name]
|
||||
speaker_id = self.speaker_manager.ids[speaker_name]
|
||||
|
||||
# get language id
|
||||
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
|
||||
language_id = self.language_manager.language_id_mapping[language_name]
|
||||
language_id = self.language_manager.ids[language_name]
|
||||
|
||||
return {
|
||||
"text": text,
|
||||
|
@ -1289,26 +1289,22 @@ class Vits(BaseTTS):
|
|||
d_vectors = None
|
||||
|
||||
# get numerical speaker ids from speaker names
|
||||
if self.speaker_manager is not None and self.speaker_manager.speaker_ids and self.args.use_speaker_embedding:
|
||||
speaker_ids = [self.speaker_manager.speaker_ids[sn] for sn in batch["speaker_names"]]
|
||||
if self.speaker_manager is not None and self.speaker_manager.ids and self.args.use_speaker_embedding:
|
||||
speaker_ids = [self.speaker_manager.ids[sn] for sn in batch["speaker_names"]]
|
||||
|
||||
if speaker_ids is not None:
|
||||
speaker_ids = torch.LongTensor(speaker_ids)
|
||||
batch["speaker_ids"] = speaker_ids
|
||||
|
||||
# get d_vectors from audio file names
|
||||
if self.speaker_manager is not None and self.speaker_manager.d_vectors and self.args.use_d_vector_file:
|
||||
d_vector_mapping = self.speaker_manager.d_vectors
|
||||
if self.speaker_manager is not None and self.speaker_manager.embeddings and self.args.use_d_vector_file:
|
||||
d_vector_mapping = self.speaker_manager.embeddings
|
||||
d_vectors = [d_vector_mapping[w]["embedding"] for w in batch["audio_files"]]
|
||||
d_vectors = torch.FloatTensor(d_vectors)
|
||||
|
||||
# get language ids from language names
|
||||
if (
|
||||
self.language_manager is not None
|
||||
and self.language_manager.language_id_mapping
|
||||
and self.args.use_language_embedding
|
||||
):
|
||||
language_ids = [self.language_manager.language_id_mapping[ln] for ln in batch["language_names"]]
|
||||
if self.language_manager is not None and self.language_manager.ids and self.args.use_language_embedding:
|
||||
language_ids = [self.language_manager.ids[ln] for ln in batch["language_names"]]
|
||||
|
||||
if language_ids is not None:
|
||||
language_ids = torch.LongTensor(language_ids)
|
||||
|
@ -1490,7 +1486,7 @@ class Vits(BaseTTS):
|
|||
language_manager = LanguageManager.init_from_config(config)
|
||||
|
||||
if config.model_args.speaker_encoder_model_path:
|
||||
speaker_manager.init_speaker_encoder(
|
||||
speaker_manager.init_encoder(
|
||||
config.model_args.speaker_encoder_model_path, config.model_args.speaker_encoder_config_path
|
||||
)
|
||||
return Vits(new_config, ap, tokenizer, speaker_manager, language_manager)
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import json
|
||||
import os
|
||||
from typing import Dict, List
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import fsspec
|
||||
import numpy as np
|
||||
|
@ -8,9 +7,10 @@ import torch
|
|||
from coqpit import Coqpit
|
||||
|
||||
from TTS.config import check_config_and_model_args
|
||||
from TTS.tts.utils.managers import BaseIDManager
|
||||
|
||||
|
||||
class LanguageManager:
|
||||
class LanguageManager(BaseIDManager):
|
||||
"""Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information
|
||||
in a way that can be queried by language.
|
||||
|
||||
|
@ -25,37 +25,23 @@ class LanguageManager:
|
|||
>>> language_id_mapper = manager.language_ids
|
||||
"""
|
||||
|
||||
language_id_mapping: Dict = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
language_ids_file_path: str = "",
|
||||
config: Coqpit = None,
|
||||
):
|
||||
self.language_id_mapping = {}
|
||||
if language_ids_file_path:
|
||||
self.set_language_ids_from_file(language_ids_file_path)
|
||||
super().__init__(id_file_path=language_ids_file_path)
|
||||
|
||||
if config:
|
||||
self.set_language_ids_from_config(config)
|
||||
|
||||
@staticmethod
|
||||
def _load_json(json_file_path: str) -> Dict:
|
||||
with fsspec.open(json_file_path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
@staticmethod
|
||||
def _save_json(json_file_path: str, data: dict) -> None:
|
||||
with fsspec.open(json_file_path, "w") as f:
|
||||
json.dump(data, f, indent=4)
|
||||
|
||||
@property
|
||||
def num_languages(self) -> int:
|
||||
return len(list(self.language_id_mapping.keys()))
|
||||
return len(list(self.ids.keys()))
|
||||
|
||||
@property
|
||||
def language_names(self) -> List:
|
||||
return list(self.language_id_mapping.keys())
|
||||
return list(self.ids.keys())
|
||||
|
||||
@staticmethod
|
||||
def parse_language_ids_from_config(c: Coqpit) -> Dict:
|
||||
|
@ -79,25 +65,24 @@ class LanguageManager:
|
|||
"""Set language IDs from config samples.
|
||||
|
||||
Args:
|
||||
items (List): Data sampled returned by `load_meta_data()`.
|
||||
c (Coqpit): Config.
|
||||
"""
|
||||
self.language_id_mapping = self.parse_language_ids_from_config(c)
|
||||
self.ids = self.parse_language_ids_from_config(c)
|
||||
|
||||
def set_language_ids_from_file(self, file_path: str) -> None:
|
||||
"""Load language ids from a json file.
|
||||
@staticmethod
|
||||
def parse_ids_from_data(items: List, parse_key: str) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the target json file.
|
||||
"""
|
||||
self.language_id_mapping = self._load_json(file_path)
|
||||
def set_ids_from_data(self, items: List, parse_key: str) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def save_language_ids_to_file(self, file_path: str) -> None:
|
||||
def save_ids_to_file(self, file_path: str) -> None:
|
||||
"""Save language IDs to a json file.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the output file.
|
||||
"""
|
||||
self._save_json(file_path, self.language_id_mapping)
|
||||
self._save_json(file_path, self.ids)
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: Coqpit) -> "LanguageManager":
|
||||
|
|
|
@ -0,0 +1,285 @@
|
|||
import json
|
||||
import random
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
import fsspec
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.encoder.utils.generic_utils import setup_encoder_model
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
class BaseIDManager:
|
||||
"""Base `ID` Manager class. Every new `ID` manager must inherit this.
|
||||
It defines common `ID` manager specific functions.
|
||||
"""
|
||||
|
||||
def __init__(self, id_file_path: str = ""):
|
||||
self.ids = {}
|
||||
|
||||
if id_file_path:
|
||||
self.load_ids_from_file(id_file_path)
|
||||
|
||||
@staticmethod
|
||||
def _load_json(json_file_path: str) -> Dict:
|
||||
with fsspec.open(json_file_path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
@staticmethod
|
||||
def _save_json(json_file_path: str, data: dict) -> None:
|
||||
with fsspec.open(json_file_path, "w") as f:
|
||||
json.dump(data, f, indent=4)
|
||||
|
||||
def set_ids_from_data(self, items: List, parse_key: str) -> None:
|
||||
"""Set IDs from data samples.
|
||||
|
||||
Args:
|
||||
items (List): Data sampled returned by `load_tts_samples()`.
|
||||
"""
|
||||
self.ids = self.parse_ids_from_data(items, parse_key=parse_key)
|
||||
|
||||
def load_ids_from_file(self, file_path: str) -> None:
|
||||
"""Set IDs from a file.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the file.
|
||||
"""
|
||||
self.ids = self._load_json(file_path)
|
||||
|
||||
def save_ids_to_file(self, file_path: str) -> None:
|
||||
"""Save IDs to a json file.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the output file.
|
||||
"""
|
||||
self._save_json(file_path, self.ids)
|
||||
|
||||
def get_random_id(self) -> Any:
|
||||
"""Get a random embedding.
|
||||
|
||||
Args:
|
||||
|
||||
Returns:
|
||||
np.ndarray: embedding.
|
||||
"""
|
||||
if self.ids:
|
||||
return self.ids[random.choices(list(self.ids.keys()))[0]]
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_ids_from_data(items: List, parse_key: str) -> Tuple[Dict]:
|
||||
"""Parse IDs from data samples retured by `load_tts_samples()`.
|
||||
|
||||
Args:
|
||||
items (list): Data sampled returned by `load_tts_samples()`.
|
||||
parse_key (str): The key to being used to parse the data.
|
||||
Returns:
|
||||
Tuple[Dict]: speaker IDs.
|
||||
"""
|
||||
classes = sorted({item[parse_key] for item in items})
|
||||
ids = {name: i for i, name in enumerate(classes)}
|
||||
return ids
|
||||
|
||||
|
||||
class EmbeddingManager(BaseIDManager):
|
||||
"""Base `Embedding` Manager class. Every new `Embedding` manager must inherit this.
|
||||
It defines common `Embedding` manager specific functions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_file_path: str = "",
|
||||
id_file_path: str = "",
|
||||
encoder_model_path: str = "",
|
||||
encoder_config_path: str = "",
|
||||
use_cuda: bool = False,
|
||||
):
|
||||
super().__init__(id_file_path=id_file_path)
|
||||
|
||||
self.embeddings = {}
|
||||
self.embeddings_by_names = {}
|
||||
self.clip_ids = []
|
||||
self.encoder = None
|
||||
self.encoder_ap = None
|
||||
self.use_cuda = use_cuda
|
||||
|
||||
if embedding_file_path:
|
||||
self.load_embeddings_from_file(embedding_file_path)
|
||||
|
||||
if encoder_model_path and encoder_config_path:
|
||||
self.init_encoder(encoder_model_path, encoder_config_path)
|
||||
|
||||
@property
|
||||
def embedding_dim(self):
|
||||
"""Dimensionality of embeddings. If embeddings are not loaded, returns zero."""
|
||||
if self.embeddings:
|
||||
return len(self.embeddings[list(self.embeddings.keys())[0]]["embedding"])
|
||||
return 0
|
||||
|
||||
def save_embeddings_to_file(self, file_path: str) -> None:
|
||||
"""Save embeddings to a json file.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the output file.
|
||||
"""
|
||||
self._save_json(file_path, self.embeddings)
|
||||
|
||||
def load_embeddings_from_file(self, file_path: str) -> None:
|
||||
"""Load embeddings from a json file.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the target json file.
|
||||
"""
|
||||
self.embeddings = self._load_json(file_path)
|
||||
|
||||
speakers = sorted({x["name"] for x in self.embeddings.values()})
|
||||
self.ids = {name: i for i, name in enumerate(speakers)}
|
||||
|
||||
self.clip_ids = list(set(sorted(clip_name for clip_name in self.embeddings.keys())))
|
||||
# cache embeddings_by_names for fast inference using a bigger speakers.json
|
||||
self.embeddings_by_names = self.get_embeddings_by_names()
|
||||
|
||||
def get_embedding_by_clip(self, clip_idx: str) -> List:
|
||||
"""Get embedding by clip ID.
|
||||
|
||||
Args:
|
||||
clip_idx (str): Target clip ID.
|
||||
|
||||
Returns:
|
||||
List: embedding as a list.
|
||||
"""
|
||||
return self.embeddings[clip_idx]["embedding"]
|
||||
|
||||
def get_embeddings_by_name(self, idx: str) -> List[List]:
|
||||
"""Get all embeddings of a speaker.
|
||||
|
||||
Args:
|
||||
idx (str): Target name.
|
||||
|
||||
Returns:
|
||||
List[List]: all the embeddings of the given speaker.
|
||||
"""
|
||||
return self.embeddings_by_names[idx]
|
||||
|
||||
def get_embeddings_by_names(self) -> Dict:
|
||||
"""Get all embeddings by names.
|
||||
|
||||
Returns:
|
||||
Dict: all the embeddings of each speaker.
|
||||
"""
|
||||
embeddings_by_names = {}
|
||||
for x in self.embeddings.values():
|
||||
if x["name"] not in embeddings_by_names.keys():
|
||||
embeddings_by_names[x["name"]] = [x["embedding"]]
|
||||
else:
|
||||
embeddings_by_names[x["name"]].append(x["embedding"])
|
||||
return embeddings_by_names
|
||||
|
||||
def get_mean_embedding(self, idx: str, num_samples: int = None, randomize: bool = False) -> np.ndarray:
|
||||
"""Get mean embedding of a idx.
|
||||
|
||||
Args:
|
||||
idx (str): Target name.
|
||||
num_samples (int, optional): Number of samples to be averaged. Defaults to None.
|
||||
randomize (bool, optional): Pick random `num_samples` of embeddings. Defaults to False.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Mean embedding.
|
||||
"""
|
||||
embeddings = self.get_embeddings_by_name(idx)
|
||||
if num_samples is None:
|
||||
embeddings = np.stack(embeddings).mean(0)
|
||||
else:
|
||||
assert len(embeddings) >= num_samples, f" [!] {idx} has number of samples < {num_samples}"
|
||||
if randomize:
|
||||
embeddings = np.stack(random.choices(embeddings, k=num_samples)).mean(0)
|
||||
else:
|
||||
embeddings = np.stack(embeddings[:num_samples]).mean(0)
|
||||
return embeddings
|
||||
|
||||
def get_random_embedding(self) -> Any:
|
||||
"""Get a random embedding.
|
||||
|
||||
Args:
|
||||
|
||||
Returns:
|
||||
np.ndarray: embedding.
|
||||
"""
|
||||
if self.embeddings:
|
||||
return self.embeddings[random.choices(list(self.embeddings.keys()))[0]]["embedding"]
|
||||
|
||||
return None
|
||||
|
||||
def get_clips(self) -> List:
|
||||
return sorted(self.embeddings.keys())
|
||||
|
||||
def init_encoder(self, model_path: str, config_path: str) -> None:
|
||||
"""Initialize a speaker encoder model.
|
||||
|
||||
Args:
|
||||
model_path (str): Model file path.
|
||||
config_path (str): Model config file path.
|
||||
"""
|
||||
self.encoder_config = load_config(config_path)
|
||||
self.encoder = setup_encoder_model(self.encoder_config)
|
||||
self.encoder_criterion = self.encoder.load_checkpoint(
|
||||
self.encoder_config, model_path, eval=True, use_cuda=self.use_cuda
|
||||
)
|
||||
self.encoder_ap = AudioProcessor(**self.encoder_config.audio)
|
||||
|
||||
def compute_embedding_from_clip(self, wav_file: Union[str, List[str]]) -> list:
|
||||
"""Compute a embedding from a given audio file.
|
||||
|
||||
Args:
|
||||
wav_file (Union[str, List[str]]): Target file path.
|
||||
|
||||
Returns:
|
||||
list: Computed embedding.
|
||||
"""
|
||||
|
||||
def _compute(wav_file: str):
|
||||
waveform = self.encoder_ap.load_wav(wav_file, sr=self.encoder_ap.sample_rate)
|
||||
if not self.encoder_config.model_params.get("use_torch_spec", False):
|
||||
m_input = self.encoder_ap.melspectrogram(waveform)
|
||||
m_input = torch.from_numpy(m_input)
|
||||
else:
|
||||
m_input = torch.from_numpy(waveform)
|
||||
|
||||
if self.use_cuda:
|
||||
m_input = m_input.cuda()
|
||||
m_input = m_input.unsqueeze(0)
|
||||
embedding = self.encoder.compute_embedding(m_input)
|
||||
return embedding
|
||||
|
||||
if isinstance(wav_file, list):
|
||||
# compute the mean embedding
|
||||
embeddings = None
|
||||
for wf in wav_file:
|
||||
embedding = _compute(wf)
|
||||
if embeddings is None:
|
||||
embeddings = embedding
|
||||
else:
|
||||
embeddings += embedding
|
||||
return (embeddings / len(wav_file))[0].tolist()
|
||||
embedding = _compute(wav_file)
|
||||
return embedding[0].tolist()
|
||||
|
||||
def compute_embeddings(self, feats: Union[torch.Tensor, np.ndarray]) -> List:
|
||||
"""Compute embedding from features.
|
||||
|
||||
Args:
|
||||
feats (Union[torch.Tensor, np.ndarray]): Input features.
|
||||
|
||||
Returns:
|
||||
List: computed embedding.
|
||||
"""
|
||||
if isinstance(feats, np.ndarray):
|
||||
feats = torch.from_numpy(feats)
|
||||
if feats.ndim == 2:
|
||||
feats = feats.unsqueeze(0)
|
||||
if self.use_cuda:
|
||||
feats = feats.cuda()
|
||||
return self.encoder.compute_embedding(feats)
|
|
@ -1,19 +1,17 @@
|
|||
import json
|
||||
import os
|
||||
import random
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import fsspec
|
||||
import numpy as np
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
|
||||
from TTS.config import get_from_config_or_model_args_with_default, load_config
|
||||
from TTS.encoder.utils.generic_utils import setup_speaker_encoder_model
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.config import get_from_config_or_model_args_with_default
|
||||
from TTS.tts.utils.managers import EmbeddingManager
|
||||
|
||||
|
||||
class SpeakerManager:
|
||||
class SpeakerManager(EmbeddingManager):
|
||||
"""Manage the speakers for multi-speaker 🐸TTS models. Load a datafile and parse the information
|
||||
in a way that can be queried by speaker or clip.
|
||||
|
||||
|
@ -50,7 +48,7 @@ class SpeakerManager:
|
|||
>>> # load a sample audio and compute embedding
|
||||
>>> waveform = ap.load_wav(sample_wav_path)
|
||||
>>> mel = ap.melspectrogram(waveform)
|
||||
>>> d_vector = manager.compute_d_vector(mel.T)
|
||||
>>> d_vector = manager.compute_embeddings(mel.T)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -62,279 +60,27 @@ class SpeakerManager:
|
|||
encoder_config_path: str = "",
|
||||
use_cuda: bool = False,
|
||||
):
|
||||
|
||||
self.d_vectors = {}
|
||||
self.speaker_ids = {}
|
||||
self.d_vectors_by_speakers = {}
|
||||
self.clip_ids = []
|
||||
self.speaker_encoder = None
|
||||
self.speaker_encoder_ap = None
|
||||
self.use_cuda = use_cuda
|
||||
super().__init__(
|
||||
embedding_file_path=d_vectors_file_path,
|
||||
id_file_path=speaker_id_file_path,
|
||||
encoder_model_path=encoder_model_path,
|
||||
encoder_config_path=encoder_config_path,
|
||||
use_cuda=use_cuda,
|
||||
)
|
||||
|
||||
if data_items:
|
||||
self.speaker_ids, _ = self.parse_speakers_from_data(data_items)
|
||||
|
||||
if d_vectors_file_path:
|
||||
self.set_d_vectors_from_file(d_vectors_file_path)
|
||||
|
||||
if speaker_id_file_path:
|
||||
self.set_speaker_ids_from_file(speaker_id_file_path)
|
||||
|
||||
if encoder_model_path and encoder_config_path:
|
||||
self.init_speaker_encoder(encoder_model_path, encoder_config_path)
|
||||
|
||||
@staticmethod
|
||||
def _load_json(json_file_path: str) -> Dict:
|
||||
with fsspec.open(json_file_path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
@staticmethod
|
||||
def _save_json(json_file_path: str, data: dict) -> None:
|
||||
with fsspec.open(json_file_path, "w") as f:
|
||||
json.dump(data, f, indent=4)
|
||||
self.set_ids_from_data(data_items, parse_key="speaker_name")
|
||||
|
||||
@property
|
||||
def num_speakers(self):
|
||||
return len(self.speaker_ids)
|
||||
return len(self.ids)
|
||||
|
||||
@property
|
||||
def speaker_names(self):
|
||||
return list(self.speaker_ids.keys())
|
||||
|
||||
@property
|
||||
def d_vector_dim(self):
|
||||
"""Dimensionality of d_vectors. If d_vectors are not loaded, returns zero."""
|
||||
if self.d_vectors:
|
||||
return len(self.d_vectors[list(self.d_vectors.keys())[0]]["embedding"])
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def parse_speakers_from_data(items: list) -> Tuple[Dict, int]:
|
||||
"""Parse speaker IDs from data samples retured by `load_tts_samples()`.
|
||||
|
||||
Args:
|
||||
items (list): Data sampled returned by `load_tts_samples()`.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, int]: speaker IDs and number of speakers.
|
||||
"""
|
||||
speakers = sorted({item["speaker_name"] for item in items})
|
||||
speaker_ids = {name: i for i, name in enumerate(speakers)}
|
||||
num_speakers = len(speaker_ids)
|
||||
return speaker_ids, num_speakers
|
||||
|
||||
def set_speaker_ids_from_data(self, items: List) -> None:
|
||||
"""Set speaker IDs from data samples.
|
||||
|
||||
Args:
|
||||
items (List): Data sampled returned by `load_tts_samples()`.
|
||||
"""
|
||||
self.speaker_ids, _ = self.parse_speakers_from_data(items)
|
||||
|
||||
def set_speaker_ids_from_file(self, file_path: str) -> None:
|
||||
"""Set speaker IDs from a file.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the file.
|
||||
"""
|
||||
self.speaker_ids = self._load_json(file_path)
|
||||
|
||||
def save_speaker_ids_to_file(self, file_path: str) -> None:
|
||||
"""Save speaker IDs to a json file.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the output file.
|
||||
"""
|
||||
self._save_json(file_path, self.speaker_ids)
|
||||
|
||||
def save_d_vectors_to_file(self, file_path: str) -> None:
|
||||
"""Save d_vectors to a json file.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the output file.
|
||||
"""
|
||||
self._save_json(file_path, self.d_vectors)
|
||||
|
||||
def set_d_vectors_from_file(self, file_path: str) -> None:
|
||||
"""Load d_vectors from a json file.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the target json file.
|
||||
"""
|
||||
self.d_vectors = self._load_json(file_path)
|
||||
|
||||
speakers = sorted({x["name"] for x in self.d_vectors.values()})
|
||||
self.speaker_ids = {name: i for i, name in enumerate(speakers)}
|
||||
|
||||
self.clip_ids = list(set(sorted(clip_name for clip_name in self.d_vectors.keys())))
|
||||
# cache d_vectors_by_speakers for fast inference using a bigger speakers.json
|
||||
self.d_vectors_by_speakers = self.get_d_vectors_by_speakers()
|
||||
|
||||
def get_d_vector_by_clip(self, clip_idx: str) -> List:
|
||||
"""Get d_vector by clip ID.
|
||||
|
||||
Args:
|
||||
clip_idx (str): Target clip ID.
|
||||
|
||||
Returns:
|
||||
List: d_vector as a list.
|
||||
"""
|
||||
return self.d_vectors[clip_idx]["embedding"]
|
||||
|
||||
def get_d_vectors_by_speaker(self, speaker_idx: str) -> List[List]:
|
||||
"""Get all d_vectors of a speaker.
|
||||
|
||||
Args:
|
||||
speaker_idx (str): Target speaker ID.
|
||||
|
||||
Returns:
|
||||
List[List]: all the d_vectors of the given speaker.
|
||||
"""
|
||||
return self.d_vectors_by_speakers[speaker_idx]
|
||||
|
||||
def get_d_vectors_by_speakers(self) -> Dict:
|
||||
"""Get all d_vectors by speaker.
|
||||
|
||||
Returns:
|
||||
Dict: all the d_vectors of each speaker.
|
||||
"""
|
||||
d_vectors_by_speakers = {}
|
||||
for x in self.d_vectors.values():
|
||||
if x["name"] not in d_vectors_by_speakers.keys():
|
||||
d_vectors_by_speakers[x["name"]] = [x["embedding"]]
|
||||
else:
|
||||
d_vectors_by_speakers[x["name"]].append(x["embedding"])
|
||||
return d_vectors_by_speakers
|
||||
|
||||
def get_mean_d_vector(self, speaker_idx: str, num_samples: int = None, randomize: bool = False) -> np.ndarray:
|
||||
"""Get mean d_vector of a speaker ID.
|
||||
|
||||
Args:
|
||||
speaker_idx (str): Target speaker ID.
|
||||
num_samples (int, optional): Number of samples to be averaged. Defaults to None.
|
||||
randomize (bool, optional): Pick random `num_samples` of d_vectors. Defaults to False.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Mean d_vector.
|
||||
"""
|
||||
d_vectors = self.get_d_vectors_by_speaker(speaker_idx)
|
||||
if num_samples is None:
|
||||
d_vectors = np.stack(d_vectors).mean(0)
|
||||
else:
|
||||
assert len(d_vectors) >= num_samples, f" [!] speaker {speaker_idx} has number of samples < {num_samples}"
|
||||
if randomize:
|
||||
d_vectors = np.stack(random.choices(d_vectors, k=num_samples)).mean(0)
|
||||
else:
|
||||
d_vectors = np.stack(d_vectors[:num_samples]).mean(0)
|
||||
return d_vectors
|
||||
|
||||
def get_random_speaker_id(self) -> Any:
|
||||
"""Get a random d_vector.
|
||||
|
||||
Args:
|
||||
|
||||
Returns:
|
||||
np.ndarray: d_vector.
|
||||
"""
|
||||
if self.speaker_ids:
|
||||
return self.speaker_ids[random.choices(list(self.speaker_ids.keys()))[0]]
|
||||
|
||||
return None
|
||||
|
||||
def get_random_d_vector(self) -> Any:
|
||||
"""Get a random D ID.
|
||||
|
||||
Args:
|
||||
|
||||
Returns:
|
||||
np.ndarray: d_vector.
|
||||
"""
|
||||
if self.d_vectors:
|
||||
return self.d_vectors[random.choices(list(self.d_vectors.keys()))[0]]["embedding"]
|
||||
|
||||
return None
|
||||
return list(self.ids.keys())
|
||||
|
||||
def get_speakers(self) -> List:
|
||||
return self.speaker_ids
|
||||
|
||||
def get_clips(self) -> List:
|
||||
return sorted(self.d_vectors.keys())
|
||||
|
||||
def init_speaker_encoder(self, model_path: str, config_path: str) -> None:
|
||||
"""Initialize a speaker encoder model.
|
||||
|
||||
Args:
|
||||
model_path (str): Model file path.
|
||||
config_path (str): Model config file path.
|
||||
"""
|
||||
self.speaker_encoder_config = load_config(config_path)
|
||||
self.speaker_encoder = setup_speaker_encoder_model(self.speaker_encoder_config)
|
||||
self.speaker_encoder_criterion = self.speaker_encoder.load_checkpoint(
|
||||
self.speaker_encoder_config, model_path, eval=True, use_cuda=self.use_cuda
|
||||
)
|
||||
self.speaker_encoder_ap = AudioProcessor(**self.speaker_encoder_config.audio)
|
||||
|
||||
def compute_d_vector_from_clip(self, wav_file: Union[str, List[str]]) -> list:
|
||||
"""Compute a d_vector from a given audio file.
|
||||
|
||||
Args:
|
||||
wav_file (Union[str, List[str]]): Target file path.
|
||||
|
||||
Returns:
|
||||
list: Computed d_vector.
|
||||
"""
|
||||
|
||||
def _compute(wav_file: str):
|
||||
waveform = self.speaker_encoder_ap.load_wav(wav_file, sr=self.speaker_encoder_ap.sample_rate)
|
||||
if not self.speaker_encoder_config.model_params.get("use_torch_spec", False):
|
||||
m_input = self.speaker_encoder_ap.melspectrogram(waveform)
|
||||
m_input = torch.from_numpy(m_input)
|
||||
else:
|
||||
m_input = torch.from_numpy(waveform)
|
||||
|
||||
if self.use_cuda:
|
||||
m_input = m_input.cuda()
|
||||
m_input = m_input.unsqueeze(0)
|
||||
d_vector = self.speaker_encoder.compute_embedding(m_input)
|
||||
return d_vector
|
||||
|
||||
if isinstance(wav_file, list):
|
||||
# compute the mean d_vector
|
||||
d_vectors = None
|
||||
for wf in wav_file:
|
||||
d_vector = _compute(wf)
|
||||
if d_vectors is None:
|
||||
d_vectors = d_vector
|
||||
else:
|
||||
d_vectors += d_vector
|
||||
return (d_vectors / len(wav_file))[0].tolist()
|
||||
d_vector = _compute(wav_file)
|
||||
return d_vector[0].tolist()
|
||||
|
||||
def compute_d_vector(self, feats: Union[torch.Tensor, np.ndarray]) -> List:
|
||||
"""Compute d_vector from features.
|
||||
|
||||
Args:
|
||||
feats (Union[torch.Tensor, np.ndarray]): Input features.
|
||||
|
||||
Returns:
|
||||
List: computed d_vector.
|
||||
"""
|
||||
if isinstance(feats, np.ndarray):
|
||||
feats = torch.from_numpy(feats)
|
||||
if feats.ndim == 2:
|
||||
feats = feats.unsqueeze(0)
|
||||
if self.use_cuda:
|
||||
feats = feats.cuda()
|
||||
return self.speaker_encoder.compute_embedding(feats)
|
||||
|
||||
def run_umap(self):
|
||||
# TODO: implement speaker encoder
|
||||
raise NotImplementedError
|
||||
|
||||
def plot_embeddings(self):
|
||||
# TODO: implement speaker encoder
|
||||
raise NotImplementedError
|
||||
return self.ids
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "SpeakerManager":
|
||||
|
@ -420,7 +166,7 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
|
|||
speaker_manager = SpeakerManager()
|
||||
if c.use_speaker_embedding:
|
||||
if data is not None:
|
||||
speaker_manager.set_speaker_ids_from_data(data)
|
||||
speaker_manager.set_ids_from_data(data, parse_key="speaker_name")
|
||||
if restore_path:
|
||||
speakers_file = _set_file_path(restore_path)
|
||||
# restoring speaker manager from a previous run.
|
||||
|
@ -432,27 +178,27 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
|
|||
raise RuntimeError(
|
||||
"You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.d_vector_file"
|
||||
)
|
||||
speaker_manager.load_d_vectors_file(c.d_vector_file)
|
||||
speaker_manager.set_d_vectors_from_file(speakers_file)
|
||||
speaker_manager.load_embeddings_from_file(c.d_vector_file)
|
||||
speaker_manager.load_embeddings_from_file(speakers_file)
|
||||
elif not c.use_d_vector_file: # restor speaker manager with speaker ID file.
|
||||
speaker_ids_from_data = speaker_manager.speaker_ids
|
||||
speaker_manager.set_speaker_ids_from_file(speakers_file)
|
||||
speaker_ids_from_data = speaker_manager.ids
|
||||
speaker_manager.load_ids_from_file(speakers_file)
|
||||
assert all(
|
||||
speaker in speaker_manager.speaker_ids for speaker in speaker_ids_from_data
|
||||
speaker in speaker_manager.ids for speaker in speaker_ids_from_data
|
||||
), " [!] You cannot introduce new speakers to a pre-trained model."
|
||||
elif c.use_d_vector_file and c.d_vector_file:
|
||||
# new speaker manager with external speaker embeddings.
|
||||
speaker_manager.set_d_vectors_from_file(c.d_vector_file)
|
||||
speaker_manager.load_embeddings_from_file(c.d_vector_file)
|
||||
elif c.use_d_vector_file and not c.d_vector_file:
|
||||
raise "use_d_vector_file is True, so you need pass a external speaker embedding file."
|
||||
elif c.use_speaker_embedding and "speakers_file" in c and c.speakers_file:
|
||||
# new speaker manager with speaker IDs file.
|
||||
speaker_manager.set_speaker_ids_from_file(c.speakers_file)
|
||||
speaker_manager.load_ids_from_file(c.speakers_file)
|
||||
|
||||
if speaker_manager.num_speakers > 0:
|
||||
print(
|
||||
" > Speaker manager is loaded with {} speakers: {}".format(
|
||||
speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids)
|
||||
speaker_manager.num_speakers, ", ".join(speaker_manager.ids)
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -461,9 +207,9 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
|
|||
out_file_path = os.path.join(out_path, "speakers.json")
|
||||
print(f" > Saving `speakers.json` to {out_file_path}.")
|
||||
if c.use_d_vector_file and c.d_vector_file:
|
||||
speaker_manager.save_d_vectors_to_file(out_file_path)
|
||||
speaker_manager.save_embeddings_to_file(out_file_path)
|
||||
else:
|
||||
speaker_manager.save_speaker_ids_to_file(out_file_path)
|
||||
speaker_manager.save_ids_to_file(out_file_path)
|
||||
return speaker_manager
|
||||
|
||||
|
||||
|
|
|
@ -122,7 +122,7 @@ class Synthesizer(object):
|
|||
self.tts_model.cuda()
|
||||
|
||||
if self.encoder_checkpoint and hasattr(self.tts_model, "speaker_manager"):
|
||||
self.tts_model.speaker_manager.init_speaker_encoder(self.encoder_checkpoint, self.encoder_config)
|
||||
self.tts_model.speaker_manager.init_encoder(self.encoder_checkpoint, self.encoder_config)
|
||||
|
||||
def _set_speaker_encoder_paths_from_tts_config(self):
|
||||
"""Set the encoder paths from the tts model config for models with speaker encoders."""
|
||||
|
@ -212,17 +212,17 @@ class Synthesizer(object):
|
|||
# handle multi-speaker
|
||||
speaker_embedding = None
|
||||
speaker_id = None
|
||||
if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "speaker_ids"):
|
||||
if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "ids"):
|
||||
if speaker_name and isinstance(speaker_name, str):
|
||||
if self.tts_config.use_d_vector_file:
|
||||
# get the average speaker embedding from the saved d_vectors.
|
||||
speaker_embedding = self.tts_model.speaker_manager.get_mean_d_vector(
|
||||
speaker_embedding = self.tts_model.speaker_manager.get_mean_embedding(
|
||||
speaker_name, num_samples=None, randomize=False
|
||||
)
|
||||
speaker_embedding = np.array(speaker_embedding)[None, :] # [1 x embedding_dim]
|
||||
else:
|
||||
# get speaker idx from the speaker name
|
||||
speaker_id = self.tts_model.speaker_manager.speaker_ids[speaker_name]
|
||||
speaker_id = self.tts_model.speaker_manager.ids[speaker_name]
|
||||
|
||||
elif not speaker_name and not speaker_wav:
|
||||
raise ValueError(
|
||||
|
@ -244,7 +244,7 @@ class Synthesizer(object):
|
|||
hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None
|
||||
):
|
||||
if language_name and isinstance(language_name, str):
|
||||
language_id = self.tts_model.language_manager.language_id_mapping[language_name]
|
||||
language_id = self.tts_model.language_manager.ids[language_name]
|
||||
|
||||
elif not language_name:
|
||||
raise ValueError(
|
||||
|
@ -260,7 +260,7 @@ class Synthesizer(object):
|
|||
|
||||
# compute a new d_vector from the given clip.
|
||||
if speaker_wav is not None:
|
||||
speaker_embedding = self.tts_model.speaker_manager.compute_d_vector_from_clip(speaker_wav)
|
||||
speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav)
|
||||
|
||||
use_gl = self.vocoder_model is None
|
||||
|
||||
|
@ -319,7 +319,7 @@ class Synthesizer(object):
|
|||
if reference_speaker_name and isinstance(reference_speaker_name, str):
|
||||
if self.tts_config.use_d_vector_file:
|
||||
# get the speaker embedding from the saved d_vectors.
|
||||
reference_speaker_embedding = self.tts_model.speaker_manager.get_d_vectors_by_speaker(
|
||||
reference_speaker_embedding = self.tts_model.speaker_manager.get_embeddings_by_name(
|
||||
reference_speaker_name
|
||||
)[0]
|
||||
reference_speaker_embedding = np.array(reference_speaker_embedding)[
|
||||
|
@ -327,9 +327,9 @@ class Synthesizer(object):
|
|||
] # [1 x embedding_dim]
|
||||
else:
|
||||
# get speaker idx from the speaker name
|
||||
reference_speaker_id = self.tts_model.speaker_manager.speaker_ids[reference_speaker_name]
|
||||
reference_speaker_id = self.tts_model.speaker_manager.ids[reference_speaker_name]
|
||||
else:
|
||||
reference_speaker_embedding = self.tts_model.speaker_manager.compute_d_vector_from_clip(
|
||||
reference_speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(
|
||||
reference_wav
|
||||
)
|
||||
|
||||
|
|
|
@ -119,7 +119,7 @@ train_samples, eval_samples = load_tts_samples(
|
|||
# init speaker manager for multi-speaker training
|
||||
# it maps speaker-id to speaker-name in the model and data-loader
|
||||
speaker_manager = SpeakerManager()
|
||||
speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples)
|
||||
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
|
||||
config.model_args.num_speakers = speaker_manager.num_speakers
|
||||
|
||||
language_manager = LanguageManager(config=config)
|
||||
|
|
|
@ -81,7 +81,7 @@ train_samples, eval_samples = load_tts_samples(
|
|||
# init speaker manager for multi-speaker training
|
||||
# it maps speaker-id to speaker-name in the model and data-loader
|
||||
speaker_manager = SpeakerManager()
|
||||
speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples)
|
||||
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
|
||||
config.model_args.num_speakers = speaker_manager.num_speakers
|
||||
|
||||
# init model
|
||||
|
|
|
@ -79,7 +79,7 @@ train_samples, eval_samples = load_tts_samples(
|
|||
# init speaker manager for multi-speaker training
|
||||
# it maps speaker-id to speaker-name in the model and data-loader
|
||||
speaker_manager = SpeakerManager()
|
||||
speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples)
|
||||
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
|
||||
config.model_args.num_speakers = speaker_manager.num_speakers
|
||||
|
||||
# init model
|
||||
|
|
|
@ -79,7 +79,7 @@ train_samples, eval_samples = load_tts_samples(
|
|||
# init speaker manager for multi-speaker training
|
||||
# it maps speaker-id to speaker-name in the model and data-loader
|
||||
speaker_manager = SpeakerManager()
|
||||
speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples)
|
||||
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
|
||||
config.num_speakers = speaker_manager.num_speakers
|
||||
|
||||
# init model
|
||||
|
|
|
@ -79,7 +79,7 @@ train_samples, eval_samples = load_tts_samples(
|
|||
# init speaker manager for multi-speaker training
|
||||
# it maps speaker-id to speaker-name in the model and data-loader
|
||||
speaker_manager = SpeakerManager()
|
||||
speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples)
|
||||
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
|
||||
config.model_args.num_speakers = speaker_manager.num_speakers
|
||||
|
||||
# init model
|
||||
|
|
|
@ -82,7 +82,7 @@ train_samples, eval_samples = load_tts_samples(
|
|||
# init speaker manager for multi-speaker training
|
||||
# it mainly handles speaker-id to speaker-name for the model and the data-loader
|
||||
speaker_manager = SpeakerManager()
|
||||
speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples)
|
||||
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
|
||||
|
||||
# init model
|
||||
model = Tacotron(config, ap, tokenizer, speaker_manager)
|
||||
|
|
|
@ -88,7 +88,7 @@ train_samples, eval_samples = load_tts_samples(
|
|||
# init speaker manager for multi-speaker training
|
||||
# it mainly handles speaker-id to speaker-name for the model and the data-loader
|
||||
speaker_manager = SpeakerManager()
|
||||
speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples)
|
||||
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
|
||||
|
||||
# init model
|
||||
model = Tacotron2(config, ap, tokenizer, speaker_manager)
|
||||
|
|
|
@ -88,7 +88,7 @@ train_samples, eval_samples = load_tts_samples(
|
|||
# init speaker manager for multi-speaker training
|
||||
# it mainly handles speaker-id to speaker-name for the model and the data-loader
|
||||
speaker_manager = SpeakerManager()
|
||||
speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples)
|
||||
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
|
||||
|
||||
# init model
|
||||
model = Tacotron2(config, ap, tokenizer, speaker_manager)
|
||||
|
|
|
@ -89,7 +89,7 @@ train_samples, eval_samples = load_tts_samples(
|
|||
# init speaker manager for multi-speaker training
|
||||
# it maps speaker-id to speaker-name in the model and data-loader
|
||||
speaker_manager = SpeakerManager()
|
||||
speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples)
|
||||
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
|
||||
config.model_args.num_speakers = speaker_manager.num_speakers
|
||||
|
||||
# init model
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
|||
|
||||
from tests import get_tests_input_path
|
||||
from TTS.config import load_config
|
||||
from TTS.encoder.utils.generic_utils import setup_speaker_encoder_model
|
||||
from TTS.encoder.utils.generic_utils import setup_encoder_model
|
||||
from TTS.encoder.utils.io import save_checkpoint
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
@ -28,7 +28,7 @@ class SpeakerManagerTest(unittest.TestCase):
|
|||
config.audio.resample = True
|
||||
|
||||
# create a dummy speaker encoder
|
||||
model = setup_speaker_encoder_model(config)
|
||||
model = setup_encoder_model(config)
|
||||
save_checkpoint(model, None, None, get_tests_input_path(), 0)
|
||||
|
||||
# load audio processor and speaker encoder
|
||||
|
@ -38,19 +38,19 @@ class SpeakerManagerTest(unittest.TestCase):
|
|||
# load a sample audio and compute embedding
|
||||
waveform = ap.load_wav(sample_wav_path)
|
||||
mel = ap.melspectrogram(waveform)
|
||||
d_vector = manager.compute_d_vector(mel)
|
||||
d_vector = manager.compute_embeddings(mel)
|
||||
assert d_vector.shape[1] == 256
|
||||
|
||||
# compute d_vector directly from an input file
|
||||
d_vector = manager.compute_d_vector_from_clip(sample_wav_path)
|
||||
d_vector2 = manager.compute_d_vector_from_clip(sample_wav_path)
|
||||
d_vector = manager.compute_embedding_from_clip(sample_wav_path)
|
||||
d_vector2 = manager.compute_embedding_from_clip(sample_wav_path)
|
||||
d_vector = torch.FloatTensor(d_vector)
|
||||
d_vector2 = torch.FloatTensor(d_vector2)
|
||||
assert d_vector.shape[0] == 256
|
||||
assert (d_vector - d_vector2).sum() == 0.0
|
||||
|
||||
# compute d_vector from a list of wav files.
|
||||
d_vector3 = manager.compute_d_vector_from_clip([sample_wav_path, sample_wav_path2])
|
||||
d_vector3 = manager.compute_embedding_from_clip([sample_wav_path, sample_wav_path2])
|
||||
d_vector3 = torch.FloatTensor(d_vector3)
|
||||
assert d_vector3.shape[0] == 256
|
||||
assert (d_vector - d_vector3).sum() != 0.0
|
||||
|
@ -62,14 +62,14 @@ class SpeakerManagerTest(unittest.TestCase):
|
|||
def test_speakers_file_processing():
|
||||
manager = SpeakerManager(d_vectors_file_path=d_vectors_file_path)
|
||||
print(manager.num_speakers)
|
||||
print(manager.d_vector_dim)
|
||||
print(manager.embedding_dim)
|
||||
print(manager.clip_ids)
|
||||
d_vector = manager.get_d_vector_by_clip(manager.clip_ids[0])
|
||||
d_vector = manager.get_embedding_by_clip(manager.clip_ids[0])
|
||||
assert len(d_vector) == 256
|
||||
d_vectors = manager.get_d_vectors_by_speaker(manager.speaker_names[0])
|
||||
d_vectors = manager.get_embeddings_by_name(manager.speaker_names[0])
|
||||
assert len(d_vectors[0]) == 256
|
||||
d_vector1 = manager.get_mean_d_vector(manager.speaker_names[0], num_samples=2, randomize=True)
|
||||
d_vector1 = manager.get_mean_embedding(manager.speaker_names[0], num_samples=2, randomize=True)
|
||||
assert len(d_vector1) == 256
|
||||
d_vector2 = manager.get_mean_d_vector(manager.speaker_names[0], num_samples=2, randomize=False)
|
||||
d_vector2 = manager.get_mean_embedding(manager.speaker_names[0], num_samples=2, randomize=False)
|
||||
assert len(d_vector2) == 256
|
||||
assert np.sum(np.array(d_vector1) - np.array(d_vector2)) != 0
|
||||
|
|
|
@ -86,7 +86,7 @@ class TestGlowTTS(unittest.TestCase):
|
|||
model = GlowTTS(config)
|
||||
model.speaker_manager = speaker_manager
|
||||
model.init_multispeaker(config)
|
||||
self.assertEqual(model.c_in_channels, speaker_manager.d_vector_dim)
|
||||
self.assertEqual(model.c_in_channels, speaker_manager.embedding_dim)
|
||||
self.assertEqual(model.num_speakers, speaker_manager.num_speakers)
|
||||
|
||||
def test_unlock_act_norm_layers(self):
|
||||
|
|
|
@ -7,7 +7,7 @@ from trainer.logging.tensorboard_logger import TensorboardLogger
|
|||
|
||||
from tests import assertHasAttr, assertHasNotAttr, get_tests_data_path, get_tests_input_path, get_tests_output_path
|
||||
from TTS.config import load_config
|
||||
from TTS.encoder.utils.generic_utils import setup_speaker_encoder_model
|
||||
from TTS.encoder.utils.generic_utils import setup_encoder_model
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
from TTS.tts.models.vits import Vits, VitsArgs, amp_to_db, db_to_amp, load_audio, spec_to_mel, wav_to_mel, wav_to_spec
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
|
@ -242,9 +242,9 @@ class TestVits(unittest.TestCase):
|
|||
|
||||
speaker_encoder_config = load_config(SPEAKER_ENCODER_CONFIG)
|
||||
speaker_encoder_config.model_params["use_torch_spec"] = True
|
||||
speaker_encoder = setup_speaker_encoder_model(speaker_encoder_config).to(device)
|
||||
speaker_encoder = setup_encoder_model(speaker_encoder_config).to(device)
|
||||
speaker_manager = SpeakerManager()
|
||||
speaker_manager.speaker_encoder = speaker_encoder
|
||||
speaker_manager.encoder = speaker_encoder
|
||||
|
||||
args = VitsArgs(
|
||||
language_ids_file=LANG_FILE,
|
||||
|
|
|
@ -38,7 +38,7 @@ def test_run_all_models():
|
|||
language_manager = LanguageManager(language_ids_file_path=language_files[0])
|
||||
language_id = language_manager.language_names[0]
|
||||
|
||||
speaker_id = list(speaker_manager.speaker_ids.keys())[0]
|
||||
speaker_id = list(speaker_manager.ids.keys())[0]
|
||||
run_cli(
|
||||
f"tts --model_name {model_name} "
|
||||
f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}" --language_idx "{language_id}" '
|
||||
|
|
Loading…
Reference in New Issue