mirror of https://github.com/coqui-ai/TTS.git
d-vector handling (#1945)
* Update BaseDatasetConfig - Add dataset_name - Chane name to formatter_name * Update compute_embedding - Allow entering dataset by args - Use released model by default - Use the new key format * Update loading * Update recipes * Update other dep code * Update tests * Fixup * Load multiple embedding files * Fix argument names in dep code * Update docs * Fix argument name * Fix linter
This commit is contained in:
parent
371772c355
commit
9e5a469c64
|
@ -6,38 +6,81 @@ import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from TTS.config import load_config
|
from TTS.config import load_config
|
||||||
|
from TTS.config.shared_configs import BaseDatasetConfig
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.utils.managers import save_file
|
from TTS.tts.utils.managers import save_file
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="""Compute embedding vectors for each wav file in a dataset.\n\n"""
|
description="""Compute embedding vectors for each audio file in a dataset and store them keyed by `{dataset_name}#{file_path}` in a .pth file\n\n"""
|
||||||
"""
|
"""
|
||||||
Example runs:
|
Example runs:
|
||||||
python TTS/bin/compute_embeddings.py speaker_encoder_model.pth speaker_encoder_config.json dataset_config.json
|
python TTS/bin/compute_embeddings.py --model_path speaker_encoder_model.pth --config_path speaker_encoder_config.json --config_dataset_path dataset_config.json
|
||||||
|
|
||||||
|
python TTS/bin/compute_embeddings.py --model_path speaker_encoder_model.pth --config_path speaker_encoder_config.json --fomatter vctk --dataset_path /path/to/vctk/dataset --dataset_name my_vctk
|
||||||
""",
|
""",
|
||||||
formatter_class=RawTextHelpFormatter,
|
formatter_class=RawTextHelpFormatter,
|
||||||
)
|
)
|
||||||
parser.add_argument("model_path", type=str, help="Path to model checkpoint file.")
|
parser.add_argument(
|
||||||
parser.add_argument("config_path", type=str, help="Path to model config file.")
|
"--model_path",
|
||||||
parser.add_argument("config_dataset_path", type=str, help="Path to dataset config file.")
|
type=str,
|
||||||
|
help="Path to model checkpoint file. It defaults to the released speaker encoder.",
|
||||||
|
default="https://github.com/coqui-ai/TTS/releases/download/speaker_encoder_model/model_se.pth.tar",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--config_path",
|
||||||
|
type=str,
|
||||||
|
help="Path to model config file. It defaults to the released speaker encoder config.",
|
||||||
|
default="https://github.com/coqui-ai/TTS/releases/download/speaker_encoder_model/config_se.json",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--config_dataset_path",
|
||||||
|
type=str,
|
||||||
|
help="Path to dataset config file. You either need to provide this or `formatter_name`, `dataset_name` and `dataset_path` arguments.",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
parser.add_argument("--output_path", type=str, help="Path for output `pth` or `json` file.", default="speakers.pth")
|
parser.add_argument("--output_path", type=str, help="Path for output `pth` or `json` file.", default="speakers.pth")
|
||||||
parser.add_argument("--old_file", type=str, help="Previous embedding file to only compute new audios.", default=None)
|
parser.add_argument("--old_file", type=str, help="Previous embedding file to only compute new audios.", default=None)
|
||||||
parser.add_argument("--disable_cuda", type=bool, help="Flag to disable cuda.", default=False)
|
parser.add_argument("--disable_cuda", type=bool, help="Flag to disable cuda.", default=False)
|
||||||
parser.add_argument("--no_eval", type=bool, help="Do not compute eval?. Default False", default=False)
|
parser.add_argument("--no_eval", type=bool, help="Do not compute eval?. Default False", default=False)
|
||||||
|
parser.add_argument(
|
||||||
|
"--formatter_name",
|
||||||
|
type=str,
|
||||||
|
help="Name of the formatter to use. You either need to provicder this or `config_dataset_path`",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset_name",
|
||||||
|
type=str,
|
||||||
|
help="Name of the dataset to use. You either need to provicder this or `config_dataset_path`",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset_path",
|
||||||
|
type=str,
|
||||||
|
help="Path to the dataset. You either need to provicder this or `config_dataset_path`",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
use_cuda = torch.cuda.is_available() and not args.disable_cuda
|
use_cuda = torch.cuda.is_available() and not args.disable_cuda
|
||||||
|
|
||||||
c_dataset = load_config(args.config_dataset_path)
|
if args.config_dataset_path is not None:
|
||||||
|
c_dataset = load_config(args.config_dataset_path)
|
||||||
|
meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_split=not args.no_eval)
|
||||||
|
else:
|
||||||
|
c_dataset = BaseDatasetConfig()
|
||||||
|
c_dataset.formatter = args.formatter_name
|
||||||
|
c_dataset.dataset_name = args.dataset_name
|
||||||
|
c_dataset.path = args.dataset_path
|
||||||
|
meta_data_train, meta_data_eval = load_tts_samples(c_dataset, eval_split=not args.no_eval)
|
||||||
|
|
||||||
meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_split=not args.no_eval)
|
|
||||||
|
|
||||||
if meta_data_eval is None:
|
if meta_data_eval is None:
|
||||||
wav_files = meta_data_train
|
samples = meta_data_train
|
||||||
else:
|
else:
|
||||||
wav_files = meta_data_train + meta_data_eval
|
samples = meta_data_train + meta_data_eval
|
||||||
|
|
||||||
encoder_manager = SpeakerManager(
|
encoder_manager = SpeakerManager(
|
||||||
encoder_model_path=args.model_path,
|
encoder_model_path=args.model_path,
|
||||||
|
@ -50,25 +93,25 @@ class_name_key = encoder_manager.encoder_config.class_name_key
|
||||||
|
|
||||||
# compute speaker embeddings
|
# compute speaker embeddings
|
||||||
speaker_mapping = {}
|
speaker_mapping = {}
|
||||||
for idx, wav_file in enumerate(tqdm(wav_files)):
|
for idx, fields in enumerate(tqdm(samples)):
|
||||||
if isinstance(wav_file, dict):
|
class_name = fields[class_name_key]
|
||||||
class_name = wav_file[class_name_key]
|
audio_file = fields["audio_file"]
|
||||||
wav_file = wav_file["audio_file"]
|
dataset_name = fields["dataset_name"]
|
||||||
else:
|
root_path = fields["root_path"]
|
||||||
class_name = None
|
|
||||||
|
|
||||||
wav_file_name = os.path.basename(wav_file)
|
relfilepath = os.path.splitext(audio_file.replace(root_path, ""))[0]
|
||||||
if args.old_file is not None and wav_file_name in encoder_manager.clip_ids:
|
embedding_key = f"{dataset_name}#{relfilepath}"
|
||||||
|
if args.old_file is not None and embedding_key in encoder_manager.clip_ids:
|
||||||
# get the embedding from the old file
|
# get the embedding from the old file
|
||||||
embedd = encoder_manager.get_embedding_by_clip(wav_file_name)
|
embedd = encoder_manager.get_embedding_by_clip(embedding_key)
|
||||||
else:
|
else:
|
||||||
# extract the embedding
|
# extract the embedding
|
||||||
embedd = encoder_manager.compute_embedding_from_clip(wav_file)
|
embedd = encoder_manager.compute_embedding_from_clip(audio_file)
|
||||||
|
|
||||||
# create speaker_mapping if target dataset is defined
|
# create speaker_mapping if target dataset is defined
|
||||||
speaker_mapping[wav_file_name] = {}
|
speaker_mapping[embedding_key] = {}
|
||||||
speaker_mapping[wav_file_name]["name"] = class_name
|
speaker_mapping[embedding_key]["name"] = class_name
|
||||||
speaker_mapping[wav_file_name]["embedding"] = embedd
|
speaker_mapping[embedding_key]["embedding"] = embedd
|
||||||
|
|
||||||
if speaker_mapping:
|
if speaker_mapping:
|
||||||
# save speaker_mapping if target dataset is defined
|
# save speaker_mapping if target dataset is defined
|
||||||
|
|
|
@ -37,7 +37,7 @@ def setup_loader(ap, r, verbose=False):
|
||||||
precompute_num_workers=0,
|
precompute_num_workers=0,
|
||||||
use_noise_augment=False,
|
use_noise_augment=False,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
speaker_id_mapping=speaker_manager.ids if c.use_speaker_embedding else None,
|
speaker_id_mapping=speaker_manager.name_to_id if c.use_speaker_embedding else None,
|
||||||
d_vector_mapping=speaker_manager.embeddings if c.use_d_vector_file else None,
|
d_vector_mapping=speaker_manager.embeddings if c.use_d_vector_file else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -323,7 +323,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
||||||
print(
|
print(
|
||||||
" > Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model."
|
" > Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model."
|
||||||
)
|
)
|
||||||
print(synthesizer.tts_model.speaker_manager.ids)
|
print(synthesizer.tts_model.speaker_manager.name_to_id)
|
||||||
return
|
return
|
||||||
|
|
||||||
# query langauge ids of a multi-lingual model.
|
# query langauge ids of a multi-lingual model.
|
||||||
|
|
|
@ -193,21 +193,24 @@ class BaseDatasetConfig(Coqpit):
|
||||||
"""Base config for TTS datasets.
|
"""Base config for TTS datasets.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name (str):
|
formatter (str):
|
||||||
Dataset name that defines the preprocessor in use. Defaults to None.
|
Formatter name that defines used formatter in ```TTS.tts.datasets.formatter```. Defaults to `""`.
|
||||||
|
|
||||||
|
dataset_name (str):
|
||||||
|
Unique name for the dataset. Defaults to `""`.
|
||||||
|
|
||||||
path (str):
|
path (str):
|
||||||
Root path to the dataset files. Defaults to None.
|
Root path to the dataset files. Defaults to `""`.
|
||||||
|
|
||||||
meta_file_train (str):
|
meta_file_train (str):
|
||||||
Name of the dataset meta file. Or a list of speakers to be ignored at training for multi-speaker datasets.
|
Name of the dataset meta file. Or a list of speakers to be ignored at training for multi-speaker datasets.
|
||||||
Defaults to None.
|
Defaults to `""`.
|
||||||
|
|
||||||
ignored_speakers (List):
|
ignored_speakers (List):
|
||||||
List of speakers IDs that are not used at the training. Default None.
|
List of speakers IDs that are not used at the training. Default None.
|
||||||
|
|
||||||
language (str):
|
language (str):
|
||||||
Language code of the dataset. If defined, it overrides `phoneme_language`. Defaults to None.
|
Language code of the dataset. If defined, it overrides `phoneme_language`. Defaults to `""`.
|
||||||
|
|
||||||
meta_file_val (str):
|
meta_file_val (str):
|
||||||
Name of the dataset meta file that defines the instances used at validation.
|
Name of the dataset meta file that defines the instances used at validation.
|
||||||
|
@ -217,7 +220,8 @@ class BaseDatasetConfig(Coqpit):
|
||||||
train the duration predictor.
|
train the duration predictor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name: str = ""
|
formatter: str = ""
|
||||||
|
dataset_name: str = ""
|
||||||
path: str = ""
|
path: str = ""
|
||||||
meta_file_train: str = ""
|
meta_file_train: str = ""
|
||||||
ignored_speakers: List[str] = None
|
ignored_speakers: List[str] = None
|
||||||
|
@ -230,7 +234,7 @@ class BaseDatasetConfig(Coqpit):
|
||||||
):
|
):
|
||||||
"""Check config fields"""
|
"""Check config fields"""
|
||||||
c = asdict(self)
|
c = asdict(self)
|
||||||
check_argument("name", c, restricted=True)
|
check_argument("formatter", c, restricted=True)
|
||||||
check_argument("path", c, restricted=True)
|
check_argument("path", c, restricted=True)
|
||||||
check_argument("meta_file_train", c, restricted=True)
|
check_argument("meta_file_train", c, restricted=True)
|
||||||
check_argument("meta_file_val", c, restricted=False)
|
check_argument("meta_file_val", c, restricted=False)
|
||||||
|
|
|
@ -147,7 +147,7 @@ def index():
|
||||||
"index.html",
|
"index.html",
|
||||||
show_details=args.show_details,
|
show_details=args.show_details,
|
||||||
use_multi_speaker=use_multi_speaker,
|
use_multi_speaker=use_multi_speaker,
|
||||||
speaker_ids=speaker_manager.ids if speaker_manager is not None else None,
|
speaker_ids=speaker_manager.name_to_id if speaker_manager is not None else None,
|
||||||
use_gst=use_gst,
|
use_gst=use_gst,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -97,7 +97,8 @@ def load_tts_samples(
|
||||||
if not isinstance(datasets, list):
|
if not isinstance(datasets, list):
|
||||||
datasets = [datasets]
|
datasets = [datasets]
|
||||||
for dataset in datasets:
|
for dataset in datasets:
|
||||||
name = dataset["name"]
|
formatter_name = dataset["formatter"]
|
||||||
|
dataset_name = dataset["dataset_name"]
|
||||||
root_path = dataset["path"]
|
root_path = dataset["path"]
|
||||||
meta_file_train = dataset["meta_file_train"]
|
meta_file_train = dataset["meta_file_train"]
|
||||||
meta_file_val = dataset["meta_file_val"]
|
meta_file_val = dataset["meta_file_val"]
|
||||||
|
@ -106,17 +107,18 @@ def load_tts_samples(
|
||||||
|
|
||||||
# setup the right data processor
|
# setup the right data processor
|
||||||
if formatter is None:
|
if formatter is None:
|
||||||
formatter = _get_formatter_by_name(name)
|
formatter = _get_formatter_by_name(formatter_name)
|
||||||
# load train set
|
# load train set
|
||||||
meta_data_train = formatter(root_path, meta_file_train, ignored_speakers=ignored_speakers)
|
meta_data_train = formatter(root_path, meta_file_train, ignored_speakers=ignored_speakers)
|
||||||
meta_data_train = [{**item, **{"language": language}} for item in meta_data_train]
|
meta_data_train = [{**item, **{"language": language, "dataset_name": dataset_name}} for item in meta_data_train]
|
||||||
|
|
||||||
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
|
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
|
||||||
# load evaluation split if set
|
# load evaluation split if set
|
||||||
if eval_split:
|
if eval_split:
|
||||||
if meta_file_val:
|
if meta_file_val:
|
||||||
meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers)
|
meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers)
|
||||||
meta_data_eval = [{**item, **{"language": language}} for item in meta_data_eval]
|
meta_data_eval = [
|
||||||
|
{**item, **{"language": language, "dataset_name": dataset_name}} for item in meta_data_eval
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
meta_data_eval, meta_data_train = split_dataset(meta_data_train, eval_split_max_size, eval_split_size)
|
meta_data_eval, meta_data_train = split_dataset(meta_data_train, eval_split_max_size, eval_split_size)
|
||||||
meta_data_eval_all += meta_data_eval
|
meta_data_eval_all += meta_data_eval
|
||||||
|
|
|
@ -144,11 +144,11 @@ class BaseTTS(BaseTrainerModel):
|
||||||
if speaker_name is None:
|
if speaker_name is None:
|
||||||
speaker_id = self.speaker_manager.get_random_id()
|
speaker_id = self.speaker_manager.get_random_id()
|
||||||
else:
|
else:
|
||||||
speaker_id = self.speaker_manager.ids[speaker_name]
|
speaker_id = self.speaker_manager.name_to_id[speaker_name]
|
||||||
|
|
||||||
# get language id
|
# get language id
|
||||||
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
|
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
|
||||||
language_id = self.language_manager.ids[language_name]
|
language_id = self.language_manager.name_to_id[language_name]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"text": text,
|
"text": text,
|
||||||
|
@ -288,11 +288,13 @@ class BaseTTS(BaseTrainerModel):
|
||||||
# setup multi-speaker attributes
|
# setup multi-speaker attributes
|
||||||
if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
|
if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
|
||||||
if hasattr(config, "model_args"):
|
if hasattr(config, "model_args"):
|
||||||
speaker_id_mapping = self.speaker_manager.ids if config.model_args.use_speaker_embedding else None
|
speaker_id_mapping = (
|
||||||
|
self.speaker_manager.name_to_id if config.model_args.use_speaker_embedding else None
|
||||||
|
)
|
||||||
d_vector_mapping = self.speaker_manager.embeddings if config.model_args.use_d_vector_file else None
|
d_vector_mapping = self.speaker_manager.embeddings if config.model_args.use_d_vector_file else None
|
||||||
config.use_d_vector_file = config.model_args.use_d_vector_file
|
config.use_d_vector_file = config.model_args.use_d_vector_file
|
||||||
else:
|
else:
|
||||||
speaker_id_mapping = self.speaker_manager.ids if config.use_speaker_embedding else None
|
speaker_id_mapping = self.speaker_manager.name_to_id if config.use_speaker_embedding else None
|
||||||
d_vector_mapping = self.speaker_manager.embeddings if config.use_d_vector_file else None
|
d_vector_mapping = self.speaker_manager.embeddings if config.use_d_vector_file else None
|
||||||
else:
|
else:
|
||||||
speaker_id_mapping = None
|
speaker_id_mapping = None
|
||||||
|
@ -300,7 +302,7 @@ class BaseTTS(BaseTrainerModel):
|
||||||
|
|
||||||
# setup multi-lingual attributes
|
# setup multi-lingual attributes
|
||||||
if hasattr(self, "language_manager") and self.language_manager is not None:
|
if hasattr(self, "language_manager") and self.language_manager is not None:
|
||||||
language_id_mapping = self.language_manager.ids if self.args.use_language_embedding else None
|
language_id_mapping = self.language_manager.name_to_id if self.args.use_language_embedding else None
|
||||||
else:
|
else:
|
||||||
language_id_mapping = None
|
language_id_mapping = None
|
||||||
|
|
||||||
|
@ -363,7 +365,7 @@ class BaseTTS(BaseTrainerModel):
|
||||||
aux_inputs = {
|
aux_inputs = {
|
||||||
"speaker_id": None
|
"speaker_id": None
|
||||||
if not self.config.use_speaker_embedding
|
if not self.config.use_speaker_embedding
|
||||||
else random.sample(sorted(self.speaker_manager.ids.values()), 1),
|
else random.sample(sorted(self.speaker_manager.name_to_id.values()), 1),
|
||||||
"d_vector": d_vector,
|
"d_vector": d_vector,
|
||||||
"style_wav": None, # TODO: handle GST style input
|
"style_wav": None, # TODO: handle GST style input
|
||||||
}
|
}
|
||||||
|
|
|
@ -1185,7 +1185,6 @@ class Vits(BaseTTS):
|
||||||
y_lengths = torch.tensor([y.size(-1)]).to(y.device)
|
y_lengths = torch.tensor([y.size(-1)]).to(y.device)
|
||||||
speaker_cond_src = reference_speaker_id if reference_speaker_id is not None else reference_d_vector
|
speaker_cond_src = reference_speaker_id if reference_speaker_id is not None else reference_d_vector
|
||||||
speaker_cond_tgt = speaker_id if speaker_id is not None else d_vector
|
speaker_cond_tgt = speaker_id if speaker_id is not None else d_vector
|
||||||
# print(y.shape, y_lengths.shape)
|
|
||||||
wav, _, _ = self.voice_conversion(y, y_lengths, speaker_cond_src, speaker_cond_tgt)
|
wav, _, _ = self.voice_conversion(y, y_lengths, speaker_cond_src, speaker_cond_tgt)
|
||||||
return wav
|
return wav
|
||||||
|
|
||||||
|
@ -1402,11 +1401,11 @@ class Vits(BaseTTS):
|
||||||
if speaker_name is None:
|
if speaker_name is None:
|
||||||
speaker_id = self.speaker_manager.get_random_id()
|
speaker_id = self.speaker_manager.get_random_id()
|
||||||
else:
|
else:
|
||||||
speaker_id = self.speaker_manager.ids[speaker_name]
|
speaker_id = self.speaker_manager.name_to_id[speaker_name]
|
||||||
|
|
||||||
# get language id
|
# get language id
|
||||||
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
|
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
|
||||||
language_id = self.language_manager.ids[language_name]
|
language_id = self.language_manager.name_to_id[language_name]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"text": text,
|
"text": text,
|
||||||
|
@ -1461,8 +1460,8 @@ class Vits(BaseTTS):
|
||||||
d_vectors = None
|
d_vectors = None
|
||||||
|
|
||||||
# get numerical speaker ids from speaker names
|
# get numerical speaker ids from speaker names
|
||||||
if self.speaker_manager is not None and self.speaker_manager.ids and self.args.use_speaker_embedding:
|
if self.speaker_manager is not None and self.speaker_manager.name_to_id and self.args.use_speaker_embedding:
|
||||||
speaker_ids = [self.speaker_manager.ids[sn] for sn in batch["speaker_names"]]
|
speaker_ids = [self.speaker_manager.name_to_id[sn] for sn in batch["speaker_names"]]
|
||||||
|
|
||||||
if speaker_ids is not None:
|
if speaker_ids is not None:
|
||||||
speaker_ids = torch.LongTensor(speaker_ids)
|
speaker_ids = torch.LongTensor(speaker_ids)
|
||||||
|
@ -1475,8 +1474,8 @@ class Vits(BaseTTS):
|
||||||
d_vectors = torch.FloatTensor(d_vectors)
|
d_vectors = torch.FloatTensor(d_vectors)
|
||||||
|
|
||||||
# get language ids from language names
|
# get language ids from language names
|
||||||
if self.language_manager is not None and self.language_manager.ids and self.args.use_language_embedding:
|
if self.language_manager is not None and self.language_manager.name_to_id and self.args.use_language_embedding:
|
||||||
language_ids = [self.language_manager.ids[ln] for ln in batch["language_names"]]
|
language_ids = [self.language_manager.name_to_id[ln] for ln in batch["language_names"]]
|
||||||
|
|
||||||
if language_ids is not None:
|
if language_ids is not None:
|
||||||
language_ids = torch.LongTensor(language_ids)
|
language_ids = torch.LongTensor(language_ids)
|
||||||
|
|
|
@ -37,11 +37,11 @@ class LanguageManager(BaseIDManager):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_languages(self) -> int:
|
def num_languages(self) -> int:
|
||||||
return len(list(self.ids.keys()))
|
return len(list(self.name_to_id.keys()))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def language_names(self) -> List:
|
def language_names(self) -> List:
|
||||||
return list(self.ids.keys())
|
return list(self.name_to_id.keys())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_language_ids_from_config(c: Coqpit) -> Dict:
|
def parse_language_ids_from_config(c: Coqpit) -> Dict:
|
||||||
|
@ -67,7 +67,7 @@ class LanguageManager(BaseIDManager):
|
||||||
Args:
|
Args:
|
||||||
c (Coqpit): Config.
|
c (Coqpit): Config.
|
||||||
"""
|
"""
|
||||||
self.ids = self.parse_language_ids_from_config(c)
|
self.name_to_id = self.parse_language_ids_from_config(c)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_ids_from_data(items: List, parse_key: str) -> Any:
|
def parse_ids_from_data(items: List, parse_key: str) -> Any:
|
||||||
|
@ -82,7 +82,7 @@ class LanguageManager(BaseIDManager):
|
||||||
Args:
|
Args:
|
||||||
file_path (str): Path to the output file.
|
file_path (str): Path to the output file.
|
||||||
"""
|
"""
|
||||||
self._save_json(file_path, self.ids)
|
self._save_json(file_path, self.name_to_id)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def init_from_config(config: Coqpit) -> "LanguageManager":
|
def init_from_config(config: Coqpit) -> "LanguageManager":
|
||||||
|
|
|
@ -39,7 +39,7 @@ class BaseIDManager:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, id_file_path: str = ""):
|
def __init__(self, id_file_path: str = ""):
|
||||||
self.ids = {}
|
self.name_to_id = {}
|
||||||
|
|
||||||
if id_file_path:
|
if id_file_path:
|
||||||
self.load_ids_from_file(id_file_path)
|
self.load_ids_from_file(id_file_path)
|
||||||
|
@ -60,7 +60,7 @@ class BaseIDManager:
|
||||||
Args:
|
Args:
|
||||||
items (List): Data sampled returned by `load_tts_samples()`.
|
items (List): Data sampled returned by `load_tts_samples()`.
|
||||||
"""
|
"""
|
||||||
self.ids = self.parse_ids_from_data(items, parse_key=parse_key)
|
self.name_to_id = self.parse_ids_from_data(items, parse_key=parse_key)
|
||||||
|
|
||||||
def load_ids_from_file(self, file_path: str) -> None:
|
def load_ids_from_file(self, file_path: str) -> None:
|
||||||
"""Set IDs from a file.
|
"""Set IDs from a file.
|
||||||
|
@ -68,7 +68,7 @@ class BaseIDManager:
|
||||||
Args:
|
Args:
|
||||||
file_path (str): Path to the file.
|
file_path (str): Path to the file.
|
||||||
"""
|
"""
|
||||||
self.ids = load_file(file_path)
|
self.name_to_id = load_file(file_path)
|
||||||
|
|
||||||
def save_ids_to_file(self, file_path: str) -> None:
|
def save_ids_to_file(self, file_path: str) -> None:
|
||||||
"""Save IDs to a json file.
|
"""Save IDs to a json file.
|
||||||
|
@ -76,7 +76,7 @@ class BaseIDManager:
|
||||||
Args:
|
Args:
|
||||||
file_path (str): Path to the output file.
|
file_path (str): Path to the output file.
|
||||||
"""
|
"""
|
||||||
save_file(self.ids, file_path)
|
save_file(self.name_to_id, file_path)
|
||||||
|
|
||||||
def get_random_id(self) -> Any:
|
def get_random_id(self) -> Any:
|
||||||
"""Get a random embedding.
|
"""Get a random embedding.
|
||||||
|
@ -86,8 +86,8 @@ class BaseIDManager:
|
||||||
Returns:
|
Returns:
|
||||||
np.ndarray: embedding.
|
np.ndarray: embedding.
|
||||||
"""
|
"""
|
||||||
if self.ids:
|
if self.name_to_id:
|
||||||
return self.ids[random.choices(list(self.ids.keys()))[0]]
|
return self.name_to_id[random.choices(list(self.name_to_id.keys()))[0]]
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -109,11 +109,27 @@ class BaseIDManager:
|
||||||
class EmbeddingManager(BaseIDManager):
|
class EmbeddingManager(BaseIDManager):
|
||||||
"""Base `Embedding` Manager class. Every new `Embedding` manager must inherit this.
|
"""Base `Embedding` Manager class. Every new `Embedding` manager must inherit this.
|
||||||
It defines common `Embedding` manager specific functions.
|
It defines common `Embedding` manager specific functions.
|
||||||
|
|
||||||
|
It expects embeddings files in the following format:
|
||||||
|
|
||||||
|
::
|
||||||
|
|
||||||
|
{
|
||||||
|
'audio_file_key':{
|
||||||
|
'name': 'category_name',
|
||||||
|
'embedding'[<embedding_values>]
|
||||||
|
},
|
||||||
|
...
|
||||||
|
}
|
||||||
|
|
||||||
|
`audio_file_key` is a unique key to the audio file in the dataset. It can be the path to the file or any other unique key.
|
||||||
|
`embedding` is the embedding vector of the audio file.
|
||||||
|
`name` can be name of the speaker of the audio file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
embedding_file_path: str = "",
|
embedding_file_path: Union[str, List[str]] = "",
|
||||||
id_file_path: str = "",
|
id_file_path: str = "",
|
||||||
encoder_model_path: str = "",
|
encoder_model_path: str = "",
|
||||||
encoder_config_path: str = "",
|
encoder_config_path: str = "",
|
||||||
|
@ -129,11 +145,24 @@ class EmbeddingManager(BaseIDManager):
|
||||||
self.use_cuda = use_cuda
|
self.use_cuda = use_cuda
|
||||||
|
|
||||||
if embedding_file_path:
|
if embedding_file_path:
|
||||||
self.load_embeddings_from_file(embedding_file_path)
|
if isinstance(embedding_file_path, list):
|
||||||
|
self.load_embeddings_from_list_of_files(embedding_file_path)
|
||||||
|
else:
|
||||||
|
self.load_embeddings_from_file(embedding_file_path)
|
||||||
|
|
||||||
if encoder_model_path and encoder_config_path:
|
if encoder_model_path and encoder_config_path:
|
||||||
self.init_encoder(encoder_model_path, encoder_config_path, use_cuda)
|
self.init_encoder(encoder_model_path, encoder_config_path, use_cuda)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_embeddings(self):
|
||||||
|
"""Get number of embeddings."""
|
||||||
|
return len(self.embeddings)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_names(self):
|
||||||
|
"""Get number of embeddings."""
|
||||||
|
return len(self.embeddings_by_names)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def embedding_dim(self):
|
def embedding_dim(self):
|
||||||
"""Dimensionality of embeddings. If embeddings are not loaded, returns zero."""
|
"""Dimensionality of embeddings. If embeddings are not loaded, returns zero."""
|
||||||
|
@ -141,6 +170,11 @@ class EmbeddingManager(BaseIDManager):
|
||||||
return len(self.embeddings[list(self.embeddings.keys())[0]]["embedding"])
|
return len(self.embeddings[list(self.embeddings.keys())[0]]["embedding"])
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embedding_names(self):
|
||||||
|
"""Get embedding names."""
|
||||||
|
return list(self.embeddings_by_names.keys())
|
||||||
|
|
||||||
def save_embeddings_to_file(self, file_path: str) -> None:
|
def save_embeddings_to_file(self, file_path: str) -> None:
|
||||||
"""Save embeddings to a json file.
|
"""Save embeddings to a json file.
|
||||||
|
|
||||||
|
@ -149,20 +183,57 @@ class EmbeddingManager(BaseIDManager):
|
||||||
"""
|
"""
|
||||||
save_file(self.embeddings, file_path)
|
save_file(self.embeddings, file_path)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def read_embeddings_from_file(file_path: str):
|
||||||
|
"""Load embeddings from a json file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path (str): Path to the file.
|
||||||
|
"""
|
||||||
|
embeddings = load_file(file_path)
|
||||||
|
speakers = sorted({x["name"] for x in embeddings.values()})
|
||||||
|
name_to_id = {name: i for i, name in enumerate(speakers)}
|
||||||
|
clip_ids = list(set(sorted(clip_name for clip_name in embeddings.keys())))
|
||||||
|
# cache embeddings_by_names for fast inference using a bigger speakers.json
|
||||||
|
embeddings_by_names = {}
|
||||||
|
for x in embeddings.values():
|
||||||
|
if x["name"] not in embeddings_by_names.keys():
|
||||||
|
embeddings_by_names[x["name"]] = [x["embedding"]]
|
||||||
|
else:
|
||||||
|
embeddings_by_names[x["name"]].append(x["embedding"])
|
||||||
|
return name_to_id, clip_ids, embeddings, embeddings_by_names
|
||||||
|
|
||||||
def load_embeddings_from_file(self, file_path: str) -> None:
|
def load_embeddings_from_file(self, file_path: str) -> None:
|
||||||
"""Load embeddings from a json file.
|
"""Load embeddings from a json file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_path (str): Path to the target json file.
|
file_path (str): Path to the target json file.
|
||||||
"""
|
"""
|
||||||
self.embeddings = load_file(file_path)
|
self.name_to_id, self.clip_ids, self.embeddings, self.embeddings_by_names = self.read_embeddings_from_file(
|
||||||
|
file_path
|
||||||
|
)
|
||||||
|
|
||||||
speakers = sorted({x["name"] for x in self.embeddings.values()})
|
def load_embeddings_from_list_of_files(self, file_paths: List[str]) -> None:
|
||||||
self.ids = {name: i for i, name in enumerate(speakers)}
|
"""Load embeddings from a list of json files and don't allow duplicate keys.
|
||||||
|
|
||||||
self.clip_ids = list(set(sorted(clip_name for clip_name in self.embeddings.keys())))
|
Args:
|
||||||
# cache embeddings_by_names for fast inference using a bigger speakers.json
|
file_paths (List[str]): List of paths to the target json files.
|
||||||
self.embeddings_by_names = self.get_embeddings_by_names()
|
"""
|
||||||
|
self.name_to_id = {}
|
||||||
|
self.clip_ids = []
|
||||||
|
self.embeddings_by_names = {}
|
||||||
|
self.embeddings = {}
|
||||||
|
for file_path in file_paths:
|
||||||
|
ids, clip_ids, embeddings, embeddings_by_names = self.read_embeddings_from_file(file_path)
|
||||||
|
# check colliding keys
|
||||||
|
duplicates = set(self.embeddings.keys()) & set(embeddings.keys())
|
||||||
|
if duplicates:
|
||||||
|
raise ValueError(f" [!] Duplicate embedding names <{duplicates}> in {file_path}")
|
||||||
|
# store values
|
||||||
|
self.name_to_id.update(ids)
|
||||||
|
self.clip_ids.extend(clip_ids)
|
||||||
|
self.embeddings_by_names.update(embeddings_by_names)
|
||||||
|
self.embeddings.update(embeddings)
|
||||||
|
|
||||||
def get_embedding_by_clip(self, clip_idx: str) -> List:
|
def get_embedding_by_clip(self, clip_idx: str) -> List:
|
||||||
"""Get embedding by clip ID.
|
"""Get embedding by clip ID.
|
||||||
|
|
|
@ -73,14 +73,14 @@ class SpeakerManager(EmbeddingManager):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_speakers(self):
|
def num_speakers(self):
|
||||||
return len(self.ids)
|
return len(self.name_to_id)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def speaker_names(self):
|
def speaker_names(self):
|
||||||
return list(self.ids.keys())
|
return list(self.name_to_id.keys())
|
||||||
|
|
||||||
def get_speakers(self) -> List:
|
def get_speakers(self) -> List:
|
||||||
return self.ids
|
return self.name_to_id
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def init_from_config(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "SpeakerManager":
|
def init_from_config(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "SpeakerManager":
|
||||||
|
@ -182,10 +182,10 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
|
||||||
speaker_manager.load_embeddings_from_file(c.d_vector_file)
|
speaker_manager.load_embeddings_from_file(c.d_vector_file)
|
||||||
speaker_manager.load_embeddings_from_file(speakers_file)
|
speaker_manager.load_embeddings_from_file(speakers_file)
|
||||||
elif not c.use_d_vector_file: # restor speaker manager with speaker ID file.
|
elif not c.use_d_vector_file: # restor speaker manager with speaker ID file.
|
||||||
speaker_ids_from_data = speaker_manager.ids
|
speaker_ids_from_data = speaker_manager.name_to_id
|
||||||
speaker_manager.load_ids_from_file(speakers_file)
|
speaker_manager.load_ids_from_file(speakers_file)
|
||||||
assert all(
|
assert all(
|
||||||
speaker in speaker_manager.ids for speaker in speaker_ids_from_data
|
speaker in speaker_manager.name_to_id for speaker in speaker_ids_from_data
|
||||||
), " [!] You cannot introduce new speakers to a pre-trained model."
|
), " [!] You cannot introduce new speakers to a pre-trained model."
|
||||||
elif c.use_d_vector_file and c.d_vector_file:
|
elif c.use_d_vector_file and c.d_vector_file:
|
||||||
# new speaker manager with external speaker embeddings.
|
# new speaker manager with external speaker embeddings.
|
||||||
|
@ -199,7 +199,7 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
|
||||||
if speaker_manager.num_speakers > 0:
|
if speaker_manager.num_speakers > 0:
|
||||||
print(
|
print(
|
||||||
" > Speaker manager is loaded with {} speakers: {}".format(
|
" > Speaker manager is loaded with {} speakers: {}".format(
|
||||||
speaker_manager.num_speakers, ", ".join(speaker_manager.ids)
|
speaker_manager.num_speakers, ", ".join(speaker_manager.name_to_id)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -212,7 +212,7 @@ class Synthesizer(object):
|
||||||
# handle multi-speaker
|
# handle multi-speaker
|
||||||
speaker_embedding = None
|
speaker_embedding = None
|
||||||
speaker_id = None
|
speaker_id = None
|
||||||
if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "ids"):
|
if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "name_to_id"):
|
||||||
if speaker_name and isinstance(speaker_name, str):
|
if speaker_name and isinstance(speaker_name, str):
|
||||||
if self.tts_config.use_d_vector_file:
|
if self.tts_config.use_d_vector_file:
|
||||||
# get the average speaker embedding from the saved d_vectors.
|
# get the average speaker embedding from the saved d_vectors.
|
||||||
|
@ -222,7 +222,7 @@ class Synthesizer(object):
|
||||||
speaker_embedding = np.array(speaker_embedding)[None, :] # [1 x embedding_dim]
|
speaker_embedding = np.array(speaker_embedding)[None, :] # [1 x embedding_dim]
|
||||||
else:
|
else:
|
||||||
# get speaker idx from the speaker name
|
# get speaker idx from the speaker name
|
||||||
speaker_id = self.tts_model.speaker_manager.ids[speaker_name]
|
speaker_id = self.tts_model.speaker_manager.name_to_id[speaker_name]
|
||||||
|
|
||||||
elif not speaker_name and not speaker_wav:
|
elif not speaker_name and not speaker_wav:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -244,7 +244,7 @@ class Synthesizer(object):
|
||||||
hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None
|
hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None
|
||||||
):
|
):
|
||||||
if language_name and isinstance(language_name, str):
|
if language_name and isinstance(language_name, str):
|
||||||
language_id = self.tts_model.language_manager.ids[language_name]
|
language_id = self.tts_model.language_manager.name_to_id[language_name]
|
||||||
|
|
||||||
elif not language_name:
|
elif not language_name:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -316,7 +316,7 @@ class Synthesizer(object):
|
||||||
# get the speaker embedding or speaker id for the reference wav file
|
# get the speaker embedding or speaker id for the reference wav file
|
||||||
reference_speaker_embedding = None
|
reference_speaker_embedding = None
|
||||||
reference_speaker_id = None
|
reference_speaker_id = None
|
||||||
if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "ids"):
|
if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "name_to_id"):
|
||||||
if reference_speaker_name and isinstance(reference_speaker_name, str):
|
if reference_speaker_name and isinstance(reference_speaker_name, str):
|
||||||
if self.tts_config.use_d_vector_file:
|
if self.tts_config.use_d_vector_file:
|
||||||
# get the speaker embedding from the saved d_vectors.
|
# get the speaker embedding from the saved d_vectors.
|
||||||
|
@ -328,12 +328,11 @@ class Synthesizer(object):
|
||||||
] # [1 x embedding_dim]
|
] # [1 x embedding_dim]
|
||||||
else:
|
else:
|
||||||
# get speaker idx from the speaker name
|
# get speaker idx from the speaker name
|
||||||
reference_speaker_id = self.tts_model.speaker_manager.ids[reference_speaker_name]
|
reference_speaker_id = self.tts_model.speaker_manager.name_to_id[reference_speaker_name]
|
||||||
else:
|
else:
|
||||||
reference_speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(
|
reference_speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(
|
||||||
reference_wav
|
reference_wav
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = transfer_voice(
|
outputs = transfer_voice(
|
||||||
model=self.tts_model,
|
model=self.tts_model,
|
||||||
CONFIG=self.tts_config,
|
CONFIG=self.tts_config,
|
||||||
|
|
|
@ -53,7 +53,7 @@ We tried to collect common issues and questions we receive about 🐸TTS. It is
|
||||||
"mixed_precision": false,
|
"mixed_precision": false,
|
||||||
"output_path": "recipes/ljspeech/glow_tts/",
|
"output_path": "recipes/ljspeech/glow_tts/",
|
||||||
"test_sentences": ["Test this sentence.", "This test sentence.", "Sentence this test."],
|
"test_sentences": ["Test this sentence.", "This test sentence.", "Sentence this test."],
|
||||||
"datasets":[{"name": "ljspeech", "meta_file_train":"metadata.csv", "path": "recipes/ljspeech/LJSpeech-1.1/"}]
|
"datasets":[{"formatter": "ljspeech", "meta_file_train":"metadata.csv", "path": "recipes/ljspeech/LJSpeech-1.1/"}]
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -88,7 +88,7 @@ from TTS.tts.datasets import load_tts_samples
|
||||||
|
|
||||||
# dataset config for one of the pre-defined datasets
|
# dataset config for one of the pre-defined datasets
|
||||||
dataset_config = BaseDatasetConfig(
|
dataset_config = BaseDatasetConfig(
|
||||||
name="vctk", meta_file_train="", language="en-us", path="dataset-path")
|
formatter="vctk", meta_file_train="", language="en-us", path="dataset-path")
|
||||||
)
|
)
|
||||||
|
|
||||||
# load training samples
|
# load training samples
|
||||||
|
|
|
@ -84,7 +84,7 @@ We still support running training from CLI like in the old days. The same traini
|
||||||
"print_eval": true,
|
"print_eval": true,
|
||||||
"mixed_precision": false,
|
"mixed_precision": false,
|
||||||
"output_path": "recipes/ljspeech/glow_tts/",
|
"output_path": "recipes/ljspeech/glow_tts/",
|
||||||
"datasets":[{"name": "ljspeech", "meta_file_train":"metadata.csv", "path": "recipes/ljspeech/LJSpeech-1.1/"}]
|
"datasets":[{"formatter": "ljspeech", "meta_file_train":"metadata.csv", "path": "recipes/ljspeech/LJSpeech-1.1/"}]
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -120,6 +120,3 @@ $ tts-server -h # see the help
|
||||||
$ tts-server --list_models # list the available models.
|
$ tts-server --list_models # list the available models.
|
||||||
```
|
```
|
||||||

|

|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@ output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
data_path = "/srv/data/"
|
data_path = "/srv/data/"
|
||||||
|
|
||||||
# Using LJSpeech like dataset processing for the blizzard dataset
|
# Using LJSpeech like dataset processing for the blizzard dataset
|
||||||
dataset_config = BaseDatasetConfig(name="ljspeech", meta_file_train="metadata.csv", path=data_path)
|
dataset_config = BaseDatasetConfig(formatter="ljspeech", meta_file_train="metadata.csv", path=data_path)
|
||||||
|
|
||||||
audio_config = BaseAudioConfig(
|
audio_config = BaseAudioConfig(
|
||||||
sample_rate=24000,
|
sample_rate=24000,
|
||||||
|
|
|
@ -16,7 +16,7 @@ data_path = "/srv/data/blizzard2013/segmented"
|
||||||
|
|
||||||
# Using LJSpeech like dataset processing for the blizzard dataset
|
# Using LJSpeech like dataset processing for the blizzard dataset
|
||||||
dataset_config = BaseDatasetConfig(
|
dataset_config = BaseDatasetConfig(
|
||||||
name="ljspeech",
|
formatter="ljspeech",
|
||||||
meta_file_train="metadata.csv",
|
meta_file_train="metadata.csv",
|
||||||
path=data_path,
|
path=data_path,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
{
|
{
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
"name": "kokoro",
|
"formatter": "kokoro",
|
||||||
"path": "DEFINE THIS",
|
"path": "DEFINE THIS",
|
||||||
"meta_file_train": "metadata.csv",
|
"meta_file_train": "metadata.csv",
|
||||||
"meta_file_val": null
|
"meta_file_val": null
|
||||||
|
@ -119,7 +119,7 @@
|
||||||
"phonemes": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
"phonemes": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||||
},
|
},
|
||||||
"use_speaker_embedding": false,
|
"use_speaker_embedding": false,
|
||||||
"use_gst": false,
|
"use_gst": false,
|
||||||
"use_external_speaker_embedding_file": false,
|
"use_external_speaker_embedding_file": false,
|
||||||
"external_speaker_embedding_file": "../../speakers-vctk-en.json"
|
"external_speaker_embedding_file": "../../speakers-vctk-en.json"
|
||||||
}
|
}
|
|
@ -13,7 +13,7 @@ output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
# init configs
|
# init configs
|
||||||
dataset_config = BaseDatasetConfig(
|
dataset_config = BaseDatasetConfig(
|
||||||
name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
|
formatter="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
|
||||||
)
|
)
|
||||||
config = AlignTTSConfig(
|
config = AlignTTSConfig(
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
|
|
|
@ -14,7 +14,7 @@ output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
# init configs
|
# init configs
|
||||||
dataset_config = BaseDatasetConfig(
|
dataset_config = BaseDatasetConfig(
|
||||||
name="ljspeech",
|
formatter="ljspeech",
|
||||||
meta_file_train="metadata.csv",
|
meta_file_train="metadata.csv",
|
||||||
# meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"),
|
# meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"),
|
||||||
path=os.path.join(output_path, "../LJSpeech-1.1/"),
|
path=os.path.join(output_path, "../LJSpeech-1.1/"),
|
||||||
|
|
|
@ -14,7 +14,7 @@ output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
# init configs
|
# init configs
|
||||||
dataset_config = BaseDatasetConfig(
|
dataset_config = BaseDatasetConfig(
|
||||||
name="ljspeech",
|
formatter="ljspeech",
|
||||||
meta_file_train="metadata.csv",
|
meta_file_train="metadata.csv",
|
||||||
# meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"),
|
# meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"),
|
||||||
path=os.path.join(output_path, "../LJSpeech-1.1/"),
|
path=os.path.join(output_path, "../LJSpeech-1.1/"),
|
||||||
|
|
|
@ -21,7 +21,7 @@ output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
# Set LJSpeech as our target dataset and define its path.
|
# Set LJSpeech as our target dataset and define its path.
|
||||||
# You can also use a simple Dict to define the dataset and pass it to your custom formatter.
|
# You can also use a simple Dict to define the dataset and pass it to your custom formatter.
|
||||||
dataset_config = BaseDatasetConfig(
|
dataset_config = BaseDatasetConfig(
|
||||||
name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
|
formatter="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
|
||||||
)
|
)
|
||||||
|
|
||||||
# INITIALIZE THE TRAINING CONFIGURATION
|
# INITIALIZE THE TRAINING CONFIGURATION
|
||||||
|
|
|
@ -11,7 +11,7 @@ from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
dataset_config = BaseDatasetConfig(
|
dataset_config = BaseDatasetConfig(
|
||||||
name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
|
formatter="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
|
||||||
)
|
)
|
||||||
|
|
||||||
audio_config = BaseAudioConfig(
|
audio_config = BaseAudioConfig(
|
||||||
|
|
|
@ -16,7 +16,7 @@ data_path = "/srv/data/"
|
||||||
|
|
||||||
# Using LJSpeech like dataset processing for the blizzard dataset
|
# Using LJSpeech like dataset processing for the blizzard dataset
|
||||||
dataset_config = BaseDatasetConfig(
|
dataset_config = BaseDatasetConfig(
|
||||||
name="ljspeech",
|
formatter="ljspeech",
|
||||||
meta_file_train="metadata.csv",
|
meta_file_train="metadata.csv",
|
||||||
path=data_path,
|
path=data_path,
|
||||||
)
|
)
|
||||||
|
|
|
@ -16,7 +16,7 @@ output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
# init configs
|
# init configs
|
||||||
dataset_config = BaseDatasetConfig(
|
dataset_config = BaseDatasetConfig(
|
||||||
name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
|
formatter="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
|
||||||
)
|
)
|
||||||
|
|
||||||
audio_config = BaseAudioConfig(
|
audio_config = BaseAudioConfig(
|
||||||
|
|
|
@ -16,7 +16,7 @@ output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
# init configs
|
# init configs
|
||||||
dataset_config = BaseDatasetConfig(
|
dataset_config = BaseDatasetConfig(
|
||||||
name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
|
formatter="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
|
||||||
)
|
)
|
||||||
|
|
||||||
audio_config = BaseAudioConfig(
|
audio_config = BaseAudioConfig(
|
||||||
|
|
|
@ -11,7 +11,7 @@ from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
dataset_config = BaseDatasetConfig(
|
dataset_config = BaseDatasetConfig(
|
||||||
name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
|
formatter="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
|
||||||
)
|
)
|
||||||
audio_config = VitsAudioConfig(
|
audio_config = VitsAudioConfig(
|
||||||
sample_rate=22050, win_length=1024, hop_length=256, num_mels=80, mel_fmin=0, mel_fmax=None
|
sample_rate=22050, win_length=1024, hop_length=256, num_mels=80, mel_fmin=0, mel_fmax=None
|
||||||
|
|
|
@ -17,7 +17,7 @@ output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
mailabs_path = "/home/julian/workspace/mailabs/**"
|
mailabs_path = "/home/julian/workspace/mailabs/**"
|
||||||
dataset_paths = glob(mailabs_path)
|
dataset_paths = glob(mailabs_path)
|
||||||
dataset_config = [
|
dataset_config = [
|
||||||
BaseDatasetConfig(name="mailabs", meta_file_train=None, path=path, language=path.split("/")[-1])
|
BaseDatasetConfig(formatter="mailabs", meta_file_train=None, path=path, language=path.split("/")[-1])
|
||||||
for path in dataset_paths
|
for path in dataset_paths
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@ output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
# init configs
|
# init configs
|
||||||
dataset_config = BaseDatasetConfig(
|
dataset_config = BaseDatasetConfig(
|
||||||
name="thorsten", meta_file_train="metadata.csv", path=os.path.join(output_path, "../thorsten-de/")
|
formatter="thorsten", meta_file_train="metadata.csv", path=os.path.join(output_path, "../thorsten-de/")
|
||||||
)
|
)
|
||||||
|
|
||||||
# download dataset if not already present
|
# download dataset if not already present
|
||||||
|
|
|
@ -22,7 +22,7 @@ output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
# Set LJSpeech as our target dataset and define its path.
|
# Set LJSpeech as our target dataset and define its path.
|
||||||
# You can also use a simple Dict to define the dataset and pass it to your custom formatter.
|
# You can also use a simple Dict to define the dataset and pass it to your custom formatter.
|
||||||
dataset_config = BaseDatasetConfig(
|
dataset_config = BaseDatasetConfig(
|
||||||
name="thorsten", meta_file_train="metadata.csv", path=os.path.join(output_path, "../thorsten-de/")
|
formatter="thorsten", meta_file_train="metadata.csv", path=os.path.join(output_path, "../thorsten-de/")
|
||||||
)
|
)
|
||||||
|
|
||||||
# download dataset if not already present
|
# download dataset if not already present
|
||||||
|
|
|
@ -12,7 +12,7 @@ from TTS.utils.downloaders import download_thorsten_de
|
||||||
|
|
||||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
dataset_config = BaseDatasetConfig(
|
dataset_config = BaseDatasetConfig(
|
||||||
name="thorsten", meta_file_train="metadata.csv", path=os.path.join(output_path, "../thorsten-de/")
|
formatter="thorsten", meta_file_train="metadata.csv", path=os.path.join(output_path, "../thorsten-de/")
|
||||||
)
|
)
|
||||||
|
|
||||||
# download dataset if not already present
|
# download dataset if not already present
|
||||||
|
|
|
@ -16,7 +16,7 @@ output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
# init configs
|
# init configs
|
||||||
dataset_config = BaseDatasetConfig(
|
dataset_config = BaseDatasetConfig(
|
||||||
name="thorsten", meta_file_train="metadata.csv", path=os.path.join(output_path, "../thorsten-de/")
|
formatter="thorsten", meta_file_train="metadata.csv", path=os.path.join(output_path, "../thorsten-de/")
|
||||||
)
|
)
|
||||||
|
|
||||||
# download dataset if not already present
|
# download dataset if not already present
|
||||||
|
|
|
@ -12,7 +12,7 @@ from TTS.utils.downloaders import download_thorsten_de
|
||||||
|
|
||||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
dataset_config = BaseDatasetConfig(
|
dataset_config = BaseDatasetConfig(
|
||||||
name="thorsten", meta_file_train="metadata.csv", path=os.path.join(output_path, "../thorsten-de/")
|
formatter="thorsten", meta_file_train="metadata.csv", path=os.path.join(output_path, "../thorsten-de/")
|
||||||
)
|
)
|
||||||
|
|
||||||
# download dataset if not already present
|
# download dataset if not already present
|
||||||
|
|
|
@ -11,7 +11,7 @@ from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
dataset_config = BaseDatasetConfig(name="vctk", meta_file_train="", path=os.path.join(output_path, "../VCTK/"))
|
dataset_config = BaseDatasetConfig(formatter="vctk", meta_file_train="", path=os.path.join(output_path, "../VCTK/"))
|
||||||
|
|
||||||
audio_config = BaseAudioConfig(
|
audio_config = BaseAudioConfig(
|
||||||
sample_rate=22050,
|
sample_rate=22050,
|
||||||
|
|
|
@ -11,7 +11,7 @@ from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
dataset_config = BaseDatasetConfig(name="vctk", meta_file_train="", path=os.path.join(output_path, "../VCTK/"))
|
dataset_config = BaseDatasetConfig(formatter="vctk", meta_file_train="", path=os.path.join(output_path, "../VCTK/"))
|
||||||
|
|
||||||
audio_config = BaseAudioConfig(
|
audio_config = BaseAudioConfig(
|
||||||
sample_rate=22050,
|
sample_rate=22050,
|
||||||
|
|
|
@ -22,7 +22,7 @@ if not os.path.exists(dataset_path):
|
||||||
download_vctk(dataset_path)
|
download_vctk(dataset_path)
|
||||||
|
|
||||||
# define dataset config
|
# define dataset config
|
||||||
dataset_config = BaseDatasetConfig(name="vctk", meta_file_train="", path=dataset_path)
|
dataset_config = BaseDatasetConfig(formatter="vctk", meta_file_train="", path=dataset_path)
|
||||||
|
|
||||||
# define audio config
|
# define audio config
|
||||||
# ❗ resample the dataset externally using `TTS/bin/resample.py` and set `resample=False` for faster training
|
# ❗ resample the dataset externally using `TTS/bin/resample.py` and set `resample=False` for faster training
|
||||||
|
|
|
@ -31,7 +31,7 @@ config = SpeakerEncoderConfig()
|
||||||
|
|
||||||
#### DATASET CONFIG ####
|
#### DATASET CONFIG ####
|
||||||
# The formatter need to return the key "speaker_name" for the speaker encoder and the "emotion_name" for the emotion encoder
|
# The formatter need to return the key "speaker_name" for the speaker encoder and the "emotion_name" for the emotion encoder
|
||||||
dataset_config = BaseDatasetConfig(name="vctk", meta_file_train="", language="en-us", path=VCTK_PATH)
|
dataset_config = BaseDatasetConfig(formatter="vctk", meta_file_train="", language="en-us", path=VCTK_PATH)
|
||||||
|
|
||||||
# add the dataset to the config
|
# add the dataset to the config
|
||||||
config.datasets = [dataset_config]
|
config.datasets = [dataset_config]
|
||||||
|
|
|
@ -11,7 +11,7 @@ from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
dataset_config = BaseDatasetConfig(name="vctk", meta_file_train="", path=os.path.join(output_path, "../VCTK/"))
|
dataset_config = BaseDatasetConfig(formatter="vctk", meta_file_train="", path=os.path.join(output_path, "../VCTK/"))
|
||||||
|
|
||||||
audio_config = BaseAudioConfig(
|
audio_config = BaseAudioConfig(
|
||||||
sample_rate=22050,
|
sample_rate=22050,
|
||||||
|
|
|
@ -12,7 +12,7 @@ from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
dataset_config = BaseDatasetConfig(name="vctk", meta_file_train="", path=os.path.join(output_path, "../VCTK/"))
|
dataset_config = BaseDatasetConfig(formatter="vctk", meta_file_train="", path=os.path.join(output_path, "../VCTK/"))
|
||||||
|
|
||||||
audio_config = BaseAudioConfig(
|
audio_config = BaseAudioConfig(
|
||||||
sample_rate=22050,
|
sample_rate=22050,
|
||||||
|
|
|
@ -12,7 +12,7 @@ from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
dataset_config = BaseDatasetConfig(name="vctk", meta_file_train="", path=os.path.join(output_path, "../VCTK/"))
|
dataset_config = BaseDatasetConfig(formatter="vctk", meta_file_train="", path=os.path.join(output_path, "../VCTK/"))
|
||||||
|
|
||||||
audio_config = BaseAudioConfig(
|
audio_config = BaseAudioConfig(
|
||||||
sample_rate=22050,
|
sample_rate=22050,
|
||||||
|
|
|
@ -12,7 +12,7 @@ from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
dataset_config = BaseDatasetConfig(name="vctk", meta_file_train="", path=os.path.join(output_path, "../VCTK/"))
|
dataset_config = BaseDatasetConfig(formatter="vctk", meta_file_train="", path=os.path.join(output_path, "../VCTK/"))
|
||||||
|
|
||||||
audio_config = BaseAudioConfig(
|
audio_config = BaseAudioConfig(
|
||||||
sample_rate=22050,
|
sample_rate=22050,
|
||||||
|
|
|
@ -12,7 +12,7 @@ from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
dataset_config = BaseDatasetConfig(
|
dataset_config = BaseDatasetConfig(
|
||||||
name="vctk", meta_file_train="", language="en-us", path=os.path.join(output_path, "../VCTK/")
|
formatter="vctk", meta_file_train="", language="en-us", path=os.path.join(output_path, "../VCTK/")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -42,7 +42,7 @@ def run_cli(command):
|
||||||
|
|
||||||
|
|
||||||
def get_test_data_config():
|
def get_test_data_config():
|
||||||
return BaseDatasetConfig(name="ljspeech", path="tests/data/ljspeech/", meta_file_train="metadata.csv")
|
return BaseDatasetConfig(formatter="ljspeech", path="tests/data/ljspeech/", meta_file_train="metadata.csv")
|
||||||
|
|
||||||
|
|
||||||
def assertHasAttr(test_obj, obj, intendedAttr):
|
def assertHasAttr(test_obj, obj, intendedAttr):
|
||||||
|
|
|
@ -0,0 +1,92 @@
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tests import get_tests_input_path
|
||||||
|
from TTS.config import load_config
|
||||||
|
from TTS.encoder.utils.generic_utils import setup_encoder_model
|
||||||
|
from TTS.encoder.utils.io import save_checkpoint
|
||||||
|
from TTS.tts.utils.managers import EmbeddingManager
|
||||||
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
|
encoder_config_path = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json")
|
||||||
|
encoder_model_path = os.path.join(get_tests_input_path(), "checkpoint_0.pth")
|
||||||
|
sample_wav_path = os.path.join(get_tests_input_path(), "../data/ljspeech/wavs/LJ001-0001.wav")
|
||||||
|
sample_wav_path2 = os.path.join(get_tests_input_path(), "../data/ljspeech/wavs/LJ001-0002.wav")
|
||||||
|
embedding_file_path = os.path.join(get_tests_input_path(), "../data/dummy_speakers.json")
|
||||||
|
embeddings_file_path2 = os.path.join(get_tests_input_path(), "../data/dummy_speakers2.json")
|
||||||
|
embeddings_file_pth_path = os.path.join(get_tests_input_path(), "../data/dummy_speakers.pth")
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingManagerTest(unittest.TestCase):
|
||||||
|
"""Test emEeddingManager for loading embedding files and computing embeddings from waveforms"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def test_speaker_embedding():
|
||||||
|
# load config
|
||||||
|
config = load_config(encoder_config_path)
|
||||||
|
config.audio.resample = True
|
||||||
|
|
||||||
|
# create a dummy speaker encoder
|
||||||
|
model = setup_encoder_model(config)
|
||||||
|
save_checkpoint(model, None, None, get_tests_input_path(), 0)
|
||||||
|
|
||||||
|
# load audio processor and speaker encoder
|
||||||
|
manager = EmbeddingManager(encoder_model_path=encoder_model_path, encoder_config_path=encoder_config_path)
|
||||||
|
|
||||||
|
# load a sample audio and compute embedding
|
||||||
|
ap = AudioProcessor(**config.audio)
|
||||||
|
waveform = ap.load_wav(sample_wav_path)
|
||||||
|
mel = ap.melspectrogram(waveform)
|
||||||
|
embedding = manager.compute_embeddings(mel)
|
||||||
|
assert embedding.shape[1] == 256
|
||||||
|
|
||||||
|
# compute embedding directly from an input file
|
||||||
|
embedding = manager.compute_embedding_from_clip(sample_wav_path)
|
||||||
|
embedding2 = manager.compute_embedding_from_clip(sample_wav_path)
|
||||||
|
embedding = torch.FloatTensor(embedding)
|
||||||
|
embedding2 = torch.FloatTensor(embedding2)
|
||||||
|
assert embedding.shape[0] == 256
|
||||||
|
assert (embedding - embedding2).sum() == 0.0
|
||||||
|
|
||||||
|
# compute embedding from a list of wav files.
|
||||||
|
embedding3 = manager.compute_embedding_from_clip([sample_wav_path, sample_wav_path2])
|
||||||
|
embedding3 = torch.FloatTensor(embedding3)
|
||||||
|
assert embedding3.shape[0] == 256
|
||||||
|
assert (embedding - embedding3).sum() != 0.0
|
||||||
|
|
||||||
|
# remove dummy model
|
||||||
|
os.remove(encoder_model_path)
|
||||||
|
|
||||||
|
def test_embedding_file_processing(self): # pylint: disable=no-self-use
|
||||||
|
manager = EmbeddingManager(embedding_file_path=embeddings_file_pth_path)
|
||||||
|
# test embedding querying
|
||||||
|
embedding = manager.get_embedding_by_clip(manager.clip_ids[0])
|
||||||
|
assert len(embedding) == 256
|
||||||
|
embeddings = manager.get_embeddings_by_name(manager.embedding_names[0])
|
||||||
|
assert len(embeddings[0]) == 256
|
||||||
|
embedding1 = manager.get_mean_embedding(manager.embedding_names[0], num_samples=2, randomize=True)
|
||||||
|
assert len(embedding1) == 256
|
||||||
|
embedding2 = manager.get_mean_embedding(manager.embedding_names[0], num_samples=2, randomize=False)
|
||||||
|
assert len(embedding2) == 256
|
||||||
|
assert np.sum(np.array(embedding1) - np.array(embedding2)) != 0
|
||||||
|
|
||||||
|
def test_embedding_file_loading(self):
|
||||||
|
# test loading a json file
|
||||||
|
manager = EmbeddingManager(embedding_file_path=embedding_file_path)
|
||||||
|
self.assertEqual(manager.num_embeddings, 384)
|
||||||
|
self.assertEqual(manager.embedding_dim, 256)
|
||||||
|
# test loading a pth file
|
||||||
|
manager = EmbeddingManager(embedding_file_path=embeddings_file_pth_path)
|
||||||
|
self.assertEqual(manager.num_embeddings, 384)
|
||||||
|
self.assertEqual(manager.embedding_dim, 256)
|
||||||
|
# test loading a pth files with duplicate embedding keys
|
||||||
|
with self.assertRaises(Exception) as context:
|
||||||
|
manager = EmbeddingManager(embedding_file_path=[embeddings_file_pth_path, embeddings_file_pth_path])
|
||||||
|
self.assertTrue("Duplicate embedding names" in str(context.exception))
|
||||||
|
# test loading embedding files with different embedding keys
|
||||||
|
manager = EmbeddingManager(embedding_file_path=[embeddings_file_pth_path, embeddings_file_path2])
|
||||||
|
self.assertEqual(manager.embedding_dim, 256)
|
||||||
|
self.assertEqual(manager.num_embeddings, 384 * 2)
|
|
@ -12,7 +12,7 @@ torch.manual_seed(1)
|
||||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||||
|
|
||||||
dataset_config_en = BaseDatasetConfig(
|
dataset_config_en = BaseDatasetConfig(
|
||||||
name="ljspeech",
|
formatter="ljspeech",
|
||||||
meta_file_train="metadata.csv",
|
meta_file_train="metadata.csv",
|
||||||
meta_file_val="metadata.csv",
|
meta_file_val="metadata.csv",
|
||||||
path="tests/data/ljspeech",
|
path="tests/data/ljspeech",
|
||||||
|
@ -21,7 +21,7 @@ dataset_config_en = BaseDatasetConfig(
|
||||||
|
|
||||||
"""
|
"""
|
||||||
dataset_config_pt = BaseDatasetConfig(
|
dataset_config_pt = BaseDatasetConfig(
|
||||||
name="ljspeech",
|
formatter="ljspeech",
|
||||||
meta_file_train="metadata.csv",
|
meta_file_train="metadata.csv",
|
||||||
meta_file_val="metadata.csv",
|
meta_file_val="metadata.csv",
|
||||||
path="tests/data/ljspeech",
|
path="tests/data/ljspeech",
|
||||||
|
|
|
@ -11,7 +11,7 @@ def run_test_train():
|
||||||
command = (
|
command = (
|
||||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_encoder.py --config_path {config_path} "
|
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_encoder.py --config_path {config_path} "
|
||||||
f"--coqpit.output_path {output_path} "
|
f"--coqpit.output_path {output_path} "
|
||||||
"--coqpit.datasets.0.name ljspeech_test "
|
"--coqpit.datasets.0.formatter ljspeech_test "
|
||||||
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||||
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||||
"--coqpit.datasets.0.path tests/data/ljspeech "
|
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||||
|
|
|
@ -59,7 +59,7 @@ class SpeakerManagerTest(unittest.TestCase):
|
||||||
# remove dummy model
|
# remove dummy model
|
||||||
os.remove(encoder_model_path)
|
os.remove(encoder_model_path)
|
||||||
|
|
||||||
def test_speakers_file_processing(self):
|
def test_dvector_file_processing(self):
|
||||||
manager = SpeakerManager(d_vectors_file_path=d_vectors_file_path)
|
manager = SpeakerManager(d_vectors_file_path=d_vectors_file_path)
|
||||||
self.assertEqual(manager.num_speakers, 1)
|
self.assertEqual(manager.num_speakers, 1)
|
||||||
self.assertEqual(manager.embedding_dim, 256)
|
self.assertEqual(manager.embedding_dim, 256)
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -24,7 +24,7 @@ c.data_path = os.path.join(get_tests_data_path(), "ljspeech/")
|
||||||
ok_ljspeech = os.path.exists(c.data_path)
|
ok_ljspeech = os.path.exists(c.data_path)
|
||||||
|
|
||||||
dataset_config = BaseDatasetConfig(
|
dataset_config = BaseDatasetConfig(
|
||||||
name="ljspeech_test", # ljspeech_test to multi-speaker
|
formatter="ljspeech_test", # ljspeech_test to multi-speaker
|
||||||
meta_file_train="metadata.csv",
|
meta_file_train="metadata.csv",
|
||||||
meta_file_val=None,
|
meta_file_val=None,
|
||||||
path=c.data_path,
|
path=c.data_path,
|
||||||
|
|
|
@ -15,7 +15,7 @@ from TTS.utils.samplers import BucketBatchSampler, PerfectBatchSampler
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
|
||||||
dataset_config_en = BaseDatasetConfig(
|
dataset_config_en = BaseDatasetConfig(
|
||||||
name="ljspeech",
|
formatter="ljspeech",
|
||||||
meta_file_train="metadata.csv",
|
meta_file_train="metadata.csv",
|
||||||
meta_file_val="metadata.csv",
|
meta_file_val="metadata.csv",
|
||||||
path="tests/data/ljspeech",
|
path="tests/data/ljspeech",
|
||||||
|
@ -23,7 +23,7 @@ dataset_config_en = BaseDatasetConfig(
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset_config_pt = BaseDatasetConfig(
|
dataset_config_pt = BaseDatasetConfig(
|
||||||
name="ljspeech",
|
formatter="ljspeech",
|
||||||
meta_file_train="metadata.csv",
|
meta_file_train="metadata.csv",
|
||||||
meta_file_val="metadata.csv",
|
meta_file_val="metadata.csv",
|
||||||
path="tests/data/ljspeech",
|
path="tests/data/ljspeech",
|
||||||
|
|
|
@ -148,7 +148,7 @@
|
||||||
"datasets": // List of datasets. They all merged and they get different speaker_ids.
|
"datasets": // List of datasets. They all merged and they get different speaker_ids.
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"name": "ljspeech",
|
"formatter": "ljspeech",
|
||||||
"path": "tests/data/ljspeech/",
|
"path": "tests/data/ljspeech/",
|
||||||
"meta_file_train": "metadata.csv",
|
"meta_file_train": "metadata.csv",
|
||||||
"meta_file_val": "metadata.csv",
|
"meta_file_val": "metadata.csv",
|
||||||
|
|
|
@ -140,12 +140,10 @@
|
||||||
"datasets": // List of datasets. They all merged and they get different speaker_ids.
|
"datasets": // List of datasets. They all merged and they get different speaker_ids.
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"name": "ljspeech",
|
"formatter": "ljspeech",
|
||||||
"path": "tests/data/ljspeech/",
|
"path": "tests/data/ljspeech/",
|
||||||
"meta_file_train": "metadata.csv",
|
"meta_file_train": "metadata.csv",
|
||||||
"meta_file_val": "metadata.csv"
|
"meta_file_val": "metadata.csv"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -145,7 +145,7 @@
|
||||||
"datasets": // List of datasets. They all merged and they get different speaker_ids.
|
"datasets": // List of datasets. They all merged and they get different speaker_ids.
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"name": "ljspeech",
|
"formatter": "ljspeech",
|
||||||
"path": "tests/data/ljspeech/",
|
"path": "tests/data/ljspeech/",
|
||||||
"meta_file_train": "metadata.csv",
|
"meta_file_train": "metadata.csv",
|
||||||
"meta_file_val": "metadata.csv",
|
"meta_file_val": "metadata.csv",
|
||||||
|
|
|
@ -166,7 +166,7 @@
|
||||||
"datasets": // List of datasets. They all merged and they get different speaker_ids.
|
"datasets": // List of datasets. They all merged and they get different speaker_ids.
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"name": "ljspeech",
|
"formatter": "ljspeech",
|
||||||
"path": "tests/data/ljspeech/",
|
"path": "tests/data/ljspeech/",
|
||||||
"meta_file_train": "metadata.csv",
|
"meta_file_train": "metadata.csv",
|
||||||
"meta_file_val": "metadata.csv"
|
"meta_file_val": "metadata.csv"
|
||||||
|
@ -174,4 +174,3 @@
|
||||||
]
|
]
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -166,7 +166,7 @@
|
||||||
"datasets": // List of datasets. They all merged and they get different speaker_ids.
|
"datasets": // List of datasets. They all merged and they get different speaker_ids.
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"name": "ljspeech",
|
"formatter": "ljspeech",
|
||||||
"path": "tests/data/ljspeech/",
|
"path": "tests/data/ljspeech/",
|
||||||
"meta_file_train": "metadata.csv",
|
"meta_file_train": "metadata.csv",
|
||||||
"meta_file_val": "metadata.csv"
|
"meta_file_val": "metadata.csv"
|
||||||
|
@ -174,4 +174,3 @@
|
||||||
]
|
]
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -166,7 +166,7 @@
|
||||||
"datasets": // List of datasets. They all merged and they get different speaker_ids.
|
"datasets": // List of datasets. They all merged and they get different speaker_ids.
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"name": "ljspeech",
|
"formatter": "ljspeech",
|
||||||
"path": "tests/data/ljspeech/",
|
"path": "tests/data/ljspeech/",
|
||||||
"meta_file_train": "metadata.csv",
|
"meta_file_train": "metadata.csv",
|
||||||
"meta_file_val": "metadata.csv"
|
"meta_file_val": "metadata.csv"
|
||||||
|
@ -174,4 +174,3 @@
|
||||||
]
|
]
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -39,7 +39,7 @@ config.save_json(config_path)
|
||||||
command_train = (
|
command_train = (
|
||||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||||
f"--coqpit.output_path {output_path} "
|
f"--coqpit.output_path {output_path} "
|
||||||
"--coqpit.datasets.0.name ljspeech "
|
"--coqpit.datasets.0.formatter ljspeech "
|
||||||
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||||
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||||
"--coqpit.datasets.0.path tests/data/ljspeech "
|
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||||
|
|
|
@ -56,7 +56,7 @@ config.save_json(config_path)
|
||||||
command_train = (
|
command_train = (
|
||||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||||
f"--coqpit.output_path {output_path} "
|
f"--coqpit.output_path {output_path} "
|
||||||
"--coqpit.datasets.0.name ljspeech_test "
|
"--coqpit.datasets.0.formatter ljspeech_test "
|
||||||
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||||
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||||
"--coqpit.datasets.0.path tests/data/ljspeech "
|
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||||
|
|
|
@ -56,7 +56,7 @@ config.save_json(config_path)
|
||||||
command_train = (
|
command_train = (
|
||||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||||
f"--coqpit.output_path {output_path} "
|
f"--coqpit.output_path {output_path} "
|
||||||
"--coqpit.datasets.0.name ljspeech "
|
"--coqpit.datasets.0.formatter ljspeech "
|
||||||
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||||
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||||
"--coqpit.datasets.0.path tests/data/ljspeech "
|
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||||
|
|
|
@ -43,7 +43,7 @@ config.save_json(config_path)
|
||||||
command_train = (
|
command_train = (
|
||||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||||
f"--coqpit.output_path {output_path} "
|
f"--coqpit.output_path {output_path} "
|
||||||
"--coqpit.datasets.0.name ljspeech_test "
|
"--coqpit.datasets.0.formatter ljspeech_test "
|
||||||
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||||
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||||
"--coqpit.datasets.0.path tests/data/ljspeech "
|
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||||
|
|
|
@ -40,7 +40,7 @@ config.save_json(config_path)
|
||||||
command_train = (
|
command_train = (
|
||||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||||
f"--coqpit.output_path {output_path} "
|
f"--coqpit.output_path {output_path} "
|
||||||
"--coqpit.datasets.0.name ljspeech_test "
|
"--coqpit.datasets.0.formatter ljspeech_test "
|
||||||
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||||
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||||
"--coqpit.datasets.0.path tests/data/ljspeech "
|
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||||
|
|
|
@ -39,7 +39,7 @@ config.save_json(config_path)
|
||||||
command_train = (
|
command_train = (
|
||||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||||
f"--coqpit.output_path {output_path} "
|
f"--coqpit.output_path {output_path} "
|
||||||
"--coqpit.datasets.0.name ljspeech "
|
"--coqpit.datasets.0.formatter ljspeech "
|
||||||
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||||
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||||
"--coqpit.datasets.0.path tests/data/ljspeech "
|
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||||
|
|
|
@ -38,7 +38,7 @@ config.save_json(config_path)
|
||||||
command_train = (
|
command_train = (
|
||||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||||
f"--coqpit.output_path {output_path} "
|
f"--coqpit.output_path {output_path} "
|
||||||
"--coqpit.datasets.0.name ljspeech "
|
"--coqpit.datasets.0.formatter ljspeech "
|
||||||
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||||
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||||
"--coqpit.datasets.0.path tests/data/ljspeech "
|
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||||
|
|
|
@ -44,7 +44,7 @@ config.save_json(config_path)
|
||||||
command_train = (
|
command_train = (
|
||||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||||
f"--coqpit.output_path {output_path} "
|
f"--coqpit.output_path {output_path} "
|
||||||
"--coqpit.datasets.0.name ljspeech_test "
|
"--coqpit.datasets.0.formatter ljspeech_test "
|
||||||
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||||
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||||
"--coqpit.datasets.0.path tests/data/ljspeech "
|
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||||
|
|
|
@ -42,7 +42,7 @@ config.save_json(config_path)
|
||||||
command_train = (
|
command_train = (
|
||||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||||
f"--coqpit.output_path {output_path} "
|
f"--coqpit.output_path {output_path} "
|
||||||
"--coqpit.datasets.0.name ljspeech_test "
|
"--coqpit.datasets.0.formatter ljspeech_test "
|
||||||
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||||
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||||
"--coqpit.datasets.0.path tests/data/ljspeech "
|
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||||
|
|
|
@ -39,7 +39,7 @@ config.save_json(config_path)
|
||||||
command_train = (
|
command_train = (
|
||||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||||
f"--coqpit.output_path {output_path} "
|
f"--coqpit.output_path {output_path} "
|
||||||
"--coqpit.datasets.0.name ljspeech "
|
"--coqpit.datasets.0.formatter ljspeech "
|
||||||
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||||
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||||
"--coqpit.datasets.0.path tests/data/ljspeech "
|
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||||
|
|
|
@ -39,7 +39,7 @@ config.save_json(config_path)
|
||||||
command_train = (
|
command_train = (
|
||||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||||
f"--coqpit.output_path {output_path} "
|
f"--coqpit.output_path {output_path} "
|
||||||
"--coqpit.datasets.0.name ljspeech "
|
"--coqpit.datasets.0.formatter ljspeech "
|
||||||
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||||
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||||
"--coqpit.datasets.0.path tests/data/ljspeech "
|
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||||
|
|
|
@ -43,7 +43,7 @@ config.save_json(config_path)
|
||||||
command_train = (
|
command_train = (
|
||||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||||
f"--coqpit.output_path {output_path} "
|
f"--coqpit.output_path {output_path} "
|
||||||
"--coqpit.datasets.0.name ljspeech "
|
"--coqpit.datasets.0.formatter ljspeech "
|
||||||
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||||
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||||
"--coqpit.datasets.0.path tests/data/ljspeech "
|
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||||
|
|
|
@ -14,7 +14,7 @@ output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||||
|
|
||||||
|
|
||||||
dataset_config_en = BaseDatasetConfig(
|
dataset_config_en = BaseDatasetConfig(
|
||||||
name="ljspeech",
|
formatter="ljspeech",
|
||||||
meta_file_train="metadata.csv",
|
meta_file_train="metadata.csv",
|
||||||
meta_file_val="metadata.csv",
|
meta_file_val="metadata.csv",
|
||||||
path="tests/data/ljspeech",
|
path="tests/data/ljspeech",
|
||||||
|
@ -22,7 +22,7 @@ dataset_config_en = BaseDatasetConfig(
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset_config_pt = BaseDatasetConfig(
|
dataset_config_pt = BaseDatasetConfig(
|
||||||
name="ljspeech",
|
formatter="ljspeech",
|
||||||
meta_file_train="metadata.csv",
|
meta_file_train="metadata.csv",
|
||||||
meta_file_val="metadata.csv",
|
meta_file_val="metadata.csv",
|
||||||
path="tests/data/ljspeech",
|
path="tests/data/ljspeech",
|
||||||
|
|
|
@ -14,7 +14,7 @@ output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||||
|
|
||||||
|
|
||||||
dataset_config_en = BaseDatasetConfig(
|
dataset_config_en = BaseDatasetConfig(
|
||||||
name="ljspeech_test",
|
formatter="ljspeech_test",
|
||||||
meta_file_train="metadata.csv",
|
meta_file_train="metadata.csv",
|
||||||
meta_file_val="metadata.csv",
|
meta_file_val="metadata.csv",
|
||||||
path="tests/data/ljspeech",
|
path="tests/data/ljspeech",
|
||||||
|
@ -22,7 +22,7 @@ dataset_config_en = BaseDatasetConfig(
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset_config_pt = BaseDatasetConfig(
|
dataset_config_pt = BaseDatasetConfig(
|
||||||
name="ljspeech_test",
|
formatter="ljspeech_test",
|
||||||
meta_file_train="metadata.csv",
|
meta_file_train="metadata.csv",
|
||||||
meta_file_val="metadata.csv",
|
meta_file_val="metadata.csv",
|
||||||
path="tests/data/ljspeech",
|
path="tests/data/ljspeech",
|
||||||
|
|
|
@ -47,7 +47,7 @@ config.save_json(config_path)
|
||||||
command_train = (
|
command_train = (
|
||||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||||
f"--coqpit.output_path {output_path} "
|
f"--coqpit.output_path {output_path} "
|
||||||
"--coqpit.datasets.0.name ljspeech_test "
|
"--coqpit.datasets.0.formatter ljspeech_test "
|
||||||
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||||
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||||
"--coqpit.datasets.0.path tests/data/ljspeech "
|
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||||
|
|
|
@ -38,7 +38,7 @@ config.save_json(config_path)
|
||||||
command_train = (
|
command_train = (
|
||||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||||
f"--coqpit.output_path {output_path} "
|
f"--coqpit.output_path {output_path} "
|
||||||
"--coqpit.datasets.0.name ljspeech "
|
"--coqpit.datasets.0.formatter ljspeech "
|
||||||
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||||
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||||
"--coqpit.datasets.0.path tests/data/ljspeech "
|
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||||
|
|
|
@ -38,7 +38,7 @@ def test_run_all_models():
|
||||||
language_manager = LanguageManager(language_ids_file_path=language_files[0])
|
language_manager = LanguageManager(language_ids_file_path=language_files[0])
|
||||||
language_id = language_manager.language_names[0]
|
language_id = language_manager.language_names[0]
|
||||||
|
|
||||||
speaker_id = list(speaker_manager.ids.keys())[0]
|
speaker_id = list(speaker_manager.name_to_id.keys())[0]
|
||||||
run_cli(
|
run_cli(
|
||||||
f"tts --model_name {model_name} "
|
f"tts --model_name {model_name} "
|
||||||
f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}" --language_idx "{language_id}" '
|
f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}" --language_idx "{language_id}" '
|
||||||
|
|
Loading…
Reference in New Issue