mirror of https://github.com/coqui-ai/TTS.git
Merge branch 'pr/Edresson/731-rebased' into dev
This commit is contained in:
commit
d37cfe474a
|
@ -128,6 +128,8 @@ core
|
|||
recipes/WIP/*
|
||||
recipes/ljspeech/LJSpeech-1.1/*
|
||||
recipes/vctk/VCTK/*
|
||||
recipes/**/*.npy
|
||||
recipes/**/*.json
|
||||
VCTK-Corpus-removed-silence/*
|
||||
|
||||
# ignore training logs
|
||||
|
@ -161,4 +163,5 @@ speakers.json
|
|||
internal/*
|
||||
*_pitch.npy
|
||||
*_phoneme.npy
|
||||
wandb
|
||||
wandb
|
||||
depot/*
|
|
@ -1,5 +1,17 @@
|
|||
{
|
||||
"tts_models": {
|
||||
"multilingual":{
|
||||
"multi-dataset":{
|
||||
"your_tts":{
|
||||
"description": "Your TTS model accompanying the paper https://arxiv.org/abs/2112.02418",
|
||||
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.5.0_models/tts_models--multilingual--multi-dataset--your_tts.zip",
|
||||
"default_vocoder": null,
|
||||
"commit": "e9a1953e",
|
||||
"license": "CC BY-NC-ND 4.0",
|
||||
"contact": "egolge@coqui.ai"
|
||||
}
|
||||
}
|
||||
},
|
||||
"en": {
|
||||
"ek1": {
|
||||
"tacotron2": {
|
||||
|
|
|
@ -12,7 +12,7 @@ from tqdm import tqdm
|
|||
from TTS.config import load_config
|
||||
from TTS.tts.datasets import TTSDataset, load_tts_samples
|
||||
from TTS.tts.models import setup_model
|
||||
from TTS.tts.utils.speakers import get_speaker_manager
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import count_parameters
|
||||
|
||||
|
@ -37,8 +37,8 @@ def setup_loader(ap, r, verbose=False):
|
|||
enable_eos_bos=c.enable_eos_bos_chars,
|
||||
use_noise_augment=False,
|
||||
verbose=verbose,
|
||||
speaker_id_mapping=speaker_manager.speaker_ids,
|
||||
d_vector_mapping=speaker_manager.d_vectors if c.use_speaker_embedding and c.use_d_vector_file else None,
|
||||
speaker_id_mapping=speaker_manager.speaker_ids if c.use_speaker_embedding else None,
|
||||
d_vector_mapping=speaker_manager.d_vectors if c.use_d_vector_file else None,
|
||||
)
|
||||
|
||||
if c.use_phonemes and c.compute_input_seq_cache:
|
||||
|
@ -234,8 +234,13 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
# use eval and training partitions
|
||||
meta_data = meta_data_train + meta_data_eval
|
||||
|
||||
# parse speakers
|
||||
speaker_manager = get_speaker_manager(c, args, meta_data_train)
|
||||
# init speaker manager
|
||||
if c.use_speaker_embedding:
|
||||
speaker_manager = SpeakerManager(data_items=meta_data)
|
||||
elif c.use_d_vector_file:
|
||||
speaker_manager = SpeakerManager(d_vectors_file_path=c.d_vector_file)
|
||||
else:
|
||||
speaker_manager = None
|
||||
|
||||
# setup model
|
||||
model = setup_model(c)
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
"""Find all the unique characters in a dataset"""
|
||||
import argparse
|
||||
import multiprocessing
|
||||
from argparse import RawTextHelpFormatter
|
||||
|
||||
from tqdm.contrib.concurrent import process_map
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.utils.text import text2phone
|
||||
|
||||
|
||||
def compute_phonemes(item):
|
||||
try:
|
||||
text = item[0]
|
||||
language = item[-1]
|
||||
ph = text2phone(text, language, use_espeak_phonemes=c.use_espeak_phonemes).split("|")
|
||||
except:
|
||||
return []
|
||||
return list(set(ph))
|
||||
|
||||
|
||||
def main():
|
||||
# pylint: disable=W0601
|
||||
global c
|
||||
# pylint: disable=bad-option-value
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""Find all the unique characters or phonemes in a dataset.\n\n"""
|
||||
"""
|
||||
Example runs:
|
||||
|
||||
python TTS/bin/find_unique_chars.py --config_path config.json
|
||||
""",
|
||||
formatter_class=RawTextHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--config_path", type=str, help="Path to dataset config file.", required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
c = load_config(args.config_path)
|
||||
|
||||
# load all datasets
|
||||
train_items, eval_items = load_tts_samples(c.datasets, eval_split=True)
|
||||
items = train_items + eval_items
|
||||
print("Num items:", len(items))
|
||||
|
||||
phonemes = process_map(compute_phonemes, items, max_workers=multiprocessing.cpu_count(), chunksize=15)
|
||||
phones = []
|
||||
for ph in phonemes:
|
||||
phones.extend(ph)
|
||||
phones = set(phones)
|
||||
lower_phones = filter(lambda c: c.islower(), phones)
|
||||
phones_force_lower = [c.lower() for c in phones]
|
||||
phones_force_lower = set(phones_force_lower)
|
||||
|
||||
print(f" > Number of unique phonemes: {len(phones)}")
|
||||
print(f" > Unique phonemes: {''.join(sorted(phones))}")
|
||||
print(f" > Unique lower phonemes: {''.join(sorted(lower_phones))}")
|
||||
print(f" > Unique all forced to lower phonemes: {''.join(sorted(phones_force_lower))}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,89 @@
|
|||
import argparse
|
||||
import glob
|
||||
import multiprocessing
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
from tqdm.contrib.concurrent import process_map
|
||||
|
||||
from TTS.utils.vad import get_vad_speech_segments, read_wave, write_wave
|
||||
|
||||
|
||||
def remove_silence(filepath):
|
||||
output_path = filepath.replace(os.path.join(args.input_dir, ""), os.path.join(args.output_dir, ""))
|
||||
# ignore if the file exists
|
||||
if os.path.exists(output_path) and not args.force:
|
||||
return
|
||||
|
||||
# create all directory structure
|
||||
pathlib.Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
# load wave
|
||||
audio, sample_rate = read_wave(filepath)
|
||||
|
||||
# get speech segments
|
||||
segments = get_vad_speech_segments(audio, sample_rate, aggressiveness=args.aggressiveness)
|
||||
|
||||
segments = list(segments)
|
||||
num_segments = len(segments)
|
||||
flag = False
|
||||
# create the output wave
|
||||
if num_segments != 0:
|
||||
for i, segment in reversed(list(enumerate(segments))):
|
||||
if i >= 1:
|
||||
if not flag:
|
||||
concat_segment = segment
|
||||
flag = True
|
||||
else:
|
||||
concat_segment = segment + concat_segment
|
||||
else:
|
||||
if flag:
|
||||
segment = segment + concat_segment
|
||||
# print("Saving: ", output_path)
|
||||
write_wave(output_path, segment, sample_rate)
|
||||
return
|
||||
else:
|
||||
print("> Just Copying the file to:", output_path)
|
||||
# if fail to remove silence just write the file
|
||||
write_wave(output_path, audio, sample_rate)
|
||||
return
|
||||
|
||||
|
||||
def preprocess_audios():
|
||||
files = sorted(glob.glob(os.path.join(args.input_dir, args.glob), recursive=True))
|
||||
print("> Number of files: ", len(files))
|
||||
if not args.force:
|
||||
print("> Ignoring files that already exist in the output directory.")
|
||||
|
||||
if files:
|
||||
# create threads
|
||||
num_threads = multiprocessing.cpu_count()
|
||||
process_map(remove_silence, files, max_workers=num_threads, chunksize=15)
|
||||
else:
|
||||
print("> No files Found !")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="python remove_silence.py -i=VCTK-Corpus-bk/ -o=../VCTK-Corpus-removed-silence -g=wav48/*/*.wav -a=2"
|
||||
)
|
||||
parser.add_argument("-i", "--input_dir", type=str, default="../VCTK-Corpus", help="Dataset root dir")
|
||||
parser.add_argument(
|
||||
"-o", "--output_dir", type=str, default="../VCTK-Corpus-removed-silence", help="Output Dataset dir"
|
||||
)
|
||||
parser.add_argument("-f", "--force", default=False, action="store_true", help="Force the replace of exists files")
|
||||
parser.add_argument(
|
||||
"-g",
|
||||
"--glob",
|
||||
type=str,
|
||||
default="**/*.wav",
|
||||
help="path in glob format for acess wavs from input_dir. ex: wav48/*/*.wav",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-a",
|
||||
"--aggressiveness",
|
||||
type=int,
|
||||
default=2,
|
||||
help="set its aggressiveness mode, which is an integer between 0 and 3. 0 is the least aggressive about filtering out non-speech, 3 is the most aggressive.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
preprocess_audios()
|
|
@ -152,12 +152,19 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
|
||||
# args for multi-speaker synthesis
|
||||
parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None)
|
||||
parser.add_argument("--language_ids_file_path", type=str, help="JSON file for multi-lingual model.", default=None)
|
||||
parser.add_argument(
|
||||
"--speaker_idx",
|
||||
type=str,
|
||||
help="Target speaker ID for a multi-speaker TTS model.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--language_idx",
|
||||
type=str,
|
||||
help="Target language ID for a multi-lingual TTS model.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speaker_wav",
|
||||
nargs="+",
|
||||
|
@ -173,6 +180,14 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
const=True,
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--list_language_idxs",
|
||||
help="List available language ids for the defined multi-lingual model.",
|
||||
type=str2bool,
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
)
|
||||
# aux args
|
||||
parser.add_argument(
|
||||
"--save_spectogram",
|
||||
|
@ -184,7 +199,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
args = parser.parse_args()
|
||||
|
||||
# print the description if either text or list_models is not set
|
||||
if args.text is None and not args.list_models and not args.list_speaker_idxs:
|
||||
if args.text is None and not args.list_models and not args.list_speaker_idxs and not args.list_language_idxs:
|
||||
parser.parse_args(["-h"])
|
||||
|
||||
# load model manager
|
||||
|
@ -194,6 +209,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
model_path = None
|
||||
config_path = None
|
||||
speakers_file_path = None
|
||||
language_ids_file_path = None
|
||||
vocoder_path = None
|
||||
vocoder_config_path = None
|
||||
encoder_path = None
|
||||
|
@ -217,6 +233,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
model_path = args.model_path
|
||||
config_path = args.config_path
|
||||
speakers_file_path = args.speakers_file_path
|
||||
language_ids_file_path = args.language_ids_file_path
|
||||
|
||||
if args.vocoder_path is not None:
|
||||
vocoder_path = args.vocoder_path
|
||||
|
@ -231,6 +248,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
model_path,
|
||||
config_path,
|
||||
speakers_file_path,
|
||||
language_ids_file_path,
|
||||
vocoder_path,
|
||||
vocoder_config_path,
|
||||
encoder_path,
|
||||
|
@ -246,6 +264,14 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
print(synthesizer.tts_model.speaker_manager.speaker_ids)
|
||||
return
|
||||
|
||||
# query langauge ids of a multi-lingual model.
|
||||
if args.list_language_idxs:
|
||||
print(
|
||||
" > Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model."
|
||||
)
|
||||
print(synthesizer.tts_model.language_manager.language_id_mapping)
|
||||
return
|
||||
|
||||
# check the arguments against a multi-speaker model.
|
||||
if synthesizer.tts_speakers_file and (not args.speaker_idx and not args.speaker_wav):
|
||||
print(
|
||||
|
@ -258,7 +284,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
print(" > Text: {}".format(args.text))
|
||||
|
||||
# kick it
|
||||
wav = synthesizer.tts(args.text, args.speaker_idx, args.speaker_wav, args.gst_style)
|
||||
wav = synthesizer.tts(args.text, args.speaker_idx, args.language_idx, args.speaker_wav)
|
||||
|
||||
# save the results
|
||||
print(" > Saving output to {}".format(args.out_path))
|
||||
|
|
|
@ -11,7 +11,7 @@ from torch.utils.data import DataLoader
|
|||
|
||||
from TTS.speaker_encoder.dataset import SpeakerEncoderDataset
|
||||
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
||||
from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_model
|
||||
from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_speaker_encoder_model
|
||||
from TTS.speaker_encoder.utils.training import init_training
|
||||
from TTS.speaker_encoder.utils.visual import plot_embeddings
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
|
@ -151,7 +151,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
global meta_data_eval
|
||||
|
||||
ap = AudioProcessor(**c.audio)
|
||||
model = setup_model(c)
|
||||
model = setup_speaker_encoder_model(c)
|
||||
|
||||
optimizer = RAdam(model.parameters(), lr=c.lr)
|
||||
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import os
|
||||
|
||||
from TTS.config import load_config, register_config
|
||||
from TTS.config import check_config_and_model_args, get_from_config_or_model_args, load_config, register_config
|
||||
from TTS.trainer import Trainer, TrainingArgs
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.models import setup_model
|
||||
from TTS.tts.utils.languages import LanguageManager
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
@ -45,15 +46,32 @@ def main():
|
|||
ap = AudioProcessor(**config.audio)
|
||||
|
||||
# init speaker manager
|
||||
if config.use_speaker_embedding:
|
||||
if check_config_and_model_args(config, "use_speaker_embedding", True):
|
||||
speaker_manager = SpeakerManager(data_items=train_samples + eval_samples)
|
||||
elif config.use_d_vector_file:
|
||||
speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file)
|
||||
if hasattr(config, "model_args"):
|
||||
config.model_args.num_speakers = speaker_manager.num_speakers
|
||||
else:
|
||||
config.num_speakers = speaker_manager.num_speakers
|
||||
elif check_config_and_model_args(config, "use_d_vector_file", True):
|
||||
speaker_manager = SpeakerManager(d_vectors_file_path=get_from_config_or_model_args(config, "d_vector_file"))
|
||||
if hasattr(config, "model_args"):
|
||||
config.model_args.num_speakers = speaker_manager.num_speakers
|
||||
else:
|
||||
config.num_speakers = speaker_manager.num_speakers
|
||||
else:
|
||||
speaker_manager = None
|
||||
|
||||
if hasattr(config, "use_language_embedding") and config.use_language_embedding:
|
||||
language_manager = LanguageManager(config=config)
|
||||
if hasattr(config, "model_args"):
|
||||
config.model_args.num_languages = language_manager.num_languages
|
||||
else:
|
||||
config.num_languages = language_manager.num_languages
|
||||
else:
|
||||
language_manager = None
|
||||
|
||||
# init the model from config
|
||||
model = setup_model(config, speaker_manager)
|
||||
model = setup_model(config, speaker_manager, language_manager)
|
||||
|
||||
# init the trainer and 🚀
|
||||
trainer = Trainer(
|
||||
|
|
|
@ -95,3 +95,38 @@ def load_config(config_path: str) -> None:
|
|||
config = config_class()
|
||||
config.from_dict(config_dict)
|
||||
return config
|
||||
|
||||
|
||||
def check_config_and_model_args(config, arg_name, value):
|
||||
"""Check the give argument in `config.model_args` if exist or in `config` for
|
||||
the given value.
|
||||
|
||||
Return False if the argument does not exist in `config.model_args` or `config`.
|
||||
This is to patch up the compatibility between models with and without `model_args`.
|
||||
|
||||
TODO: Remove this in the future with a unified approach.
|
||||
"""
|
||||
if hasattr(config, "model_args"):
|
||||
if arg_name in config.model_args:
|
||||
return config.model_args[arg_name] == value
|
||||
if hasattr(config, arg_name):
|
||||
return config[arg_name] == value
|
||||
return False
|
||||
|
||||
|
||||
def get_from_config_or_model_args(config, arg_name):
|
||||
"""Get the given argument from `config.model_args` if exist or in `config`."""
|
||||
if hasattr(config, "model_args"):
|
||||
if arg_name in config.model_args:
|
||||
return config.model_args[arg_name]
|
||||
return config[arg_name]
|
||||
|
||||
|
||||
def get_from_config_or_model_args_with_default(config, arg_name, def_val):
|
||||
"""Get the given argument from `config.model_args` if exist or in `config`."""
|
||||
if hasattr(config, "model_args"):
|
||||
if arg_name in config.model_args:
|
||||
return config.model_args[arg_name]
|
||||
if hasattr(config, arg_name):
|
||||
return config[arg_name]
|
||||
return def_val
|
||||
|
|
|
@ -60,6 +60,12 @@ class BaseAudioConfig(Coqpit):
|
|||
trim_db (int):
|
||||
Silence threshold used for silence trimming. Defaults to 45.
|
||||
|
||||
do_rms_norm (bool, optional):
|
||||
enable/disable RMS volume normalization when loading an audio file. Defaults to False.
|
||||
|
||||
db_level (int, optional):
|
||||
dB level used for rms normalization. The range is -99 to 0. Defaults to None.
|
||||
|
||||
power (float):
|
||||
Exponent used for expanding spectrogra levels before running Griffin Lim. It helps to reduce the
|
||||
artifacts in the synthesized voice. Defaults to 1.5.
|
||||
|
@ -116,6 +122,9 @@ class BaseAudioConfig(Coqpit):
|
|||
# silence trimming
|
||||
do_trim_silence: bool = True
|
||||
trim_db: int = 45
|
||||
# rms volume normalization
|
||||
do_rms_norm: bool = False
|
||||
db_level: float = None
|
||||
# griffin-lim params
|
||||
power: float = 1.5
|
||||
griffin_lim_iters: int = 60
|
||||
|
@ -198,7 +207,8 @@ class BaseDatasetConfig(Coqpit):
|
|||
name: str = ""
|
||||
path: str = ""
|
||||
meta_file_train: str = ""
|
||||
ununsed_speakers: List[str] = None
|
||||
ignored_speakers: List[str] = None
|
||||
language: str = ""
|
||||
meta_file_val: str = ""
|
||||
meta_file_attn_mask: str = ""
|
||||
|
||||
|
@ -335,6 +345,8 @@ class BaseTrainingConfig(Coqpit):
|
|||
num_loader_workers: int = 0
|
||||
num_eval_loader_workers: int = 0
|
||||
use_noise_augment: bool = False
|
||||
use_language_weighted_sampler: bool = False
|
||||
|
||||
# paths
|
||||
output_path: str = None
|
||||
# distributed
|
||||
|
|
|
@ -100,7 +100,15 @@ if args.vocoder_path is not None:
|
|||
|
||||
# load models
|
||||
synthesizer = Synthesizer(
|
||||
model_path, config_path, speakers_file_path, vocoder_path, vocoder_config_path, use_cuda=args.use_cuda
|
||||
tts_checkpoint=model_path,
|
||||
tts_config_path=config_path,
|
||||
tts_speakers_file=speakers_file_path,
|
||||
tts_languages_file=None,
|
||||
vocoder_checkpoint=vocoder_path,
|
||||
vocoder_config=vocoder_config_path,
|
||||
encoder_checkpoint="",
|
||||
encoder_config="",
|
||||
use_cuda=args.use_cuda,
|
||||
)
|
||||
|
||||
use_multi_speaker = hasattr(synthesizer.tts_model, "num_speakers") and synthesizer.tts_model.num_speakers > 1
|
||||
|
@ -165,7 +173,7 @@ def tts():
|
|||
|
||||
style_wav = style_wav_uri_to_dict(style_wav)
|
||||
print(" > Model input: {}".format(text))
|
||||
wavs = synthesizer.tts(text, speaker_idx=speaker_idx, style_wav=style_wav)
|
||||
wavs = synthesizer.tts(text, speaker_name=speaker_idx, style_wav=style_wav)
|
||||
out = io.BytesIO()
|
||||
synthesizer.save_wav(wavs, out)
|
||||
return send_file(out, mimetype="audio/wav")
|
||||
|
|
|
@ -250,4 +250,4 @@ class SpeakerEncoderDataset(Dataset):
|
|||
feats = torch.stack(feats)
|
||||
labels = torch.stack(labels)
|
||||
|
||||
return feats.transpose(1, 2), labels
|
||||
return feats, labels
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch import nn
|
||||
|
||||
from TTS.speaker_encoder.models.resnet import PreEmphasis
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
||||
|
@ -33,9 +35,21 @@ class LSTMWithoutProjection(nn.Module):
|
|||
|
||||
|
||||
class LSTMSpeakerEncoder(nn.Module):
|
||||
def __init__(self, input_dim, proj_dim=256, lstm_dim=768, num_lstm_layers=3, use_lstm_with_projection=True):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim,
|
||||
proj_dim=256,
|
||||
lstm_dim=768,
|
||||
num_lstm_layers=3,
|
||||
use_lstm_with_projection=True,
|
||||
use_torch_spec=False,
|
||||
audio_config=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_lstm_with_projection = use_lstm_with_projection
|
||||
self.use_torch_spec = use_torch_spec
|
||||
self.audio_config = audio_config
|
||||
|
||||
layers = []
|
||||
# choise LSTM layer
|
||||
if use_lstm_with_projection:
|
||||
|
@ -46,6 +60,38 @@ class LSTMSpeakerEncoder(nn.Module):
|
|||
else:
|
||||
self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers)
|
||||
|
||||
self.instancenorm = nn.InstanceNorm1d(input_dim)
|
||||
|
||||
if self.use_torch_spec:
|
||||
self.torch_spec = torch.nn.Sequential(
|
||||
PreEmphasis(audio_config["preemphasis"]),
|
||||
# TorchSTFT(
|
||||
# n_fft=audio_config["fft_size"],
|
||||
# hop_length=audio_config["hop_length"],
|
||||
# win_length=audio_config["win_length"],
|
||||
# sample_rate=audio_config["sample_rate"],
|
||||
# window="hamming_window",
|
||||
# mel_fmin=0.0,
|
||||
# mel_fmax=None,
|
||||
# use_htk=True,
|
||||
# do_amp_to_db=False,
|
||||
# n_mels=audio_config["num_mels"],
|
||||
# power=2.0,
|
||||
# use_mel=True,
|
||||
# mel_norm=None,
|
||||
# )
|
||||
torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=audio_config["sample_rate"],
|
||||
n_fft=audio_config["fft_size"],
|
||||
win_length=audio_config["win_length"],
|
||||
hop_length=audio_config["hop_length"],
|
||||
window_fn=torch.hamming_window,
|
||||
n_mels=audio_config["num_mels"],
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.torch_spec = None
|
||||
|
||||
self._init_layers()
|
||||
|
||||
def _init_layers(self):
|
||||
|
@ -55,22 +101,33 @@ class LSTMSpeakerEncoder(nn.Module):
|
|||
elif "weight" in name:
|
||||
nn.init.xavier_normal_(param)
|
||||
|
||||
def forward(self, x):
|
||||
# TODO: implement state passing for lstms
|
||||
def forward(self, x, l2_norm=True):
|
||||
"""Forward pass of the model.
|
||||
|
||||
Args:
|
||||
x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
|
||||
to compute the spectrogram on-the-fly.
|
||||
l2_norm (bool): Whether to L2-normalize the outputs.
|
||||
|
||||
Shapes:
|
||||
- x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
|
||||
"""
|
||||
with torch.no_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
if self.use_torch_spec:
|
||||
x.squeeze_(1)
|
||||
x = self.torch_spec(x)
|
||||
x = self.instancenorm(x).transpose(1, 2)
|
||||
d = self.layers(x)
|
||||
if self.use_lstm_with_projection:
|
||||
d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1)
|
||||
else:
|
||||
d = d[:, -1]
|
||||
if l2_norm:
|
||||
d = torch.nn.functional.normalize(d, p=2, dim=1)
|
||||
return d
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, x):
|
||||
d = self.layers.forward(x)
|
||||
if self.use_lstm_with_projection:
|
||||
d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1)
|
||||
else:
|
||||
d = torch.nn.functional.normalize(d, p=2, dim=1)
|
||||
def inference(self, x, l2_norm=True):
|
||||
d = self.forward(x, l2_norm=l2_norm)
|
||||
return d
|
||||
|
||||
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True):
|
||||
|
|
|
@ -1,10 +1,25 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch import nn
|
||||
|
||||
# from TTS.utils.audio import TorchSTFT
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
||||
class PreEmphasis(nn.Module):
|
||||
def __init__(self, coefficient=0.97):
|
||||
super().__init__()
|
||||
self.coefficient = coefficient
|
||||
self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0))
|
||||
|
||||
def forward(self, x):
|
||||
assert len(x.size()) == 2
|
||||
|
||||
x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect")
|
||||
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
|
||||
|
||||
|
||||
class SELayer(nn.Module):
|
||||
def __init__(self, channel, reduction=8):
|
||||
super(SELayer, self).__init__()
|
||||
|
@ -70,12 +85,17 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
num_filters=[32, 64, 128, 256],
|
||||
encoder_type="ASP",
|
||||
log_input=False,
|
||||
use_torch_spec=False,
|
||||
audio_config=None,
|
||||
):
|
||||
super(ResNetSpeakerEncoder, self).__init__()
|
||||
|
||||
self.encoder_type = encoder_type
|
||||
self.input_dim = input_dim
|
||||
self.log_input = log_input
|
||||
self.use_torch_spec = use_torch_spec
|
||||
self.audio_config = audio_config
|
||||
|
||||
self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.bn1 = nn.BatchNorm2d(num_filters[0])
|
||||
|
@ -88,6 +108,36 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
|
||||
self.instancenorm = nn.InstanceNorm1d(input_dim)
|
||||
|
||||
if self.use_torch_spec:
|
||||
self.torch_spec = torch.nn.Sequential(
|
||||
PreEmphasis(audio_config["preemphasis"]),
|
||||
# TorchSTFT(
|
||||
# n_fft=audio_config["fft_size"],
|
||||
# hop_length=audio_config["hop_length"],
|
||||
# win_length=audio_config["win_length"],
|
||||
# sample_rate=audio_config["sample_rate"],
|
||||
# window="hamming_window",
|
||||
# mel_fmin=0.0,
|
||||
# mel_fmax=None,
|
||||
# use_htk=True,
|
||||
# do_amp_to_db=False,
|
||||
# n_mels=audio_config["num_mels"],
|
||||
# power=2.0,
|
||||
# use_mel=True,
|
||||
# mel_norm=None,
|
||||
# )
|
||||
torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=audio_config["sample_rate"],
|
||||
n_fft=audio_config["fft_size"],
|
||||
win_length=audio_config["win_length"],
|
||||
hop_length=audio_config["hop_length"],
|
||||
window_fn=torch.hamming_window,
|
||||
n_mels=audio_config["num_mels"],
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.torch_spec = None
|
||||
|
||||
outmap_size = int(self.input_dim / 8)
|
||||
|
||||
self.attention = nn.Sequential(
|
||||
|
@ -140,9 +190,23 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
return out
|
||||
|
||||
def forward(self, x, l2_norm=False):
|
||||
x = x.transpose(1, 2)
|
||||
"""Forward pass of the model.
|
||||
|
||||
Args:
|
||||
x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
|
||||
to compute the spectrogram on-the-fly.
|
||||
l2_norm (bool): Whether to L2-normalize the outputs.
|
||||
|
||||
Shapes:
|
||||
- x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
|
||||
"""
|
||||
with torch.no_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
x.squeeze_(1)
|
||||
# if you torch spec compute it otherwise use the mel spec computed by the AP
|
||||
if self.use_torch_spec:
|
||||
x = self.torch_spec(x)
|
||||
|
||||
if self.log_input:
|
||||
x = (x + 1e-6).log()
|
||||
x = self.instancenorm(x).unsqueeze(1)
|
||||
|
@ -175,11 +239,19 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True):
|
||||
def inference(self, x, l2_norm=False):
|
||||
return self.forward(x, l2_norm)
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True, l2_norm=True):
|
||||
"""
|
||||
Generate embeddings for a batch of utterances
|
||||
x: 1xTxD
|
||||
"""
|
||||
# map to the waveform size
|
||||
if self.use_torch_spec:
|
||||
num_frames = num_frames * self.audio_config["hop_length"]
|
||||
|
||||
max_len = x.shape[1]
|
||||
|
||||
if max_len < num_frames:
|
||||
|
@ -195,11 +267,10 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
frames_batch.append(frames)
|
||||
|
||||
frames_batch = torch.cat(frames_batch, dim=0)
|
||||
embeddings = self.forward(frames_batch, l2_norm=True)
|
||||
embeddings = self.inference(frames_batch, l2_norm=l2_norm)
|
||||
|
||||
if return_mean:
|
||||
embeddings = torch.mean(embeddings, dim=0, keepdim=True)
|
||||
|
||||
return embeddings
|
||||
|
||||
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
|
||||
|
|
|
@ -170,16 +170,24 @@ def to_camel(text):
|
|||
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
||||
|
||||
|
||||
def setup_model(c):
|
||||
if c.model_params["model_name"].lower() == "lstm":
|
||||
def setup_speaker_encoder_model(config: "Coqpit"):
|
||||
if config.model_params["model_name"].lower() == "lstm":
|
||||
model = LSTMSpeakerEncoder(
|
||||
c.model_params["input_dim"],
|
||||
c.model_params["proj_dim"],
|
||||
c.model_params["lstm_dim"],
|
||||
c.model_params["num_lstm_layers"],
|
||||
config.model_params["input_dim"],
|
||||
config.model_params["proj_dim"],
|
||||
config.model_params["lstm_dim"],
|
||||
config.model_params["num_lstm_layers"],
|
||||
use_torch_spec=config.model_params.get("use_torch_spec", False),
|
||||
audio_config=config.audio,
|
||||
)
|
||||
elif config.model_params["model_name"].lower() == "resnet":
|
||||
model = ResNetSpeakerEncoder(
|
||||
input_dim=config.model_params["input_dim"],
|
||||
proj_dim=config.model_params["proj_dim"],
|
||||
log_input=config.model_params.get("log_input", False),
|
||||
use_torch_spec=config.model_params.get("use_torch_spec", False),
|
||||
audio_config=config.audio,
|
||||
)
|
||||
elif c.model_params["model_name"].lower() == "resnet":
|
||||
model = ResNetSpeakerEncoder(input_dim=c.model_params["input_dim"], proj_dim=c.model_params["proj_dim"])
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
@ -202,7 +202,7 @@ class Trainer:
|
|||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
# copy training assets to the output folder
|
||||
copy_model_files(config, output_path, new_fields=None)
|
||||
copy_model_files(config, output_path)
|
||||
|
||||
# init class members
|
||||
self.args = args
|
||||
|
@ -439,7 +439,7 @@ class Trainer:
|
|||
if "scaler" in checkpoint and self.use_amp_scaler and checkpoint["scaler"]:
|
||||
print(" > Restoring Scaler...")
|
||||
scaler = _restore_list_objs(checkpoint["scaler"], scaler)
|
||||
except (KeyError, RuntimeError):
|
||||
except (KeyError, RuntimeError, ValueError):
|
||||
print(" > Partial model initialization...")
|
||||
model_dict = model.state_dict()
|
||||
model_dict = set_init_dict(model_dict, checkpoint["model"], config)
|
||||
|
|
|
@ -82,8 +82,14 @@ class VitsConfig(BaseTTSConfig):
|
|||
add_blank (bool):
|
||||
If true, a blank token is added in between every character. Defaults to `True`.
|
||||
|
||||
test_sentences (List[str]):
|
||||
List of sentences to be used for testing.
|
||||
test_sentences (List[List]):
|
||||
List of sentences with speaker and language information to be used for testing.
|
||||
|
||||
language_ids_file (str):
|
||||
Path to the language ids file.
|
||||
|
||||
use_language_embedding (bool):
|
||||
If true, language embedding is used. Defaults to `False`.
|
||||
|
||||
Note:
|
||||
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
|
||||
|
@ -117,6 +123,7 @@ class VitsConfig(BaseTTSConfig):
|
|||
feat_loss_alpha: float = 1.0
|
||||
mel_loss_alpha: float = 45.0
|
||||
dur_loss_alpha: float = 1.0
|
||||
speaker_encoder_loss_alpha: float = 1.0
|
||||
|
||||
# data loader params
|
||||
return_wav: bool = True
|
||||
|
@ -130,13 +137,13 @@ class VitsConfig(BaseTTSConfig):
|
|||
add_blank: bool = True
|
||||
|
||||
# testing
|
||||
test_sentences: List[str] = field(
|
||||
test_sentences: List[List] = field(
|
||||
default_factory=lambda: [
|
||||
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
"Be a voice, not an echo.",
|
||||
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||
"This cake is great. It's so delicious and moist.",
|
||||
"Prior to November 22, 1963.",
|
||||
["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent."],
|
||||
["Be a voice, not an echo."],
|
||||
["I'm sorry Dave. I'm afraid I can't do that."],
|
||||
["This cake is great. It's so delicious and moist."],
|
||||
["Prior to November 22, 1963."],
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -146,29 +153,15 @@ class VitsConfig(BaseTTSConfig):
|
|||
use_speaker_embedding: bool = False
|
||||
speakers_file: str = None
|
||||
speaker_embedding_channels: int = 256
|
||||
language_ids_file: str = None
|
||||
use_language_embedding: bool = False
|
||||
|
||||
# use d-vectors
|
||||
use_d_vector_file: bool = False
|
||||
d_vector_file: str = False
|
||||
d_vector_file: str = None
|
||||
d_vector_dim: int = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there.
|
||||
if self.num_speakers > 0:
|
||||
self.model_args.num_speakers = self.num_speakers
|
||||
|
||||
# speaker embedding settings
|
||||
if self.use_speaker_embedding:
|
||||
self.model_args.use_speaker_embedding = True
|
||||
if self.speakers_file:
|
||||
self.model_args.speakers_file = self.speakers_file
|
||||
if self.speaker_embedding_channels:
|
||||
self.model_args.speaker_embedding_channels = self.speaker_embedding_channels
|
||||
|
||||
# d-vector settings
|
||||
if self.use_d_vector_file:
|
||||
self.model_args.use_d_vector_file = True
|
||||
if self.d_vector_dim is not None and self.d_vector_dim > 0:
|
||||
self.model_args.d_vector_dim = self.d_vector_dim
|
||||
if self.d_vector_file:
|
||||
self.model_args.d_vector_file = self.d_vector_file
|
||||
for key, val in self.model_args.items():
|
||||
if hasattr(self, key):
|
||||
self[key] = val
|
||||
|
|
|
@ -67,16 +67,22 @@ def load_tts_samples(
|
|||
root_path = dataset["path"]
|
||||
meta_file_train = dataset["meta_file_train"]
|
||||
meta_file_val = dataset["meta_file_val"]
|
||||
ignored_speakers = dataset["ignored_speakers"]
|
||||
language = dataset["language"]
|
||||
|
||||
# setup the right data processor
|
||||
if formatter is None:
|
||||
formatter = _get_formatter_by_name(name)
|
||||
# load train set
|
||||
meta_data_train = formatter(root_path, meta_file_train)
|
||||
meta_data_train = formatter(root_path, meta_file_train, ignored_speakers=ignored_speakers)
|
||||
meta_data_train = [[*item, language] for item in meta_data_train]
|
||||
|
||||
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
|
||||
# load evaluation split if set
|
||||
if eval_split:
|
||||
if meta_file_val:
|
||||
meta_data_eval = formatter(root_path, meta_file_val)
|
||||
meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers)
|
||||
meta_data_eval = [[*item, language] for item in meta_data_eval]
|
||||
else:
|
||||
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
|
||||
meta_data_eval_all += meta_data_eval
|
||||
|
|
|
@ -37,6 +37,7 @@ class TTSDataset(Dataset):
|
|||
enable_eos_bos: bool = False,
|
||||
speaker_id_mapping: Dict = None,
|
||||
d_vector_mapping: Dict = None,
|
||||
language_id_mapping: Dict = None,
|
||||
use_noise_augment: bool = False,
|
||||
verbose: bool = False,
|
||||
):
|
||||
|
@ -122,7 +123,9 @@ class TTSDataset(Dataset):
|
|||
self.enable_eos_bos = enable_eos_bos
|
||||
self.speaker_id_mapping = speaker_id_mapping
|
||||
self.d_vector_mapping = d_vector_mapping
|
||||
self.language_id_mapping = language_id_mapping
|
||||
self.use_noise_augment = use_noise_augment
|
||||
|
||||
self.verbose = verbose
|
||||
self.input_seq_computed = False
|
||||
self.rescue_item_idx = 1
|
||||
|
@ -197,10 +200,10 @@ class TTSDataset(Dataset):
|
|||
def load_data(self, idx):
|
||||
item = self.items[idx]
|
||||
|
||||
if len(item) == 4:
|
||||
text, wav_file, speaker_name, attn_file = item
|
||||
if len(item) == 5:
|
||||
text, wav_file, speaker_name, language_name, attn_file = item
|
||||
else:
|
||||
text, wav_file, speaker_name = item
|
||||
text, wav_file, speaker_name, language_name = item
|
||||
attn = None
|
||||
raw_text = text
|
||||
|
||||
|
@ -218,7 +221,7 @@ class TTSDataset(Dataset):
|
|||
self.phoneme_cache_path,
|
||||
self.enable_eos_bos,
|
||||
self.cleaners,
|
||||
self.phoneme_language,
|
||||
language_name if language_name else self.phoneme_language,
|
||||
self.custom_symbols,
|
||||
self.characters,
|
||||
self.add_blank,
|
||||
|
@ -260,6 +263,7 @@ class TTSDataset(Dataset):
|
|||
"attn": attn,
|
||||
"item_idx": self.items[idx][1],
|
||||
"speaker_name": speaker_name,
|
||||
"language_name": language_name,
|
||||
"wav_file_name": os.path.basename(wav_file),
|
||||
}
|
||||
return sample
|
||||
|
@ -269,6 +273,7 @@ class TTSDataset(Dataset):
|
|||
item = args[0]
|
||||
func_args = args[1]
|
||||
text, wav_file, *_ = item
|
||||
func_args[3] = item[3]
|
||||
phonemes = TTSDataset._load_or_generate_phoneme_sequence(wav_file, text, *func_args)
|
||||
return phonemes
|
||||
|
||||
|
@ -335,7 +340,6 @@ class TTSDataset(Dataset):
|
|||
else:
|
||||
lengths = np.array([len(ins[0]) for ins in self.items])
|
||||
|
||||
# sort items based on the sequence length in ascending order
|
||||
idxs = np.argsort(lengths)
|
||||
new_items = []
|
||||
ignored = []
|
||||
|
@ -345,10 +349,7 @@ class TTSDataset(Dataset):
|
|||
ignored.append(idx)
|
||||
else:
|
||||
new_items.append(self.items[idx])
|
||||
|
||||
# shuffle batch groups
|
||||
# create batches with similar length items
|
||||
# the larger the `batch_group_size`, the higher the length variety in a batch.
|
||||
if self.batch_group_size > 0:
|
||||
for i in range(len(new_items) // self.batch_group_size):
|
||||
offset = i * self.batch_group_size
|
||||
|
@ -356,14 +357,8 @@ class TTSDataset(Dataset):
|
|||
temp_items = new_items[offset:end_offset]
|
||||
random.shuffle(temp_items)
|
||||
new_items[offset:end_offset] = temp_items
|
||||
|
||||
if len(new_items) == 0:
|
||||
raise RuntimeError(" [!] No items left after filtering.")
|
||||
|
||||
# update items to the new sorted items
|
||||
self.items = new_items
|
||||
|
||||
# logging
|
||||
if self.verbose:
|
||||
print(" | > Max length sequence: {}".format(np.max(lengths)))
|
||||
print(" | > Min length sequence: {}".format(np.min(lengths)))
|
||||
|
@ -413,9 +408,14 @@ class TTSDataset(Dataset):
|
|||
# convert list of dicts to dict of lists
|
||||
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
|
||||
|
||||
# get language ids from language names
|
||||
if self.language_id_mapping is not None:
|
||||
language_ids = [self.language_id_mapping[ln] for ln in batch["language_name"]]
|
||||
else:
|
||||
language_ids = None
|
||||
# get pre-computed d-vectors
|
||||
if self.d_vector_mapping is not None:
|
||||
wav_files_names = [batch["wav_file_name"][idx] for idx in ids_sorted_decreasing]
|
||||
wav_files_names = list(batch["wav_file_name"])
|
||||
d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names]
|
||||
else:
|
||||
d_vectors = None
|
||||
|
@ -466,6 +466,9 @@ class TTSDataset(Dataset):
|
|||
if speaker_ids is not None:
|
||||
speaker_ids = torch.LongTensor(speaker_ids)
|
||||
|
||||
if language_ids is not None:
|
||||
language_ids = torch.LongTensor(language_ids)
|
||||
|
||||
# compute linear spectrogram
|
||||
if self.compute_linear_spec:
|
||||
linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]]
|
||||
|
@ -528,6 +531,7 @@ class TTSDataset(Dataset):
|
|||
"waveform": wav_padded,
|
||||
"raw_text": batch["raw_text"],
|
||||
"pitch": pitch,
|
||||
"language_ids": language_ids,
|
||||
}
|
||||
|
||||
raise TypeError(
|
||||
|
@ -542,7 +546,6 @@ class TTSDataset(Dataset):
|
|||
|
||||
class PitchExtractor:
|
||||
"""Pitch Extractor for computing F0 from wav files.
|
||||
|
||||
Args:
|
||||
items (List[List]): Dataset samples.
|
||||
verbose (bool): Whether to print the progress.
|
||||
|
|
|
@ -12,7 +12,7 @@ from tqdm import tqdm
|
|||
########################
|
||||
|
||||
|
||||
def tweb(root_path, meta_file):
|
||||
def tweb(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||
"""Normalize TWEB dataset.
|
||||
https://www.kaggle.com/bryanpark/the-world-english-bible-speech-dataset
|
||||
"""
|
||||
|
@ -28,7 +28,7 @@ def tweb(root_path, meta_file):
|
|||
return items
|
||||
|
||||
|
||||
def mozilla(root_path, meta_file):
|
||||
def mozilla(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||
"""Normalizes Mozilla meta data files to TTS format"""
|
||||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
|
@ -43,7 +43,7 @@ def mozilla(root_path, meta_file):
|
|||
return items
|
||||
|
||||
|
||||
def mozilla_de(root_path, meta_file):
|
||||
def mozilla_de(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||
"""Normalizes Mozilla meta data files to TTS format"""
|
||||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
|
@ -59,7 +59,7 @@ def mozilla_de(root_path, meta_file):
|
|||
return items
|
||||
|
||||
|
||||
def mailabs(root_path, meta_files=None):
|
||||
def mailabs(root_path, meta_files=None, ignored_speakers=None):
|
||||
"""Normalizes M-AI-Labs meta data files to TTS format
|
||||
|
||||
Args:
|
||||
|
@ -68,25 +68,34 @@ def mailabs(root_path, meta_files=None):
|
|||
recursively. Defaults to None
|
||||
"""
|
||||
speaker_regex = re.compile("by_book/(male|female)/(?P<speaker_name>[^/]+)/")
|
||||
if meta_files is None:
|
||||
if not meta_files:
|
||||
csv_files = glob(root_path + "/**/metadata.csv", recursive=True)
|
||||
else:
|
||||
csv_files = meta_files
|
||||
|
||||
# meta_files = [f.strip() for f in meta_files.split(",")]
|
||||
items = []
|
||||
for csv_file in csv_files:
|
||||
txt_file = os.path.join(root_path, csv_file)
|
||||
if os.path.isfile(csv_file):
|
||||
txt_file = csv_file
|
||||
else:
|
||||
txt_file = os.path.join(root_path, csv_file)
|
||||
|
||||
folder = os.path.dirname(txt_file)
|
||||
# determine speaker based on folder structure...
|
||||
speaker_name_match = speaker_regex.search(txt_file)
|
||||
if speaker_name_match is None:
|
||||
continue
|
||||
speaker_name = speaker_name_match.group("speaker_name")
|
||||
# ignore speakers
|
||||
if isinstance(ignored_speakers, list):
|
||||
if speaker_name in ignored_speakers:
|
||||
continue
|
||||
print(" | > {}".format(csv_file))
|
||||
with open(txt_file, "r", encoding="utf-8") as ttf:
|
||||
for line in ttf:
|
||||
cols = line.split("|")
|
||||
if meta_files is None:
|
||||
if not meta_files:
|
||||
wav_file = os.path.join(folder, "wavs", cols[0] + ".wav")
|
||||
else:
|
||||
wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav")
|
||||
|
@ -94,11 +103,12 @@ def mailabs(root_path, meta_files=None):
|
|||
text = cols[1].strip()
|
||||
items.append([text, wav_file, speaker_name])
|
||||
else:
|
||||
raise RuntimeError("> File %s does not exist!" % (wav_file))
|
||||
# M-AI-Labs have some missing samples, so just print the warning
|
||||
print("> File %s does not exist!" % (wav_file))
|
||||
return items
|
||||
|
||||
|
||||
def ljspeech(root_path, meta_file):
|
||||
def ljspeech(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||
"""Normalizes the LJSpeech meta data file to TTS format
|
||||
https://keithito.com/LJ-Speech-Dataset/"""
|
||||
txt_file = os.path.join(root_path, meta_file)
|
||||
|
@ -113,7 +123,7 @@ def ljspeech(root_path, meta_file):
|
|||
return items
|
||||
|
||||
|
||||
def ljspeech_test(root_path, meta_file):
|
||||
def ljspeech_test(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||
"""Normalizes the LJSpeech meta data file for TTS testing
|
||||
https://keithito.com/LJ-Speech-Dataset/"""
|
||||
txt_file = os.path.join(root_path, meta_file)
|
||||
|
@ -127,7 +137,7 @@ def ljspeech_test(root_path, meta_file):
|
|||
return items
|
||||
|
||||
|
||||
def sam_accenture(root_path, meta_file):
|
||||
def sam_accenture(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||
"""Normalizes the sam-accenture meta data file to TTS format
|
||||
https://github.com/Sam-Accenture-Non-Binary-Voice/non-binary-voice-files"""
|
||||
xml_file = os.path.join(root_path, "voice_over_recordings", meta_file)
|
||||
|
@ -144,12 +154,12 @@ def sam_accenture(root_path, meta_file):
|
|||
return items
|
||||
|
||||
|
||||
def ruslan(root_path, meta_file):
|
||||
def ruslan(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||
"""Normalizes the RUSLAN meta data file to TTS format
|
||||
https://ruslan-corpus.github.io/"""
|
||||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
speaker_name = "ljspeech"
|
||||
speaker_name = "ruslan"
|
||||
with open(txt_file, "r", encoding="utf-8") as ttf:
|
||||
for line in ttf:
|
||||
cols = line.split("|")
|
||||
|
@ -159,11 +169,11 @@ def ruslan(root_path, meta_file):
|
|||
return items
|
||||
|
||||
|
||||
def css10(root_path, meta_file):
|
||||
def css10(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||
"""Normalizes the CSS10 dataset file to TTS format"""
|
||||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
speaker_name = "ljspeech"
|
||||
speaker_name = "css10"
|
||||
with open(txt_file, "r", encoding="utf-8") as ttf:
|
||||
for line in ttf:
|
||||
cols = line.split("|")
|
||||
|
@ -173,7 +183,7 @@ def css10(root_path, meta_file):
|
|||
return items
|
||||
|
||||
|
||||
def nancy(root_path, meta_file):
|
||||
def nancy(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||
"""Normalizes the Nancy meta data file to TTS format"""
|
||||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
|
@ -187,7 +197,7 @@ def nancy(root_path, meta_file):
|
|||
return items
|
||||
|
||||
|
||||
def common_voice(root_path, meta_file):
|
||||
def common_voice(root_path, meta_file, ignored_speakers=None):
|
||||
"""Normalize the common voice meta data file to TTS format."""
|
||||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
|
@ -198,15 +208,19 @@ def common_voice(root_path, meta_file):
|
|||
cols = line.split("\t")
|
||||
text = cols[2]
|
||||
speaker_name = cols[0]
|
||||
# ignore speakers
|
||||
if isinstance(ignored_speakers, list):
|
||||
if speaker_name in ignored_speakers:
|
||||
continue
|
||||
wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav"))
|
||||
items.append([text, wav_file, "MCV_" + speaker_name])
|
||||
return items
|
||||
|
||||
|
||||
def libri_tts(root_path, meta_files=None):
|
||||
def libri_tts(root_path, meta_files=None, ignored_speakers=None):
|
||||
"""https://ai.google/tools/datasets/libri-tts/"""
|
||||
items = []
|
||||
if meta_files is None:
|
||||
if not meta_files:
|
||||
meta_files = glob(f"{root_path}/**/*trans.tsv", recursive=True)
|
||||
else:
|
||||
if isinstance(meta_files, str):
|
||||
|
@ -222,13 +236,17 @@ def libri_tts(root_path, meta_files=None):
|
|||
_root_path = os.path.join(root_path, f"{speaker_name}/{chapter_id}")
|
||||
wav_file = os.path.join(_root_path, file_name + ".wav")
|
||||
text = cols[2]
|
||||
# ignore speakers
|
||||
if isinstance(ignored_speakers, list):
|
||||
if speaker_name in ignored_speakers:
|
||||
continue
|
||||
items.append([text, wav_file, "LTTS_" + speaker_name])
|
||||
for item in items:
|
||||
assert os.path.exists(item[1]), f" [!] wav files don't exist - {item[1]}"
|
||||
return items
|
||||
|
||||
|
||||
def custom_turkish(root_path, meta_file):
|
||||
def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
speaker_name = "turkish-female"
|
||||
|
@ -247,7 +265,7 @@ def custom_turkish(root_path, meta_file):
|
|||
|
||||
|
||||
# ToDo: add the dataset link when the dataset is released publicly
|
||||
def brspeech(root_path, meta_file):
|
||||
def brspeech(root_path, meta_file, ignored_speakers=None):
|
||||
"""BRSpeech 3.0 beta"""
|
||||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
|
@ -258,21 +276,25 @@ def brspeech(root_path, meta_file):
|
|||
cols = line.split("|")
|
||||
wav_file = os.path.join(root_path, cols[0])
|
||||
text = cols[2]
|
||||
speaker_name = cols[3]
|
||||
items.append([text, wav_file, speaker_name])
|
||||
speaker_id = cols[3]
|
||||
# ignore speakers
|
||||
if isinstance(ignored_speakers, list):
|
||||
if speaker_id in ignored_speakers:
|
||||
continue
|
||||
items.append([text, wav_file, speaker_id])
|
||||
return items
|
||||
|
||||
|
||||
def vctk(root_path, meta_files=None, wavs_path="wav48"):
|
||||
def vctk(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None):
|
||||
"""homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz"""
|
||||
test_speakers = meta_files
|
||||
items = []
|
||||
meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
|
||||
for meta_file in meta_files:
|
||||
_, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep)
|
||||
file_id = txt_file.split(".")[0]
|
||||
if isinstance(test_speakers, list): # if is list ignore this speakers ids
|
||||
if speaker_id in test_speakers:
|
||||
# ignore speakers
|
||||
if isinstance(ignored_speakers, list):
|
||||
if speaker_id in ignored_speakers:
|
||||
continue
|
||||
with open(meta_file, "r", encoding="utf-8") as file_text:
|
||||
text = file_text.readlines()[0]
|
||||
|
@ -282,15 +304,16 @@ def vctk(root_path, meta_files=None, wavs_path="wav48"):
|
|||
return items
|
||||
|
||||
|
||||
def vctk_slim(root_path, meta_files=None, wavs_path="wav48"):
|
||||
def vctk_slim(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None): # pylint: disable=unused-argument
|
||||
"""homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz"""
|
||||
items = []
|
||||
txt_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
|
||||
for text_file in txt_files:
|
||||
_, speaker_id, txt_file = os.path.relpath(text_file, root_path).split(os.sep)
|
||||
file_id = txt_file.split(".")[0]
|
||||
if isinstance(meta_files, list): # if is list ignore this speakers ids
|
||||
if speaker_id in meta_files:
|
||||
# ignore speakers
|
||||
if isinstance(ignored_speakers, list):
|
||||
if speaker_id in ignored_speakers:
|
||||
continue
|
||||
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
|
||||
items.append([None, wav_file, "VCTK_" + speaker_id])
|
||||
|
@ -298,7 +321,7 @@ def vctk_slim(root_path, meta_files=None, wavs_path="wav48"):
|
|||
return items
|
||||
|
||||
|
||||
def mls(root_path, meta_files=None):
|
||||
def mls(root_path, meta_files=None, ignored_speakers=None):
|
||||
"""http://www.openslr.org/94/"""
|
||||
items = []
|
||||
with open(os.path.join(root_path, meta_files), "r", encoding="utf-8") as meta:
|
||||
|
@ -307,19 +330,23 @@ def mls(root_path, meta_files=None):
|
|||
text = text[:-1]
|
||||
speaker, book, *_ = file.split("_")
|
||||
wav_file = os.path.join(root_path, os.path.dirname(meta_files), "audio", speaker, book, file + ".wav")
|
||||
# ignore speakers
|
||||
if isinstance(ignored_speakers, list):
|
||||
if speaker in ignored_speakers:
|
||||
continue
|
||||
items.append([text, wav_file, "MLS_" + speaker])
|
||||
return items
|
||||
|
||||
|
||||
# ======================================== VOX CELEB ===========================================
|
||||
def voxceleb2(root_path, meta_file=None):
|
||||
def voxceleb2(root_path, meta_file=None, **kwargs): # pylint: disable=unused-argument
|
||||
"""
|
||||
:param meta_file Used only for consistency with load_tts_samples api
|
||||
"""
|
||||
return _voxcel_x(root_path, meta_file, voxcel_idx="2")
|
||||
|
||||
|
||||
def voxceleb1(root_path, meta_file=None):
|
||||
def voxceleb1(root_path, meta_file=None, **kwargs): # pylint: disable=unused-argument
|
||||
"""
|
||||
:param meta_file Used only for consistency with load_tts_samples api
|
||||
"""
|
||||
|
@ -361,7 +388,7 @@ def _voxcel_x(root_path, meta_file, voxcel_idx):
|
|||
return [x.strip().split("|") for x in f.readlines()]
|
||||
|
||||
|
||||
def baker(root_path: str, meta_file: str) -> List[List[str]]:
|
||||
def baker(root_path: str, meta_file: str, **kwargs) -> List[List[str]]: # pylint: disable=unused-argument
|
||||
"""Normalizes the Baker meta data file to TTS format
|
||||
|
||||
Args:
|
||||
|
@ -381,7 +408,7 @@ def baker(root_path: str, meta_file: str) -> List[List[str]]:
|
|||
return items
|
||||
|
||||
|
||||
def kokoro(root_path, meta_file):
|
||||
def kokoro(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||
"""Japanese single-speaker dataset from https://github.com/kaiidams/Kokoro-Speech-Dataset"""
|
||||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
|
|
|
@ -18,8 +18,13 @@ class DurationPredictor(nn.Module):
|
|||
dropout_p (float): Dropout rate used after each conv layer.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p, cond_channels=None):
|
||||
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p, cond_channels=None, language_emb_dim=None):
|
||||
super().__init__()
|
||||
|
||||
# add language embedding dim in the input
|
||||
if language_emb_dim:
|
||||
in_channels += language_emb_dim
|
||||
|
||||
# class arguments
|
||||
self.in_channels = in_channels
|
||||
self.filter_channels = hidden_channels
|
||||
|
@ -36,7 +41,10 @@ class DurationPredictor(nn.Module):
|
|||
if cond_channels is not None and cond_channels != 0:
|
||||
self.cond = nn.Conv1d(cond_channels, in_channels, 1)
|
||||
|
||||
def forward(self, x, x_mask, g=None):
|
||||
if language_emb_dim != 0 and language_emb_dim is not None:
|
||||
self.cond_lang = nn.Conv1d(language_emb_dim, in_channels, 1)
|
||||
|
||||
def forward(self, x, x_mask, g=None, lang_emb=None):
|
||||
"""
|
||||
Shapes:
|
||||
- x: :math:`[B, C, T]`
|
||||
|
@ -45,6 +53,10 @@ class DurationPredictor(nn.Module):
|
|||
"""
|
||||
if g is not None:
|
||||
x = x + self.cond(g)
|
||||
|
||||
if lang_emb is not None:
|
||||
x = x + self.cond_lang(lang_emb)
|
||||
|
||||
x = self.conv_1(x * x_mask)
|
||||
x = torch.relu(x)
|
||||
x = self.norm_1(x)
|
||||
|
|
|
@ -532,6 +532,7 @@ class VitsGeneratorLoss(nn.Module):
|
|||
self.feat_loss_alpha = c.feat_loss_alpha
|
||||
self.dur_loss_alpha = c.dur_loss_alpha
|
||||
self.mel_loss_alpha = c.mel_loss_alpha
|
||||
self.spk_encoder_loss_alpha = c.speaker_encoder_loss_alpha
|
||||
self.stft = TorchSTFT(
|
||||
c.audio.fft_size,
|
||||
c.audio.hop_length,
|
||||
|
@ -585,6 +586,11 @@ class VitsGeneratorLoss(nn.Module):
|
|||
l = kl / torch.sum(z_mask)
|
||||
return l
|
||||
|
||||
@staticmethod
|
||||
def cosine_similarity_loss(gt_spk_emb, syn_spk_emb):
|
||||
l = -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean()
|
||||
return l
|
||||
|
||||
def forward(
|
||||
self,
|
||||
waveform,
|
||||
|
@ -598,6 +604,9 @@ class VitsGeneratorLoss(nn.Module):
|
|||
feats_disc_fake,
|
||||
feats_disc_real,
|
||||
loss_duration,
|
||||
use_speaker_encoder_as_loss=False,
|
||||
gt_spk_emb=None,
|
||||
syn_spk_emb=None,
|
||||
):
|
||||
"""
|
||||
Shapes:
|
||||
|
@ -618,13 +627,20 @@ class VitsGeneratorLoss(nn.Module):
|
|||
# compute mel spectrograms from the waveforms
|
||||
mel = self.stft(waveform)
|
||||
mel_hat = self.stft(waveform_hat)
|
||||
|
||||
# compute losses
|
||||
loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha
|
||||
loss_feat = self.feature_loss(feats_disc_fake, feats_disc_real) * self.feat_loss_alpha
|
||||
loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha
|
||||
loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha
|
||||
loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha
|
||||
loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha
|
||||
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
|
||||
|
||||
if use_speaker_encoder_as_loss:
|
||||
loss_se = self.cosine_similarity_loss(gt_spk_emb, syn_spk_emb) * self.spk_encoder_loss_alpha
|
||||
loss += loss_se
|
||||
return_dict["loss_spk_encoder"] = loss_se
|
||||
|
||||
# pass losses to the dict
|
||||
return_dict["loss_gen"] = loss_gen
|
||||
return_dict["loss_kl"] = loss_kl
|
||||
|
|
|
@ -37,6 +37,7 @@ class TextEncoder(nn.Module):
|
|||
num_layers: int,
|
||||
kernel_size: int,
|
||||
dropout_p: float,
|
||||
language_emb_dim: int = None,
|
||||
):
|
||||
"""Text Encoder for VITS model.
|
||||
|
||||
|
@ -55,8 +56,12 @@ class TextEncoder(nn.Module):
|
|||
self.hidden_channels = hidden_channels
|
||||
|
||||
self.emb = nn.Embedding(n_vocab, hidden_channels)
|
||||
|
||||
nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5)
|
||||
|
||||
if language_emb_dim:
|
||||
hidden_channels += language_emb_dim
|
||||
|
||||
self.encoder = RelativePositionTransformer(
|
||||
in_channels=hidden_channels,
|
||||
out_channels=hidden_channels,
|
||||
|
@ -72,13 +77,18 @@ class TextEncoder(nn.Module):
|
|||
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, x, x_lengths):
|
||||
def forward(self, x, x_lengths, lang_emb=None):
|
||||
"""
|
||||
Shapes:
|
||||
- x: :math:`[B, T]`
|
||||
- x_length: :math:`[B]`
|
||||
"""
|
||||
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
|
||||
|
||||
# concat the lang emb in embedding chars
|
||||
if lang_emb is not None:
|
||||
x = torch.cat((x, lang_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1)
|
||||
|
||||
x = torch.transpose(x, 1, -1) # [b, h, t]
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
|
||||
|
|
|
@ -178,10 +178,21 @@ class StochasticDurationPredictor(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_channels: int, hidden_channels: int, kernel_size: int, dropout_p: float, num_flows=4, cond_channels=0
|
||||
self,
|
||||
in_channels: int,
|
||||
hidden_channels: int,
|
||||
kernel_size: int,
|
||||
dropout_p: float,
|
||||
num_flows=4,
|
||||
cond_channels=0,
|
||||
language_emb_dim=0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# add language embedding dim in the input
|
||||
if language_emb_dim:
|
||||
in_channels += language_emb_dim
|
||||
|
||||
# condition encoder text
|
||||
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
||||
self.convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p)
|
||||
|
@ -205,7 +216,10 @@ class StochasticDurationPredictor(nn.Module):
|
|||
if cond_channels != 0 and cond_channels is not None:
|
||||
self.cond = nn.Conv1d(cond_channels, hidden_channels, 1)
|
||||
|
||||
def forward(self, x, x_mask, dr=None, g=None, reverse=False, noise_scale=1.0):
|
||||
if language_emb_dim != 0 and language_emb_dim is not None:
|
||||
self.cond_lang = nn.Conv1d(language_emb_dim, hidden_channels, 1)
|
||||
|
||||
def forward(self, x, x_mask, dr=None, g=None, lang_emb=None, reverse=False, noise_scale=1.0):
|
||||
"""
|
||||
Shapes:
|
||||
- x: :math:`[B, C, T]`
|
||||
|
@ -217,6 +231,10 @@ class StochasticDurationPredictor(nn.Module):
|
|||
x = self.pre(x)
|
||||
if g is not None:
|
||||
x = x + self.cond(g)
|
||||
|
||||
if lang_emb is not None:
|
||||
x = x + self.cond_lang(lang_emb)
|
||||
|
||||
x = self.convs(x, x_mask)
|
||||
x = self.proj(x) * x_mask
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ from TTS.tts.utils.text.symbols import make_symbols, parse_symbols
|
|||
from TTS.utils.generic_utils import find_module
|
||||
|
||||
|
||||
def setup_model(config, speaker_manager: "SpeakerManager" = None):
|
||||
def setup_model(config, speaker_manager: "SpeakerManager" = None, language_manager: "LanguageManager" = None):
|
||||
print(" > Using model: {}".format(config.model))
|
||||
# fetch the right model implementation.
|
||||
if "base_model" in config and config["base_model"] is not None:
|
||||
|
@ -31,7 +31,10 @@ def setup_model(config, speaker_manager: "SpeakerManager" = None):
|
|||
config.model_params.num_chars = num_chars
|
||||
if "model_args" in config:
|
||||
config.model_args.num_chars = num_chars
|
||||
model = MyModel(config, speaker_manager=speaker_manager)
|
||||
if config.model.lower() in ["vits"]: # If model supports multiple languages
|
||||
model = MyModel(config, speaker_manager=speaker_manager, language_manager=language_manager)
|
||||
else:
|
||||
model = MyModel(config, speaker_manager=speaker_manager)
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
@ -12,7 +12,8 @@ from torch.utils.data.distributed import DistributedSampler
|
|||
from TTS.model import BaseModel
|
||||
from TTS.tts.configs.shared_configs import CharactersConfig
|
||||
from TTS.tts.datasets.dataset import TTSDataset
|
||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
|
||||
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
|
||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.text import make_symbols
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
|
@ -73,9 +74,18 @@ class BaseTTS(BaseModel):
|
|||
def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager:
|
||||
return get_speaker_manager(config, restore_path, data, out_path)
|
||||
|
||||
def init_multispeaker(self, config: Coqpit):
|
||||
"""Init speaker embedding layer if `use_speaker_embedding` is True and set the expected speaker embedding
|
||||
vector dimension in the network. If model uses d-vectors, then it only sets the expected dimension.
|
||||
def init_multispeaker(self, config: Coqpit, data: List = None):
|
||||
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
|
||||
`in_channels` size of the connected layers.
|
||||
|
||||
This implementation yields 3 possible outcomes:
|
||||
|
||||
1. If `config.use_speaker_embedding` and `config.use_d_vector_file are False, do nothing.
|
||||
2. If `config.use_d_vector_file` is True, set expected embedding channel size to `config.d_vector_dim` or 512.
|
||||
3. If `config.use_speaker_embedding`, initialize a speaker embedding layer with channel size of
|
||||
`config.d_vector_dim` or 512.
|
||||
|
||||
You can override this function for new models.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
|
@ -97,6 +107,57 @@ class BaseTTS(BaseModel):
|
|||
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||
|
||||
def get_aux_input(self, **kwargs) -> Dict:
|
||||
"""Prepare and return `aux_input` used by `forward()`"""
|
||||
return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None}
|
||||
|
||||
def get_aux_input_from_test_setences(self, sentence_info):
|
||||
if hasattr(self.config, "model_args"):
|
||||
config = self.config.model_args
|
||||
else:
|
||||
config = self.config
|
||||
|
||||
# extract speaker and language info
|
||||
text, speaker_name, style_wav, language_name = None, None, None, None
|
||||
|
||||
if isinstance(sentence_info, list):
|
||||
if len(sentence_info) == 1:
|
||||
text = sentence_info[0]
|
||||
elif len(sentence_info) == 2:
|
||||
text, speaker_name = sentence_info
|
||||
elif len(sentence_info) == 3:
|
||||
text, speaker_name, style_wav = sentence_info
|
||||
elif len(sentence_info) == 4:
|
||||
text, speaker_name, style_wav, language_name = sentence_info
|
||||
else:
|
||||
text = sentence_info
|
||||
|
||||
# get speaker id/d_vector
|
||||
speaker_id, d_vector, language_id = None, None, None
|
||||
if hasattr(self, "speaker_manager"):
|
||||
if config.use_d_vector_file:
|
||||
if speaker_name is None:
|
||||
d_vector = self.speaker_manager.get_random_d_vector()
|
||||
else:
|
||||
d_vector = self.speaker_manager.get_d_vector_by_speaker(speaker_name)
|
||||
elif config.use_speaker_embedding:
|
||||
if speaker_name is None:
|
||||
speaker_id = self.speaker_manager.get_random_speaker_id()
|
||||
else:
|
||||
speaker_id = self.speaker_manager.speaker_ids[speaker_name]
|
||||
|
||||
# get language id
|
||||
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
|
||||
language_id = self.language_manager.language_id_mapping[language_name]
|
||||
|
||||
return {
|
||||
"text": text,
|
||||
"speaker_id": speaker_id,
|
||||
"style_wav": style_wav,
|
||||
"d_vector": d_vector,
|
||||
"language_id": language_id,
|
||||
}
|
||||
|
||||
def format_batch(self, batch: Dict) -> Dict:
|
||||
"""Generic batch formatting for `TTSDataset`.
|
||||
|
||||
|
@ -122,6 +183,7 @@ class BaseTTS(BaseModel):
|
|||
attn_mask = batch["attns"]
|
||||
waveform = batch["waveform"]
|
||||
pitch = batch["pitch"]
|
||||
language_ids = batch["language_ids"]
|
||||
max_text_length = torch.max(text_lengths.float())
|
||||
max_spec_length = torch.max(mel_lengths.float())
|
||||
|
||||
|
@ -169,6 +231,7 @@ class BaseTTS(BaseModel):
|
|||
"item_idx": item_idx,
|
||||
"waveform": waveform,
|
||||
"pitch": pitch,
|
||||
"language_ids": language_ids,
|
||||
}
|
||||
|
||||
def get_data_loader(
|
||||
|
@ -188,8 +251,15 @@ class BaseTTS(BaseModel):
|
|||
|
||||
# setup multi-speaker attributes
|
||||
if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
|
||||
speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None
|
||||
d_vector_mapping = self.speaker_manager.d_vectors if config.use_d_vector_file else None
|
||||
if hasattr(config, "model_args"):
|
||||
speaker_id_mapping = (
|
||||
self.speaker_manager.speaker_ids if config.model_args.use_speaker_embedding else None
|
||||
)
|
||||
d_vector_mapping = self.speaker_manager.d_vectors if config.model_args.use_d_vector_file else None
|
||||
config.use_d_vector_file = config.model_args.use_d_vector_file
|
||||
else:
|
||||
speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None
|
||||
d_vector_mapping = self.speaker_manager.d_vectors if config.use_d_vector_file else None
|
||||
else:
|
||||
speaker_id_mapping = None
|
||||
d_vector_mapping = None
|
||||
|
@ -199,7 +269,14 @@ class BaseTTS(BaseModel):
|
|||
if hasattr(self, "make_symbols"):
|
||||
custom_symbols = self.make_symbols(self.config)
|
||||
|
||||
# init dataset
|
||||
if hasattr(self, "language_manager"):
|
||||
language_id_mapping = (
|
||||
self.language_manager.language_id_mapping if self.args.use_language_embedding else None
|
||||
)
|
||||
else:
|
||||
language_id_mapping = None
|
||||
|
||||
# init dataloader
|
||||
dataset = TTSDataset(
|
||||
outputs_per_step=config.r if "r" in config else 1,
|
||||
text_cleaner=config.text_cleaner,
|
||||
|
@ -222,7 +299,8 @@ class BaseTTS(BaseModel):
|
|||
use_noise_augment=False if is_eval else config.use_noise_augment,
|
||||
verbose=verbose,
|
||||
speaker_id_mapping=speaker_id_mapping,
|
||||
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
|
||||
d_vector_mapping=d_vector_mapping,
|
||||
language_id_mapping=language_id_mapping,
|
||||
)
|
||||
|
||||
# pre-compute phonemes
|
||||
|
@ -268,7 +346,22 @@ class BaseTTS(BaseModel):
|
|||
# sampler for DDP
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
|
||||
# init dataloader
|
||||
# Weighted samplers
|
||||
assert not (
|
||||
num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False)
|
||||
), "language_weighted_sampler is not supported with DistributedSampler"
|
||||
assert not (
|
||||
num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False)
|
||||
), "speaker_weighted_sampler is not supported with DistributedSampler"
|
||||
|
||||
if sampler is None:
|
||||
if getattr(config, "use_language_weighted_sampler", False):
|
||||
print(" > Using Language weighted sampler")
|
||||
sampler = get_language_weighted_sampler(dataset.items)
|
||||
elif getattr(config, "use_speaker_weighted_sampler", False):
|
||||
print(" > Using Language weighted sampler")
|
||||
sampler = get_speaker_weighted_sampler(dataset.items)
|
||||
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
||||
|
@ -340,8 +433,7 @@ class BaseTTS(BaseModel):
|
|||
return test_figures, test_audios
|
||||
|
||||
def on_init_start(self, trainer):
|
||||
"""Save the speaker.json at the beginning of the training. And update the config.json with the
|
||||
speakers.json file path."""
|
||||
"""Save the speaker.json and language_ids.json at the beginning of the training. Also update both paths."""
|
||||
if self.speaker_manager is not None:
|
||||
output_path = os.path.join(trainer.output_path, "speakers.json")
|
||||
self.speaker_manager.save_speaker_ids_to_file(output_path)
|
||||
|
@ -352,3 +444,13 @@ class BaseTTS(BaseModel):
|
|||
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
|
||||
print(f" > `speakers.json` is saved to {output_path}.")
|
||||
print(" > `speakers_file` is updated in the config.json.")
|
||||
|
||||
if hasattr(self, "language_manager") and self.language_manager is not None:
|
||||
output_path = os.path.join(trainer.output_path, "language_ids.json")
|
||||
self.language_manager.save_language_ids_to_file(output_path)
|
||||
trainer.config.language_ids_file = output_path
|
||||
if hasattr(trainer.config, "model_args"):
|
||||
trainer.config.model_args.language_ids_file = output_path
|
||||
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
|
||||
print(f" > `language_ids.json` is saved to {output_path}.")
|
||||
print(" > `language_ids_file` is updated in the config.json.")
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
import math
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from itertools import chain
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
# import torchaudio
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.cuda.amp.autocast_mode import autocast
|
||||
from torch.nn import functional as F
|
||||
|
||||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||
|
@ -15,6 +17,7 @@ from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlock
|
|||
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
|
||||
from TTS.tts.utils.languages import LanguageManager
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.visual import plot_alignment
|
||||
|
@ -138,11 +141,50 @@ class VitsArgs(Coqpit):
|
|||
use_d_vector_file (bool):
|
||||
Enable/Disable the use of d-vectors for multi-speaker training. Defaults to False.
|
||||
|
||||
d_vector_file (str):
|
||||
Path to the file including pre-computed speaker embeddings. Defaults to None.
|
||||
|
||||
d_vector_dim (int):
|
||||
Number of d-vector channels. Defaults to 0.
|
||||
|
||||
detach_dp_input (bool):
|
||||
Detach duration predictor's input from the network for stopping the gradients. Defaults to True.
|
||||
|
||||
use_language_embedding (bool):
|
||||
Enable/Disable language embedding for multilingual models. Defaults to False.
|
||||
|
||||
embedded_language_dim (int):
|
||||
Number of language embedding channels. Defaults to 4.
|
||||
|
||||
num_languages (int):
|
||||
Number of languages for the language embedding layer. Defaults to 0.
|
||||
|
||||
language_ids_file (str):
|
||||
Path to the language mapping file for the Language Manager. Defaults to None.
|
||||
|
||||
use_speaker_encoder_as_loss (bool):
|
||||
Enable/Disable Speaker Consistency Loss (SCL). Defaults to False.
|
||||
|
||||
speaker_encoder_config_path (str):
|
||||
Path to the file speaker encoder config file, to use for SCL. Defaults to "".
|
||||
|
||||
speaker_encoder_model_path (str):
|
||||
Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "".
|
||||
|
||||
freeze_encoder (bool):
|
||||
Freeze the encoder weigths during training. Defaults to False.
|
||||
|
||||
freeze_DP (bool):
|
||||
Freeze the duration predictor weigths during training. Defaults to False.
|
||||
|
||||
freeze_PE (bool):
|
||||
Freeze the posterior encoder weigths during training. Defaults to False.
|
||||
|
||||
freeze_flow_encoder (bool):
|
||||
Freeze the flow encoder weigths during training. Defaults to False.
|
||||
|
||||
freeze_waveform_decoder (bool):
|
||||
Freeze the waveform decoder weigths during training. Defaults to False.
|
||||
"""
|
||||
|
||||
num_chars: int = 100
|
||||
|
@ -179,11 +221,23 @@ class VitsArgs(Coqpit):
|
|||
use_speaker_embedding: bool = False
|
||||
num_speakers: int = 0
|
||||
speakers_file: str = None
|
||||
d_vector_file: str = None
|
||||
speaker_embedding_channels: int = 256
|
||||
use_d_vector_file: bool = False
|
||||
d_vector_file: str = None
|
||||
d_vector_dim: int = 0
|
||||
detach_dp_input: bool = True
|
||||
use_language_embedding: bool = False
|
||||
embedded_language_dim: int = 4
|
||||
num_languages: int = 0
|
||||
language_ids_file: str = None
|
||||
use_speaker_encoder_as_loss: bool = False
|
||||
speaker_encoder_config_path: str = ""
|
||||
speaker_encoder_model_path: str = ""
|
||||
freeze_encoder: bool = False
|
||||
freeze_DP: bool = False
|
||||
freeze_PE: bool = False
|
||||
freeze_flow_decoder: bool = False
|
||||
freeze_waveform_decoder: bool = False
|
||||
|
||||
|
||||
class Vits(BaseTTS):
|
||||
|
@ -216,13 +270,18 @@ class Vits(BaseTTS):
|
|||
|
||||
# pylint: disable=dangerous-default-value
|
||||
|
||||
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: Coqpit,
|
||||
speaker_manager: SpeakerManager = None,
|
||||
language_manager: LanguageManager = None,
|
||||
):
|
||||
|
||||
super().__init__(config)
|
||||
|
||||
self.END2END = True
|
||||
|
||||
self.speaker_manager = speaker_manager
|
||||
self.language_manager = language_manager
|
||||
if config.__class__.__name__ == "VitsConfig":
|
||||
# loading from VitsConfig
|
||||
if "num_chars" not in config:
|
||||
|
@ -242,6 +301,7 @@ class Vits(BaseTTS):
|
|||
self.args = args
|
||||
|
||||
self.init_multispeaker(config)
|
||||
self.init_multilingual(config)
|
||||
|
||||
self.length_scale = args.length_scale
|
||||
self.noise_scale = args.noise_scale
|
||||
|
@ -260,6 +320,7 @@ class Vits(BaseTTS):
|
|||
args.num_layers_text_encoder,
|
||||
args.kernel_size_text_encoder,
|
||||
args.dropout_p_text_encoder,
|
||||
language_emb_dim=self.embedded_language_dim,
|
||||
)
|
||||
|
||||
self.posterior_encoder = PosteriorEncoder(
|
||||
|
@ -289,10 +350,16 @@ class Vits(BaseTTS):
|
|||
args.dropout_p_duration_predictor,
|
||||
4,
|
||||
cond_channels=self.embedded_speaker_dim,
|
||||
language_emb_dim=self.embedded_language_dim,
|
||||
)
|
||||
else:
|
||||
self.duration_predictor = DurationPredictor(
|
||||
args.hidden_channels, 256, 3, args.dropout_p_duration_predictor, cond_channels=self.embedded_speaker_dim
|
||||
args.hidden_channels,
|
||||
256,
|
||||
3,
|
||||
args.dropout_p_duration_predictor,
|
||||
cond_channels=self.embedded_speaker_dim,
|
||||
language_emb_dim=self.embedded_language_dim,
|
||||
)
|
||||
|
||||
self.waveform_decoder = HifiganGenerator(
|
||||
|
@ -318,54 +385,158 @@ class Vits(BaseTTS):
|
|||
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
||||
or with external `d_vectors` computed from a speaker encoder model.
|
||||
|
||||
You must provide a `speaker_manager` at initialization to set up the multi-speaker modules.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
|
||||
"""
|
||||
self.embedded_speaker_dim = 0
|
||||
if hasattr(config, "model_args"):
|
||||
config = config.model_args
|
||||
self.num_speakers = self.args.num_speakers
|
||||
|
||||
self.num_speakers = config.num_speakers
|
||||
if self.speaker_manager:
|
||||
self.num_speakers = self.speaker_manager.num_speakers
|
||||
|
||||
if config.use_speaker_embedding:
|
||||
self._init_speaker_embedding(config)
|
||||
if self.args.use_speaker_embedding:
|
||||
self._init_speaker_embedding()
|
||||
|
||||
if config.use_d_vector_file:
|
||||
self._init_d_vector(config)
|
||||
if self.args.use_d_vector_file:
|
||||
self._init_d_vector()
|
||||
|
||||
def _init_speaker_embedding(self, config):
|
||||
# TODO: make this a function
|
||||
if self.args.use_speaker_encoder_as_loss:
|
||||
if self.speaker_manager.speaker_encoder is None and (
|
||||
not config.speaker_encoder_model_path or not config.speaker_encoder_config_path
|
||||
):
|
||||
raise RuntimeError(
|
||||
" [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!"
|
||||
)
|
||||
|
||||
self.speaker_manager.speaker_encoder.eval()
|
||||
print(" > External Speaker Encoder Loaded !!")
|
||||
|
||||
if (
|
||||
hasattr(self.speaker_manager.speaker_encoder, "audio_config")
|
||||
and self.config.audio["sample_rate"] != self.speaker_manager.speaker_encoder.audio_config["sample_rate"]
|
||||
):
|
||||
# TODO: change this with torchaudio Resample
|
||||
raise RuntimeError(
|
||||
" [!] To use the speaker consistency loss (SCL) you need to have matching sample rates between the TTS model ({}) and the speaker encoder ({})!".format(
|
||||
self.config.audio["sample_rate"],
|
||||
self.speaker_manager.speaker_encoder.audio_config["sample_rate"],
|
||||
)
|
||||
)
|
||||
# pylint: disable=W0101,W0105
|
||||
""" self.audio_transform = torchaudio.transforms.Resample(
|
||||
orig_freq=self.audio_config["sample_rate"],
|
||||
new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"],
|
||||
)
|
||||
else:
|
||||
self.audio_transform = None
|
||||
"""
|
||||
|
||||
def _init_speaker_embedding(self):
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
if config.speakers_file is not None:
|
||||
self.speaker_manager = SpeakerManager(speaker_id_file_path=config.speakers_file)
|
||||
|
||||
if self.num_speakers > 0:
|
||||
print(" > initialization of speaker-embedding layers.")
|
||||
self.embedded_speaker_dim = config.speaker_embedding_channels
|
||||
self.embedded_speaker_dim = self.args.speaker_embedding_channels
|
||||
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||
|
||||
def _init_d_vector(self, config):
|
||||
def _init_d_vector(self):
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
if hasattr(self, "emb_g"):
|
||||
raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.")
|
||||
self.speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file)
|
||||
self.embedded_speaker_dim = config.d_vector_dim
|
||||
self.embedded_speaker_dim = self.args.d_vector_dim
|
||||
|
||||
def init_multilingual(self, config: Coqpit):
|
||||
"""Initialize multilingual modules of a model.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
"""
|
||||
if self.args.language_ids_file is not None:
|
||||
self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file)
|
||||
|
||||
if self.args.use_language_embedding and self.language_manager:
|
||||
self.num_languages = self.language_manager.num_languages
|
||||
self.embedded_language_dim = self.args.embedded_language_dim
|
||||
self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim)
|
||||
torch.nn.init.xavier_uniform_(self.emb_l.weight)
|
||||
else:
|
||||
self.embedded_language_dim = 0
|
||||
self.emb_l = None
|
||||
|
||||
@staticmethod
|
||||
def _set_cond_input(aux_input: Dict):
|
||||
"""Set the speaker conditioning input based on the multi-speaker mode."""
|
||||
sid, g = None, None
|
||||
sid, g, lid = None, None, None
|
||||
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
|
||||
sid = aux_input["speaker_ids"]
|
||||
if sid.ndim == 0:
|
||||
sid = sid.unsqueeze_(0)
|
||||
if "d_vectors" in aux_input and aux_input["d_vectors"] is not None:
|
||||
g = aux_input["d_vectors"]
|
||||
return sid, g
|
||||
g = F.normalize(aux_input["d_vectors"]).unsqueeze(-1)
|
||||
if g.ndim == 2:
|
||||
g = g.unsqueeze_(0)
|
||||
|
||||
if "language_ids" in aux_input and aux_input["language_ids"] is not None:
|
||||
lid = aux_input["language_ids"]
|
||||
if lid.ndim == 0:
|
||||
lid = lid.unsqueeze_(0)
|
||||
|
||||
return sid, g, lid
|
||||
|
||||
def get_aux_input(self, aux_input: Dict):
|
||||
sid, g = self._set_cond_input(aux_input)
|
||||
return {"speaker_id": sid, "style_wav": None, "d_vector": g}
|
||||
sid, g, lid = self._set_cond_input(aux_input)
|
||||
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid}
|
||||
|
||||
def get_aux_input_from_test_sentences(self, sentence_info):
|
||||
if hasattr(self.config, "model_args"):
|
||||
config = self.config.model_args
|
||||
else:
|
||||
config = self.config
|
||||
|
||||
# extract speaker and language info
|
||||
text, speaker_name, style_wav, language_name = None, None, None, None
|
||||
|
||||
if isinstance(sentence_info, list):
|
||||
if len(sentence_info) == 1:
|
||||
text = sentence_info[0]
|
||||
elif len(sentence_info) == 2:
|
||||
text, speaker_name = sentence_info
|
||||
elif len(sentence_info) == 3:
|
||||
text, speaker_name, style_wav = sentence_info
|
||||
elif len(sentence_info) == 4:
|
||||
text, speaker_name, style_wav, language_name = sentence_info
|
||||
else:
|
||||
text = sentence_info
|
||||
|
||||
# get speaker id/d_vector
|
||||
speaker_id, d_vector, language_id = None, None, None
|
||||
if hasattr(self, "speaker_manager"):
|
||||
if config.use_d_vector_file:
|
||||
if speaker_name is None:
|
||||
d_vector = self.speaker_manager.get_random_d_vector()
|
||||
else:
|
||||
d_vector = self.speaker_manager.get_mean_d_vector(speaker_name, num_samples=1, randomize=False)
|
||||
elif config.use_speaker_embedding:
|
||||
if speaker_name is None:
|
||||
speaker_id = self.speaker_manager.get_random_speaker_id()
|
||||
else:
|
||||
speaker_id = self.speaker_manager.speaker_ids[speaker_name]
|
||||
|
||||
# get language id
|
||||
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
|
||||
language_id = self.language_manager.language_id_mapping[language_name]
|
||||
|
||||
return {
|
||||
"text": text,
|
||||
"speaker_id": speaker_id,
|
||||
"style_wav": style_wav,
|
||||
"d_vector": d_vector,
|
||||
"language_id": language_id,
|
||||
"language_name": language_name,
|
||||
}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -373,7 +544,8 @@ class Vits(BaseTTS):
|
|||
x_lengths: torch.tensor,
|
||||
y: torch.tensor,
|
||||
y_lengths: torch.tensor,
|
||||
aux_input={"d_vectors": None, "speaker_ids": None},
|
||||
waveform: torch.tensor,
|
||||
aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None},
|
||||
) -> Dict:
|
||||
"""Forward pass of the model.
|
||||
|
||||
|
@ -382,7 +554,9 @@ class Vits(BaseTTS):
|
|||
x_lengths (torch.tensor): Batch of input character sequence lengths.
|
||||
y (torch.tensor): Batch of input spectrograms.
|
||||
y_lengths (torch.tensor): Batch of input spectrogram lengths.
|
||||
aux_input (dict, optional): Auxiliary inputs for multi-speaker training. Defaults to {"d_vectors": None, "speaker_ids": None}.
|
||||
waveform (torch.tensor): Batch of ground truth waveforms per sample.
|
||||
aux_input (dict, optional): Auxiliary inputs for multi-speaker and multi-lingual training.
|
||||
Defaults to {"d_vectors": None, "speaker_ids": None, "language_ids": None}.
|
||||
|
||||
Returns:
|
||||
Dict: model outputs keyed by the output name.
|
||||
|
@ -392,17 +566,24 @@ class Vits(BaseTTS):
|
|||
- x_lengths: :math:`[B]`
|
||||
- y: :math:`[B, C, T_spec]`
|
||||
- y_lengths: :math:`[B]`
|
||||
- waveform: :math:`[B, T_wav, 1]`
|
||||
- d_vectors: :math:`[B, C, 1]`
|
||||
- speaker_ids: :math:`[B]`
|
||||
- language_ids: :math:`[B]`
|
||||
"""
|
||||
outputs = {}
|
||||
sid, g = self._set_cond_input(aux_input)
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths)
|
||||
|
||||
sid, g, lid = self._set_cond_input(aux_input)
|
||||
# speaker embedding
|
||||
if self.num_speakers > 1 and sid is not None:
|
||||
if self.args.use_speaker_embedding and sid is not None:
|
||||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
||||
|
||||
# language embedding
|
||||
lang_emb = None
|
||||
if self.args.use_language_embedding and lid is not None:
|
||||
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
||||
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
|
||||
|
||||
# posterior encoder
|
||||
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
|
||||
|
||||
|
@ -428,6 +609,7 @@ class Vits(BaseTTS):
|
|||
x_mask,
|
||||
attn_durations,
|
||||
g=g.detach() if self.args.detach_dp_input and g is not None else g,
|
||||
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
|
||||
)
|
||||
loss_duration = loss_duration / torch.sum(x_mask)
|
||||
else:
|
||||
|
@ -436,6 +618,7 @@ class Vits(BaseTTS):
|
|||
x.detach() if self.args.detach_dp_input else x,
|
||||
x_mask,
|
||||
g=g.detach() if self.args.detach_dp_input and g is not None else g,
|
||||
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
|
||||
)
|
||||
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
|
||||
outputs["loss_duration"] = loss_duration
|
||||
|
@ -447,40 +630,73 @@ class Vits(BaseTTS):
|
|||
# select a random feature segment for the waveform decoder
|
||||
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size)
|
||||
o = self.waveform_decoder(z_slice, g=g)
|
||||
|
||||
wav_seg = segment(
|
||||
waveform,
|
||||
slice_ids * self.config.audio.hop_length,
|
||||
self.args.spec_segment_size * self.config.audio.hop_length,
|
||||
)
|
||||
|
||||
if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None:
|
||||
# concate generated and GT waveforms
|
||||
wavs_batch = torch.cat((wav_seg, o), dim=0)
|
||||
|
||||
# resample audio to speaker encoder sample_rate
|
||||
# pylint: disable=W0105
|
||||
"""if self.audio_transform is not None:
|
||||
wavs_batch = self.audio_transform(wavs_batch)"""
|
||||
|
||||
pred_embs = self.speaker_manager.speaker_encoder.forward(wavs_batch, l2_norm=True)
|
||||
|
||||
# split generated and GT speaker embeddings
|
||||
gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0)
|
||||
else:
|
||||
gt_spk_emb, syn_spk_emb = None, None
|
||||
|
||||
outputs.update(
|
||||
{
|
||||
"model_outputs": o,
|
||||
"alignments": attn.squeeze(1),
|
||||
"slice_ids": slice_ids,
|
||||
"z": z,
|
||||
"z_p": z_p,
|
||||
"m_p": m_p,
|
||||
"logs_p": logs_p,
|
||||
"m_q": m_q,
|
||||
"logs_q": logs_q,
|
||||
"waveform_seg": wav_seg,
|
||||
"gt_spk_emb": gt_spk_emb,
|
||||
"syn_spk_emb": syn_spk_emb,
|
||||
}
|
||||
)
|
||||
return outputs
|
||||
|
||||
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}):
|
||||
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}):
|
||||
"""
|
||||
Shapes:
|
||||
- x: :math:`[B, T_seq]`
|
||||
- d_vectors: :math:`[B, C, 1]`
|
||||
- speaker_ids: :math:`[B]`
|
||||
"""
|
||||
sid, g = self._set_cond_input(aux_input)
|
||||
sid, g, lid = self._set_cond_input(aux_input)
|
||||
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
|
||||
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths)
|
||||
|
||||
if self.num_speakers > 0 and sid is not None:
|
||||
# speaker embedding
|
||||
if self.args.use_speaker_embedding and sid is not None:
|
||||
g = self.emb_g(sid).unsqueeze(-1)
|
||||
|
||||
# language embedding
|
||||
lang_emb = None
|
||||
if self.args.use_language_embedding and lid is not None:
|
||||
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
||||
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
|
||||
|
||||
if self.args.use_sdp:
|
||||
logw = self.duration_predictor(x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp)
|
||||
logw = self.duration_predictor(
|
||||
x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb
|
||||
)
|
||||
else:
|
||||
logw = self.duration_predictor(x, x_mask, g=g)
|
||||
logw = self.duration_predictor(x, x_mask, g=g, lang_emb=lang_emb)
|
||||
|
||||
w = torch.exp(logw) * x_mask * self.length_scale
|
||||
w_ceil = torch.ceil(w)
|
||||
|
@ -499,12 +715,30 @@ class Vits(BaseTTS):
|
|||
outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p}
|
||||
return outputs
|
||||
|
||||
def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
|
||||
"""TODO: create an end-point for voice conversion"""
|
||||
def voice_conversion(self, y, y_lengths, speaker_cond_src, speaker_cond_tgt):
|
||||
"""Forward pass for voice conversion
|
||||
|
||||
TODO: create an end-point for voice conversion
|
||||
|
||||
Args:
|
||||
y (Tensor): Reference spectrograms. Tensor of shape [B, T, C]
|
||||
y_lengths (Tensor): Length of each reference spectrogram. Tensor of shape [B]
|
||||
speaker_cond_src (Tensor): Reference speaker ID. Tensor of shape [B,]
|
||||
speaker_cond_tgt (Tensor): Target speaker ID. Tensor of shape [B,]
|
||||
"""
|
||||
assert self.num_speakers > 0, "num_speakers have to be larger than 0."
|
||||
g_src = self.emb_g(sid_src).unsqueeze(-1)
|
||||
g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
|
||||
z, _, _, y_mask = self.enc_q(y, y_lengths, g=g_src)
|
||||
|
||||
# speaker embedding
|
||||
if self.args.use_speaker_embedding and not self.args.use_d_vector_file:
|
||||
g_src = self.emb_g(speaker_cond_src).unsqueeze(-1)
|
||||
g_tgt = self.emb_g(speaker_cond_tgt).unsqueeze(-1)
|
||||
elif self.args.use_speaker_embedding and self.args.use_d_vector_file:
|
||||
g_src = F.normalize(speaker_cond_src).unsqueeze(-1)
|
||||
g_tgt = F.normalize(speaker_cond_tgt).unsqueeze(-1)
|
||||
else:
|
||||
raise RuntimeError(" [!] Voice conversion is only supported on multi-speaker models.")
|
||||
|
||||
z, _, _, y_mask = self.posterior_encoder(y.transpose(1, 2), y_lengths, g=g_src)
|
||||
z_p = self.flow(z, y_mask, g=g_src)
|
||||
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
|
||||
o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt)
|
||||
|
@ -525,6 +759,30 @@ class Vits(BaseTTS):
|
|||
if optimizer_idx not in [0, 1]:
|
||||
raise ValueError(" [!] Unexpected `optimizer_idx`.")
|
||||
|
||||
if self.args.freeze_encoder:
|
||||
for param in self.text_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if hasattr(self, "emb_l"):
|
||||
for param in self.emb_l.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if self.args.freeze_PE:
|
||||
for param in self.posterior_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if self.args.freeze_DP:
|
||||
for param in self.duration_predictor.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if self.args.freeze_flow_decoder:
|
||||
for param in self.flow.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if self.args.freeze_waveform_decoder:
|
||||
for param in self.waveform_decoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if optimizer_idx == 0:
|
||||
text_input = batch["text_input"]
|
||||
text_lengths = batch["text_lengths"]
|
||||
|
@ -532,6 +790,7 @@ class Vits(BaseTTS):
|
|||
linear_input = batch["linear_input"]
|
||||
d_vectors = batch["d_vectors"]
|
||||
speaker_ids = batch["speaker_ids"]
|
||||
language_ids = batch["language_ids"]
|
||||
waveform = batch["waveform"]
|
||||
|
||||
# generator pass
|
||||
|
@ -540,31 +799,26 @@ class Vits(BaseTTS):
|
|||
text_lengths,
|
||||
linear_input.transpose(1, 2),
|
||||
mel_lengths,
|
||||
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids},
|
||||
waveform.transpose(1, 2),
|
||||
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids},
|
||||
)
|
||||
|
||||
# cache tensors for the discriminator
|
||||
self.y_disc_cache = None
|
||||
self.wav_seg_disc_cache = None
|
||||
self.y_disc_cache = outputs["model_outputs"]
|
||||
wav_seg = segment(
|
||||
waveform.transpose(1, 2),
|
||||
outputs["slice_ids"] * self.config.audio.hop_length,
|
||||
self.args.spec_segment_size * self.config.audio.hop_length,
|
||||
)
|
||||
self.wav_seg_disc_cache = wav_seg
|
||||
outputs["waveform_seg"] = wav_seg
|
||||
self.wav_seg_disc_cache = outputs["waveform_seg"]
|
||||
|
||||
# compute discriminator scores and features
|
||||
outputs["scores_disc_fake"], outputs["feats_disc_fake"], _, outputs["feats_disc_real"] = self.disc(
|
||||
outputs["model_outputs"], wav_seg
|
||||
outputs["model_outputs"], outputs["waveform_seg"]
|
||||
)
|
||||
|
||||
# compute losses
|
||||
with autocast(enabled=False): # use float32 for the criterion
|
||||
loss_dict = criterion[optimizer_idx](
|
||||
waveform_hat=outputs["model_outputs"].float(),
|
||||
waveform=wav_seg.float(),
|
||||
waveform=outputs["waveform_seg"].float(),
|
||||
z_p=outputs["z_p"].float(),
|
||||
logs_q=outputs["logs_q"].float(),
|
||||
m_p=outputs["m_p"].float(),
|
||||
|
@ -574,6 +828,9 @@ class Vits(BaseTTS):
|
|||
feats_disc_fake=outputs["feats_disc_fake"],
|
||||
feats_disc_real=outputs["feats_disc_real"],
|
||||
loss_duration=outputs["loss_duration"],
|
||||
use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss,
|
||||
gt_spk_emb=outputs["gt_spk_emb"],
|
||||
syn_spk_emb=outputs["syn_spk_emb"],
|
||||
)
|
||||
|
||||
elif optimizer_idx == 1:
|
||||
|
@ -651,32 +908,28 @@ class Vits(BaseTTS):
|
|||
test_audios = {}
|
||||
test_figures = {}
|
||||
test_sentences = self.config.test_sentences
|
||||
aux_inputs = {
|
||||
"speaker_id": None
|
||||
if not self.config.use_speaker_embedding
|
||||
else random.sample(sorted(self.speaker_manager.speaker_ids.values()), 1),
|
||||
"d_vector": None
|
||||
if not self.config.use_d_vector_file
|
||||
else random.samples(sorted(self.speaker_manager.d_vectors.values()), 1),
|
||||
"style_wav": None,
|
||||
}
|
||||
for idx, sen in enumerate(test_sentences):
|
||||
wav, alignment, _, _ = synthesis(
|
||||
self,
|
||||
sen,
|
||||
self.config,
|
||||
"cuda" in str(next(self.parameters()).device),
|
||||
ap,
|
||||
speaker_id=aux_inputs["speaker_id"],
|
||||
d_vector=aux_inputs["d_vector"],
|
||||
style_wav=aux_inputs["style_wav"],
|
||||
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False,
|
||||
).values()
|
||||
|
||||
test_audios["{}-audio".format(idx)] = wav
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False)
|
||||
for idx, s_info in enumerate(test_sentences):
|
||||
try:
|
||||
aux_inputs = self.get_aux_input_from_test_sentences(s_info)
|
||||
wav, alignment, _, _ = synthesis(
|
||||
self,
|
||||
aux_inputs["text"],
|
||||
self.config,
|
||||
"cuda" in str(next(self.parameters()).device),
|
||||
ap,
|
||||
speaker_id=aux_inputs["speaker_id"],
|
||||
d_vector=aux_inputs["d_vector"],
|
||||
style_wav=aux_inputs["style_wav"],
|
||||
language_id=aux_inputs["language_id"],
|
||||
language_name=aux_inputs["language_name"],
|
||||
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False,
|
||||
).values()
|
||||
test_audios["{}-audio".format(idx)] = wav
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False)
|
||||
except: # pylint: disable=bare-except
|
||||
print(" !! Error creating Test Sentence -", idx)
|
||||
return test_figures, test_audios
|
||||
|
||||
def get_optimizer(self) -> List:
|
||||
|
@ -695,8 +948,12 @@ class Vits(BaseTTS):
|
|||
self.waveform_decoder.parameters(),
|
||||
)
|
||||
# add the speaker embedding layer
|
||||
if hasattr(self, "emb_g"):
|
||||
if hasattr(self, "emb_g") and self.args.use_speaker_embedding and not self.args.use_d_vector_file:
|
||||
gen_parameters = chain(gen_parameters, self.emb_g.parameters())
|
||||
# add the language embedding layer
|
||||
if hasattr(self, "emb_l") and self.args.use_language_embedding:
|
||||
gen_parameters = chain(gen_parameters, self.emb_l.parameters())
|
||||
|
||||
optimizer0 = get_optimizer(
|
||||
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
|
||||
)
|
||||
|
@ -769,6 +1026,10 @@ class Vits(BaseTTS):
|
|||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
"""Load the model checkpoint and setup for training or inference"""
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
# compat band-aid for the pre-trained models to not use the encoder baked into the model
|
||||
# TODO: consider baking the speaker encoder into the model and call it from there.
|
||||
# as it is probably easier for model distribution.
|
||||
state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k}
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
|
|
|
@ -0,0 +1,122 @@
|
|||
import json
|
||||
import os
|
||||
from typing import Dict, List
|
||||
|
||||
import fsspec
|
||||
import numpy as np
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch.utils.data.sampler import WeightedRandomSampler
|
||||
|
||||
|
||||
class LanguageManager:
|
||||
"""Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information
|
||||
in a way that can be queried by language.
|
||||
|
||||
Args:
|
||||
language_ids_file_path (str, optional): Path to the metafile that maps language names to ids used by
|
||||
TTS models. Defaults to "".
|
||||
config (Coqpit, optional): Coqpit config that contains the language information in the datasets filed.
|
||||
Defaults to None.
|
||||
|
||||
Examples:
|
||||
>>> manager = LanguageManager(language_ids_file_path=language_ids_file_path)
|
||||
>>> language_id_mapper = manager.language_ids
|
||||
"""
|
||||
|
||||
language_id_mapping: Dict = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
language_ids_file_path: str = "",
|
||||
config: Coqpit = None,
|
||||
):
|
||||
self.language_id_mapping = {}
|
||||
if language_ids_file_path:
|
||||
self.set_language_ids_from_file(language_ids_file_path)
|
||||
|
||||
if config:
|
||||
self.set_language_ids_from_config(config)
|
||||
|
||||
@staticmethod
|
||||
def _load_json(json_file_path: str) -> Dict:
|
||||
with fsspec.open(json_file_path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
@staticmethod
|
||||
def _save_json(json_file_path: str, data: dict) -> None:
|
||||
with fsspec.open(json_file_path, "w") as f:
|
||||
json.dump(data, f, indent=4)
|
||||
|
||||
@property
|
||||
def num_languages(self) -> int:
|
||||
return len(list(self.language_id_mapping.keys()))
|
||||
|
||||
@property
|
||||
def language_names(self) -> List:
|
||||
return list(self.language_id_mapping.keys())
|
||||
|
||||
@staticmethod
|
||||
def parse_language_ids_from_config(c: Coqpit) -> Dict:
|
||||
"""Set language id from config.
|
||||
|
||||
Args:
|
||||
c (Coqpit): Config
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, int]: Language ID mapping and the number of languages.
|
||||
"""
|
||||
languages = set({})
|
||||
for dataset in c.datasets:
|
||||
if "language" in dataset:
|
||||
languages.add(dataset["language"])
|
||||
else:
|
||||
raise ValueError(f"Dataset {dataset['name']} has no language specified.")
|
||||
return {name: i for i, name in enumerate(sorted(list(languages)))}
|
||||
|
||||
def set_language_ids_from_config(self, c: Coqpit) -> None:
|
||||
"""Set language IDs from config samples.
|
||||
|
||||
Args:
|
||||
items (List): Data sampled returned by `load_meta_data()`.
|
||||
"""
|
||||
self.language_id_mapping = self.parse_language_ids_from_config(c)
|
||||
|
||||
def set_language_ids_from_file(self, file_path: str) -> None:
|
||||
"""Load language ids from a json file.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the target json file.
|
||||
"""
|
||||
self.language_id_mapping = self._load_json(file_path)
|
||||
|
||||
def save_language_ids_to_file(self, file_path: str) -> None:
|
||||
"""Save language IDs to a json file.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the output file.
|
||||
"""
|
||||
self._save_json(file_path, self.language_id_mapping)
|
||||
|
||||
|
||||
def _set_file_path(path):
|
||||
"""Find the language_ids.json under the given path or the above it.
|
||||
Intended to band aid the different paths returned in restored and continued training."""
|
||||
path_restore = os.path.join(os.path.dirname(path), "language_ids.json")
|
||||
path_continue = os.path.join(path, "language_ids.json")
|
||||
fs = fsspec.get_mapper(path).fs
|
||||
if fs.exists(path_restore):
|
||||
return path_restore
|
||||
if fs.exists(path_continue):
|
||||
return path_continue
|
||||
return None
|
||||
|
||||
|
||||
def get_language_weighted_sampler(items: list):
|
||||
language_names = np.array([item[3] for item in items])
|
||||
unique_language_names = np.unique(language_names).tolist()
|
||||
language_ids = [unique_language_names.index(l) for l in language_names]
|
||||
language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names])
|
||||
weight_language = 1.0 / language_count
|
||||
dataset_samples_weight = torch.from_numpy(np.array([weight_language[l] for l in language_ids])).double()
|
||||
return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight))
|
|
@ -7,9 +7,10 @@ import fsspec
|
|||
import numpy as np
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch.utils.data.sampler import WeightedRandomSampler
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.speaker_encoder.utils.generic_utils import setup_model
|
||||
from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
|
@ -161,8 +162,10 @@ class SpeakerManager:
|
|||
file_path (str): Path to the target json file.
|
||||
"""
|
||||
self.d_vectors = self._load_json(file_path)
|
||||
|
||||
speakers = sorted({x["name"] for x in self.d_vectors.values()})
|
||||
self.speaker_ids = {name: i for i, name in enumerate(speakers)}
|
||||
|
||||
self.clip_ids = list(set(sorted(clip_name for clip_name in self.d_vectors.keys())))
|
||||
|
||||
def get_d_vector_by_clip(self, clip_idx: str) -> List:
|
||||
|
@ -209,6 +212,32 @@ class SpeakerManager:
|
|||
d_vectors = np.stack(d_vectors[:num_samples]).mean(0)
|
||||
return d_vectors
|
||||
|
||||
def get_random_speaker_id(self) -> Any:
|
||||
"""Get a random d_vector.
|
||||
|
||||
Args:
|
||||
|
||||
Returns:
|
||||
np.ndarray: d_vector.
|
||||
"""
|
||||
if self.speaker_ids:
|
||||
return self.speaker_ids[random.choices(list(self.speaker_ids.keys()))[0]]
|
||||
|
||||
return None
|
||||
|
||||
def get_random_d_vector(self) -> Any:
|
||||
"""Get a random D ID.
|
||||
|
||||
Args:
|
||||
|
||||
Returns:
|
||||
np.ndarray: d_vector.
|
||||
"""
|
||||
if self.d_vectors:
|
||||
return self.d_vectors[random.choices(list(self.d_vectors.keys()))[0]]["embedding"]
|
||||
|
||||
return None
|
||||
|
||||
def get_speakers(self) -> List:
|
||||
return self.speaker_ids
|
||||
|
||||
|
@ -223,18 +252,15 @@ class SpeakerManager:
|
|||
config_path (str): Model config file path.
|
||||
"""
|
||||
self.speaker_encoder_config = load_config(config_path)
|
||||
self.speaker_encoder = setup_model(self.speaker_encoder_config)
|
||||
self.speaker_encoder = setup_speaker_encoder_model(self.speaker_encoder_config)
|
||||
self.speaker_encoder.load_checkpoint(config_path, model_path, eval=True, use_cuda=self.use_cuda)
|
||||
self.speaker_encoder_ap = AudioProcessor(**self.speaker_encoder_config.audio)
|
||||
# normalize the input audio level and trim silences
|
||||
# self.speaker_encoder_ap.do_sound_norm = True
|
||||
# self.speaker_encoder_ap.do_trim_silence = True
|
||||
|
||||
def compute_d_vector_from_clip(self, wav_file: Union[str, list]) -> list:
|
||||
def compute_d_vector_from_clip(self, wav_file: Union[str, List[str]]) -> list:
|
||||
"""Compute a d_vector from a given audio file.
|
||||
|
||||
Args:
|
||||
wav_file (Union[str, list]): Target file path.
|
||||
wav_file (Union[str, List[str]]): Target file path.
|
||||
|
||||
Returns:
|
||||
list: Computed d_vector.
|
||||
|
@ -242,12 +268,16 @@ class SpeakerManager:
|
|||
|
||||
def _compute(wav_file: str):
|
||||
waveform = self.speaker_encoder_ap.load_wav(wav_file, sr=self.speaker_encoder_ap.sample_rate)
|
||||
spec = self.speaker_encoder_ap.melspectrogram(waveform)
|
||||
spec = torch.from_numpy(spec.T)
|
||||
if not self.speaker_encoder_config.model_params.get("use_torch_spec", False):
|
||||
m_input = self.speaker_encoder_ap.melspectrogram(waveform)
|
||||
m_input = torch.from_numpy(m_input)
|
||||
else:
|
||||
m_input = torch.from_numpy(waveform)
|
||||
|
||||
if self.use_cuda:
|
||||
spec = spec.cuda()
|
||||
spec = spec.unsqueeze(0)
|
||||
d_vector = self.speaker_encoder.compute_embedding(spec)
|
||||
m_input = m_input.cuda()
|
||||
m_input = m_input.unsqueeze(0)
|
||||
d_vector = self.speaker_encoder.compute_embedding(m_input)
|
||||
return d_vector
|
||||
|
||||
if isinstance(wav_file, list):
|
||||
|
@ -364,11 +394,14 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
|
|||
elif c.use_speaker_embedding and "speakers_file" in c and c.speakers_file:
|
||||
# new speaker manager with speaker IDs file.
|
||||
speaker_manager.set_speaker_ids_from_file(c.speakers_file)
|
||||
print(
|
||||
" > Speaker manager is loaded with {} speakers: {}".format(
|
||||
speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids)
|
||||
|
||||
if speaker_manager.num_speakers > 0:
|
||||
print(
|
||||
" > Speaker manager is loaded with {} speakers: {}".format(
|
||||
speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# save file if path is defined
|
||||
if out_path:
|
||||
out_file_path = os.path.join(out_path, "speakers.json")
|
||||
|
@ -378,3 +411,13 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
|
|||
else:
|
||||
speaker_manager.save_speaker_ids_to_file(out_file_path)
|
||||
return speaker_manager
|
||||
|
||||
|
||||
def get_speaker_weighted_sampler(items: list):
|
||||
speaker_names = np.array([item[2] for item in items])
|
||||
unique_speaker_names = np.unique(speaker_names).tolist()
|
||||
speaker_ids = [unique_speaker_names.index(l) for l in speaker_names]
|
||||
speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names])
|
||||
weight_speaker = 1.0 / speaker_count
|
||||
dataset_samples_weight = torch.from_numpy(np.array([weight_speaker[l] for l in speaker_ids])).double()
|
||||
return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight))
|
||||
|
|
|
@ -15,7 +15,7 @@ if "tensorflow" in installed or "tensorflow-gpu" in installed:
|
|||
import tensorflow as tf
|
||||
|
||||
|
||||
def text_to_seq(text, CONFIG, custom_symbols=None):
|
||||
def text_to_seq(text, CONFIG, custom_symbols=None, language=None):
|
||||
text_cleaner = [CONFIG.text_cleaner]
|
||||
# text ot phonemes to sequence vector
|
||||
if CONFIG.use_phonemes:
|
||||
|
@ -23,7 +23,7 @@ def text_to_seq(text, CONFIG, custom_symbols=None):
|
|||
phoneme_to_sequence(
|
||||
text,
|
||||
text_cleaner,
|
||||
CONFIG.phoneme_language,
|
||||
language if language else CONFIG.phoneme_language,
|
||||
CONFIG.enable_eos_bos_chars,
|
||||
tp=CONFIG.characters,
|
||||
add_blank=CONFIG.add_blank,
|
||||
|
@ -71,6 +71,7 @@ def run_model_torch(
|
|||
speaker_id: int = None,
|
||||
style_mel: torch.Tensor = None,
|
||||
d_vector: torch.Tensor = None,
|
||||
language_id: torch.Tensor = None,
|
||||
) -> Dict:
|
||||
"""Run a torch model for inference. It does not support batch inference.
|
||||
|
||||
|
@ -96,6 +97,7 @@ def run_model_torch(
|
|||
"speaker_ids": speaker_id,
|
||||
"d_vectors": d_vector,
|
||||
"style_mel": style_mel,
|
||||
"language_ids": language_id,
|
||||
},
|
||||
)
|
||||
return outputs
|
||||
|
@ -160,19 +162,20 @@ def inv_spectrogram(postnet_output, ap, CONFIG):
|
|||
return wav
|
||||
|
||||
|
||||
def speaker_id_to_torch(speaker_id, cuda=False):
|
||||
if speaker_id is not None:
|
||||
speaker_id = np.asarray(speaker_id)
|
||||
speaker_id = torch.from_numpy(speaker_id)
|
||||
def id_to_torch(aux_id, cuda=False):
|
||||
if aux_id is not None:
|
||||
aux_id = np.asarray(aux_id)
|
||||
aux_id = torch.from_numpy(aux_id)
|
||||
if cuda:
|
||||
return speaker_id.cuda()
|
||||
return speaker_id
|
||||
return aux_id.cuda()
|
||||
return aux_id
|
||||
|
||||
|
||||
def embedding_to_torch(d_vector, cuda=False):
|
||||
if d_vector is not None:
|
||||
d_vector = np.asarray(d_vector)
|
||||
d_vector = torch.from_numpy(d_vector).type(torch.FloatTensor)
|
||||
d_vector = d_vector.squeeze().unsqueeze(0)
|
||||
if cuda:
|
||||
return d_vector.cuda()
|
||||
return d_vector
|
||||
|
@ -208,6 +211,8 @@ def synthesis(
|
|||
use_griffin_lim=False,
|
||||
do_trim_silence=False,
|
||||
d_vector=None,
|
||||
language_id=None,
|
||||
language_name=None,
|
||||
backend="torch",
|
||||
):
|
||||
"""Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to
|
||||
|
@ -244,6 +249,12 @@ def synthesis(
|
|||
d_vector (torch.Tensor):
|
||||
d-vector for multi-speaker models in share :math:`[1, D]`. Defaults to None.
|
||||
|
||||
language_id (int):
|
||||
Language ID passed to the language embedding layer in multi-langual model. Defaults to None.
|
||||
|
||||
language_name (str):
|
||||
Language name corresponding to the language code used by the phonemizer. Defaults to None.
|
||||
|
||||
backend (str):
|
||||
tf or torch. Defaults to "torch".
|
||||
"""
|
||||
|
@ -258,15 +269,18 @@ def synthesis(
|
|||
if hasattr(model, "make_symbols"):
|
||||
custom_symbols = model.make_symbols(CONFIG)
|
||||
# preprocess the given text
|
||||
text_inputs = text_to_seq(text, CONFIG, custom_symbols=custom_symbols)
|
||||
text_inputs = text_to_seq(text, CONFIG, custom_symbols=custom_symbols, language=language_name)
|
||||
# pass tensors to backend
|
||||
if backend == "torch":
|
||||
if speaker_id is not None:
|
||||
speaker_id = speaker_id_to_torch(speaker_id, cuda=use_cuda)
|
||||
speaker_id = id_to_torch(speaker_id, cuda=use_cuda)
|
||||
|
||||
if d_vector is not None:
|
||||
d_vector = embedding_to_torch(d_vector, cuda=use_cuda)
|
||||
|
||||
if language_id is not None:
|
||||
language_id = id_to_torch(language_id, cuda=use_cuda)
|
||||
|
||||
if not isinstance(style_mel, dict):
|
||||
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
|
||||
text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda)
|
||||
|
@ -278,7 +292,7 @@ def synthesis(
|
|||
text_inputs = tf.expand_dims(text_inputs, 0)
|
||||
# synthesize voice
|
||||
if backend == "torch":
|
||||
outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, d_vector=d_vector)
|
||||
outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, d_vector=d_vector, language_id=language_id)
|
||||
model_outputs = outputs["model_outputs"]
|
||||
model_outputs = model_outputs[0].data.cpu().numpy()
|
||||
alignments = outputs["alignments"]
|
||||
|
|
|
@ -135,3 +135,12 @@ def phoneme_cleaners(text):
|
|||
text = remove_aux_symbols(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def multilingual_cleaners(text):
|
||||
"""Pipeline for multilingual text"""
|
||||
text = lowercase(text)
|
||||
text = replace_symbols(text, lang=None)
|
||||
text = remove_aux_symbols(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
|
|
@ -16,6 +16,60 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
|||
"""Some of the audio processing funtions using Torch for faster batch processing.
|
||||
|
||||
TODO: Merge this with audio.py
|
||||
|
||||
Args:
|
||||
|
||||
n_fft (int):
|
||||
FFT window size for STFT.
|
||||
|
||||
hop_length (int):
|
||||
number of frames between STFT columns.
|
||||
|
||||
win_length (int, optional):
|
||||
STFT window length.
|
||||
|
||||
pad_wav (bool, optional):
|
||||
If True pad the audio with (n_fft - hop_length) / 2). Defaults to False.
|
||||
|
||||
window (str, optional):
|
||||
The name of a function to create a window tensor that is applied/multiplied to each frame/window. Defaults to "hann_window"
|
||||
|
||||
sample_rate (int, optional):
|
||||
target audio sampling rate. Defaults to None.
|
||||
|
||||
mel_fmin (int, optional):
|
||||
minimum filter frequency for computing melspectrograms. Defaults to None.
|
||||
|
||||
mel_fmax (int, optional):
|
||||
maximum filter frequency for computing melspectrograms. Defaults to None.
|
||||
|
||||
n_mels (int, optional):
|
||||
number of melspectrogram dimensions. Defaults to None.
|
||||
|
||||
use_mel (bool, optional):
|
||||
If True compute the melspectrograms otherwise. Defaults to False.
|
||||
|
||||
do_amp_to_db_linear (bool, optional):
|
||||
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to False.
|
||||
|
||||
spec_gain (float, optional):
|
||||
gain applied when converting amplitude to DB. Defaults to 1.0.
|
||||
|
||||
power (float, optional):
|
||||
Exponent for the magnitude spectrogram, e.g., 1 for energy, 2 for power, etc. Defaults to None.
|
||||
|
||||
use_htk (bool, optional):
|
||||
Use HTK formula in mel filter instead of Slaney.
|
||||
|
||||
mel_norm (None, 'slaney', or number, optional):
|
||||
If 'slaney', divide the triangular mel weights by the width of the mel band
|
||||
(area normalization).
|
||||
|
||||
If numeric, use `librosa.util.normalize` to normalize each filter by to unit l_p norm.
|
||||
See `librosa.util.normalize` for a full description of supported norm values
|
||||
(including `+-np.inf`).
|
||||
|
||||
Otherwise, leave all the triangles aiming for a peak value of 1.0. Defaults to "slaney".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -32,6 +86,9 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
|||
use_mel=False,
|
||||
do_amp_to_db=False,
|
||||
spec_gain=1.0,
|
||||
power=None,
|
||||
use_htk=False,
|
||||
mel_norm="slaney",
|
||||
):
|
||||
super().__init__()
|
||||
self.n_fft = n_fft
|
||||
|
@ -45,6 +102,9 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
|||
self.use_mel = use_mel
|
||||
self.do_amp_to_db = do_amp_to_db
|
||||
self.spec_gain = spec_gain
|
||||
self.power = power
|
||||
self.use_htk = use_htk
|
||||
self.mel_norm = mel_norm
|
||||
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
|
||||
self.mel_basis = None
|
||||
if use_mel:
|
||||
|
@ -83,6 +143,10 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
|||
M = o[:, :, :, 0]
|
||||
P = o[:, :, :, 1]
|
||||
S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))
|
||||
|
||||
if self.power is not None:
|
||||
S = S ** self.power
|
||||
|
||||
if self.use_mel:
|
||||
S = torch.matmul(self.mel_basis.to(x), S)
|
||||
if self.do_amp_to_db:
|
||||
|
@ -91,7 +155,13 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
|||
|
||||
def _build_mel_basis(self):
|
||||
mel_basis = librosa.filters.mel(
|
||||
self.sample_rate, self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax
|
||||
self.sample_rate,
|
||||
self.n_fft,
|
||||
n_mels=self.n_mels,
|
||||
fmin=self.mel_fmin,
|
||||
fmax=self.mel_fmax,
|
||||
htk=self.use_htk,
|
||||
norm=self.mel_norm,
|
||||
)
|
||||
self.mel_basis = torch.from_numpy(mel_basis).float()
|
||||
|
||||
|
@ -167,7 +237,7 @@ class AudioProcessor(object):
|
|||
minimum filter frequency for computing melspectrograms. Defaults to None.
|
||||
|
||||
mel_fmax (int, optional):
|
||||
maximum filter frequency for computing melspectrograms.. Defaults to None.
|
||||
maximum filter frequency for computing melspectrograms. Defaults to None.
|
||||
|
||||
spec_gain (int, optional):
|
||||
gain applied when converting amplitude to DB. Defaults to 20.
|
||||
|
@ -196,6 +266,12 @@ class AudioProcessor(object):
|
|||
do_amp_to_db_mel (bool, optional):
|
||||
enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
|
||||
|
||||
do_rms_norm (bool, optional):
|
||||
enable/disable RMS volume normalization when loading an audio file. Defaults to False.
|
||||
|
||||
db_level (int, optional):
|
||||
dB level used for rms normalization. The range is -99 to 0. Defaults to None.
|
||||
|
||||
stats_path (str, optional):
|
||||
Path to the computed stats file. Defaults to None.
|
||||
|
||||
|
@ -233,6 +309,8 @@ class AudioProcessor(object):
|
|||
do_sound_norm=False,
|
||||
do_amp_to_db_linear=True,
|
||||
do_amp_to_db_mel=True,
|
||||
do_rms_norm=False,
|
||||
db_level=None,
|
||||
stats_path=None,
|
||||
verbose=True,
|
||||
**_,
|
||||
|
@ -264,6 +342,8 @@ class AudioProcessor(object):
|
|||
self.do_sound_norm = do_sound_norm
|
||||
self.do_amp_to_db_linear = do_amp_to_db_linear
|
||||
self.do_amp_to_db_mel = do_amp_to_db_mel
|
||||
self.do_rms_norm = do_rms_norm
|
||||
self.db_level = db_level
|
||||
self.stats_path = stats_path
|
||||
# setup exp_func for db to amp conversion
|
||||
if log_func == "np.log":
|
||||
|
@ -656,21 +736,6 @@ class AudioProcessor(object):
|
|||
frame_period=1000 * self.hop_length / self.sample_rate,
|
||||
)
|
||||
f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate)
|
||||
# pad = int((self.win_length / self.hop_length) / 2)
|
||||
# f0 = [0.0] * pad + f0 + [0.0] * pad
|
||||
# f0 = np.pad(f0, (pad, pad), mode="constant", constant_values=0)
|
||||
# f0 = np.array(f0, dtype=np.float32)
|
||||
|
||||
# f01, _, _ = librosa.pyin(
|
||||
# x,
|
||||
# fmin=65 if self.mel_fmin == 0 else self.mel_fmin,
|
||||
# fmax=self.mel_fmax,
|
||||
# frame_length=self.win_length,
|
||||
# sr=self.sample_rate,
|
||||
# fill_na=0.0,
|
||||
# )
|
||||
|
||||
# spec = self.melspectrogram(x)
|
||||
return f0
|
||||
|
||||
### Audio Processing ###
|
||||
|
@ -713,10 +778,33 @@ class AudioProcessor(object):
|
|||
"""
|
||||
return x / abs(x).max() * 0.95
|
||||
|
||||
@staticmethod
|
||||
def _rms_norm(wav, db_level=-27):
|
||||
r = 10 ** (db_level / 20)
|
||||
a = np.sqrt((len(wav) * (r ** 2)) / np.sum(wav ** 2))
|
||||
return wav * a
|
||||
|
||||
def rms_volume_norm(self, x: np.ndarray, db_level: float = None) -> np.ndarray:
|
||||
"""Normalize the volume based on RMS of the signal.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Raw waveform.
|
||||
|
||||
Returns:
|
||||
np.ndarray: RMS normalized waveform.
|
||||
"""
|
||||
if db_level is None:
|
||||
db_level = self.db_level
|
||||
assert -99 <= db_level <= 0, " [!] db_level should be between -99 and 0"
|
||||
wav = self._rms_norm(x, db_level)
|
||||
return wav
|
||||
|
||||
### save and load ###
|
||||
def load_wav(self, filename: str, sr: int = None) -> np.ndarray:
|
||||
"""Read a wav file using Librosa and optionally resample, silence trim, volume normalize.
|
||||
|
||||
Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before.
|
||||
|
||||
Args:
|
||||
filename (str): Path to the wav file.
|
||||
sr (int, optional): Sampling rate for resampling. Defaults to None.
|
||||
|
@ -725,8 +813,10 @@ class AudioProcessor(object):
|
|||
np.ndarray: Loaded waveform.
|
||||
"""
|
||||
if self.resample:
|
||||
# loading with resampling. It is significantly slower.
|
||||
x, sr = librosa.load(filename, sr=self.sample_rate)
|
||||
elif sr is None:
|
||||
# SF is faster than librosa for loading files
|
||||
x, sr = sf.read(filename)
|
||||
assert self.sample_rate == sr, "%s vs %s" % (self.sample_rate, sr)
|
||||
else:
|
||||
|
@ -738,6 +828,8 @@ class AudioProcessor(object):
|
|||
print(f" [!] File cannot be trimmed for silence - {filename}")
|
||||
if self.do_sound_norm:
|
||||
x = self.sound_norm(x)
|
||||
if self.do_rms_norm:
|
||||
x = self.rms_volume_norm(x, self.db_level)
|
||||
return x
|
||||
|
||||
def save_wav(self, wav: np.ndarray, path: str, sr: int = None) -> None:
|
||||
|
|
|
@ -26,7 +26,7 @@ class AttrDict(dict):
|
|||
self.__dict__ = self
|
||||
|
||||
|
||||
def copy_model_files(config: Coqpit, out_path, new_fields):
|
||||
def copy_model_files(config: Coqpit, out_path, new_fields=None):
|
||||
"""Copy config.json and other model files to training folder and add
|
||||
new fields.
|
||||
|
||||
|
|
|
@ -46,36 +46,66 @@ class ModelManager(object):
|
|||
with open(file_path, "r", encoding="utf-8") as json_file:
|
||||
self.models_dict = json.load(json_file)
|
||||
|
||||
def list_langs(self):
|
||||
print(" Name format: type/language")
|
||||
for model_type in self.models_dict:
|
||||
for lang in self.models_dict[model_type]:
|
||||
print(f" >: {model_type}/{lang} ")
|
||||
def _list_models(self, model_type, model_count=0):
|
||||
model_list = []
|
||||
for lang in self.models_dict[model_type]:
|
||||
for dataset in self.models_dict[model_type][lang]:
|
||||
for model in self.models_dict[model_type][lang][dataset]:
|
||||
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
|
||||
output_path = os.path.join(self.output_prefix, model_full_name)
|
||||
if os.path.exists(output_path):
|
||||
print(f" {model_count}: {model_type}/{lang}/{dataset}/{model} [already downloaded]")
|
||||
else:
|
||||
print(f" {model_count}: {model_type}/{lang}/{dataset}/{model}")
|
||||
model_list.append(f"{model_type}/{lang}/{dataset}/{model}")
|
||||
model_count += 1
|
||||
return model_list
|
||||
|
||||
def list_datasets(self):
|
||||
print(" Name format: type/language/dataset")
|
||||
for model_type in self.models_dict:
|
||||
for lang in self.models_dict[model_type]:
|
||||
for dataset in self.models_dict[model_type][lang]:
|
||||
print(f" >: {model_type}/{lang}/{dataset}")
|
||||
def _list_for_model_type(self, model_type):
|
||||
print(" Name format: language/dataset/model")
|
||||
models_name_list = []
|
||||
model_count = 1
|
||||
model_type = "tts_models"
|
||||
models_name_list.extend(self._list_models(model_type, model_count))
|
||||
return [name.replace(model_type + "/", "") for name in models_name_list]
|
||||
|
||||
def list_models(self):
|
||||
print(" Name format: type/language/dataset/model")
|
||||
models_name_list = []
|
||||
model_count = 1
|
||||
for model_type in self.models_dict:
|
||||
model_list = self._list_models(model_type, model_count)
|
||||
models_name_list.extend(model_list)
|
||||
return models_name_list
|
||||
|
||||
def list_tts_models(self):
|
||||
"""Print all `TTS` models and return a list of model names
|
||||
|
||||
Format is `language/dataset/model`
|
||||
"""
|
||||
return self._list_for_model_type("tts_models")
|
||||
|
||||
def list_vocoder_models(self):
|
||||
"""Print all the `vocoder` models and return a list of model names
|
||||
|
||||
Format is `language/dataset/model`
|
||||
"""
|
||||
return self._list_for_model_type("vocoder_models")
|
||||
|
||||
def list_langs(self):
|
||||
"""Print all the available languages"""
|
||||
print(" Name format: type/language")
|
||||
for model_type in self.models_dict:
|
||||
for lang in self.models_dict[model_type]:
|
||||
print(f" >: {model_type}/{lang} ")
|
||||
|
||||
def list_datasets(self):
|
||||
"""Print all the datasets"""
|
||||
print(" Name format: type/language/dataset")
|
||||
for model_type in self.models_dict:
|
||||
for lang in self.models_dict[model_type]:
|
||||
for dataset in self.models_dict[model_type][lang]:
|
||||
for model in self.models_dict[model_type][lang][dataset]:
|
||||
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
|
||||
output_path = os.path.join(self.output_prefix, model_full_name)
|
||||
if os.path.exists(output_path):
|
||||
print(f" {model_count}: {model_type}/{lang}/{dataset}/{model} [already downloaded]")
|
||||
else:
|
||||
print(f" {model_count}: {model_type}/{lang}/{dataset}/{model}")
|
||||
models_name_list.append(f"{model_type}/{lang}/{dataset}/{model}")
|
||||
model_count += 1
|
||||
return models_name_list
|
||||
print(f" >: {model_type}/{lang}/{dataset}")
|
||||
|
||||
def download_model(self, model_name):
|
||||
"""Download model files given the full model name.
|
||||
|
@ -121,6 +151,8 @@ class ModelManager(object):
|
|||
output_stats_path = os.path.join(output_path, "scale_stats.npy")
|
||||
output_d_vector_file_path = os.path.join(output_path, "speakers.json")
|
||||
output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json")
|
||||
speaker_encoder_config_path = os.path.join(output_path, "config_se.json")
|
||||
speaker_encoder_model_path = os.path.join(output_path, "model_se.pth.tar")
|
||||
|
||||
# update the scale_path.npy file path in the model config.json
|
||||
self._update_path("audio.stats_path", output_stats_path, config_path)
|
||||
|
@ -133,6 +165,12 @@ class ModelManager(object):
|
|||
self._update_path("speakers_file", output_speaker_ids_file_path, config_path)
|
||||
self._update_path("model_args.speakers_file", output_speaker_ids_file_path, config_path)
|
||||
|
||||
# update the speaker_encoder file path in the model config.json to the current path
|
||||
self._update_path("speaker_encoder_model_path", speaker_encoder_model_path, config_path)
|
||||
self._update_path("model_args.speaker_encoder_model_path", speaker_encoder_model_path, config_path)
|
||||
self._update_path("speaker_encoder_config_path", speaker_encoder_config_path, config_path)
|
||||
self._update_path("model_args.speaker_encoder_config_path", speaker_encoder_config_path, config_path)
|
||||
|
||||
@staticmethod
|
||||
def _update_path(field_name, new_path, config_path):
|
||||
"""Update the path in the model config.json for the current environment after download"""
|
||||
|
@ -159,8 +197,12 @@ class ModelManager(object):
|
|||
# download the file
|
||||
r = requests.get(file_url)
|
||||
# extract the file
|
||||
with zipfile.ZipFile(io.BytesIO(r.content)) as z:
|
||||
z.extractall(output_folder)
|
||||
try:
|
||||
with zipfile.ZipFile(io.BytesIO(r.content)) as z:
|
||||
z.extractall(output_folder)
|
||||
except zipfile.BadZipFile:
|
||||
print(f" > Error: Bad zip file - {file_url}")
|
||||
raise zipfile.BadZipFile
|
||||
# move the files to the outer path
|
||||
for file_path in z.namelist()[1:]:
|
||||
src_path = os.path.join(output_folder, file_path)
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
import time
|
||||
from typing import List
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import pysbd
|
||||
import torch
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.config import check_config_and_model_args, get_from_config_or_model_args_with_default, load_config
|
||||
from TTS.tts.models import setup_model as setup_tts_model
|
||||
from TTS.tts.utils.languages import LanguageManager
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
|
||||
# pylint: disable=unused-wildcard-import
|
||||
|
@ -23,6 +24,7 @@ class Synthesizer(object):
|
|||
tts_checkpoint: str,
|
||||
tts_config_path: str,
|
||||
tts_speakers_file: str = "",
|
||||
tts_languages_file: str = "",
|
||||
vocoder_checkpoint: str = "",
|
||||
vocoder_config: str = "",
|
||||
encoder_checkpoint: str = "",
|
||||
|
@ -52,6 +54,7 @@ class Synthesizer(object):
|
|||
self.tts_checkpoint = tts_checkpoint
|
||||
self.tts_config_path = tts_config_path
|
||||
self.tts_speakers_file = tts_speakers_file
|
||||
self.tts_languages_file = tts_languages_file
|
||||
self.vocoder_checkpoint = vocoder_checkpoint
|
||||
self.vocoder_config = vocoder_config
|
||||
self.encoder_checkpoint = encoder_checkpoint
|
||||
|
@ -63,6 +66,9 @@ class Synthesizer(object):
|
|||
self.speaker_manager = None
|
||||
self.num_speakers = 0
|
||||
self.tts_speakers = {}
|
||||
self.language_manager = None
|
||||
self.num_languages = 0
|
||||
self.tts_languages = {}
|
||||
self.d_vector_dim = 0
|
||||
self.seg = self._get_segmenter("en")
|
||||
self.use_cuda = use_cuda
|
||||
|
@ -110,29 +116,93 @@ class Synthesizer(object):
|
|||
self.ap = AudioProcessor(verbose=False, **self.tts_config.audio)
|
||||
|
||||
speaker_manager = self._init_speaker_manager()
|
||||
language_manager = self._init_language_manager()
|
||||
self._set_speaker_encoder_paths_from_tts_config()
|
||||
speaker_manager = self._init_speaker_encoder(speaker_manager)
|
||||
|
||||
self.tts_model = setup_tts_model(config=self.tts_config, speaker_manager=speaker_manager)
|
||||
if language_manager is not None:
|
||||
self.tts_model = setup_tts_model(
|
||||
config=self.tts_config,
|
||||
speaker_manager=speaker_manager,
|
||||
language_manager=language_manager,
|
||||
)
|
||||
else:
|
||||
self.tts_model = setup_tts_model(config=self.tts_config, speaker_manager=speaker_manager)
|
||||
self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True)
|
||||
if use_cuda:
|
||||
self.tts_model.cuda()
|
||||
|
||||
def _set_speaker_encoder_paths_from_tts_config(self):
|
||||
"""Set the encoder paths from the tts model config for models with speaker encoders."""
|
||||
if hasattr(self.tts_config, "model_args") and hasattr(
|
||||
self.tts_config.model_args, "speaker_encoder_config_path"
|
||||
):
|
||||
self.encoder_checkpoint = self.tts_config.model_args.speaker_encoder_model_path
|
||||
self.encoder_config = self.tts_config.model_args.speaker_encoder_config_path
|
||||
|
||||
def _is_use_speaker_embedding(self):
|
||||
"""Check if the speaker embedding is used in the model"""
|
||||
# we handle here the case that some models use model_args some don't
|
||||
use_speaker_embedding = False
|
||||
if hasattr(self.tts_config, "model_args"):
|
||||
use_speaker_embedding = self.tts_config["model_args"].get("use_speaker_embedding", False)
|
||||
use_speaker_embedding = use_speaker_embedding or self.tts_config.get("use_speaker_embedding", False)
|
||||
return use_speaker_embedding
|
||||
|
||||
def _is_use_d_vector_file(self):
|
||||
"""Check if the d-vector file is used in the model"""
|
||||
# we handle here the case that some models use model_args some don't
|
||||
use_d_vector_file = False
|
||||
if hasattr(self.tts_config, "model_args"):
|
||||
config = self.tts_config.model_args
|
||||
use_d_vector_file = config.get("use_d_vector_file", False)
|
||||
config = self.tts_config
|
||||
use_d_vector_file = use_d_vector_file or config.get("use_d_vector_file", False)
|
||||
return use_d_vector_file
|
||||
|
||||
def _init_speaker_manager(self):
|
||||
"""Initialize the SpeakerManager"""
|
||||
# setup if multi-speaker settings are in the global model config
|
||||
speaker_manager = None
|
||||
if hasattr(self.tts_config, "use_speaker_embedding") and self.tts_config.use_speaker_embedding is True:
|
||||
speakers_file = get_from_config_or_model_args_with_default(self.tts_config, "speakers_file", None)
|
||||
if self._is_use_speaker_embedding():
|
||||
if self.tts_speakers_file:
|
||||
speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_speakers_file)
|
||||
if self.tts_config.get("speakers_file", None):
|
||||
speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_config.speakers_file)
|
||||
if speakers_file:
|
||||
speaker_manager = SpeakerManager(speaker_id_file_path=speakers_file)
|
||||
|
||||
if hasattr(self.tts_config, "use_d_vector_file") and self.tts_config.use_speaker_embedding is True:
|
||||
if self._is_use_d_vector_file():
|
||||
d_vector_file = get_from_config_or_model_args_with_default(self.tts_config, "d_vector_file", None)
|
||||
if self.tts_speakers_file:
|
||||
speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_speakers_file)
|
||||
if self.tts_config.get("d_vector_file", None):
|
||||
speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_config.d_vector_file)
|
||||
if d_vector_file:
|
||||
speaker_manager = SpeakerManager(d_vectors_file_path=d_vector_file)
|
||||
return speaker_manager
|
||||
|
||||
def _init_speaker_encoder(self, speaker_manager):
|
||||
"""Initialize the SpeakerEncoder"""
|
||||
if self.encoder_checkpoint:
|
||||
if speaker_manager is None:
|
||||
speaker_manager = SpeakerManager(
|
||||
encoder_model_path=self.encoder_checkpoint, encoder_config_path=self.encoder_config
|
||||
)
|
||||
else:
|
||||
speaker_manager.init_speaker_encoder(self.encoder_checkpoint, self.encoder_config)
|
||||
return speaker_manager
|
||||
|
||||
def _init_language_manager(self):
|
||||
"""Initialize the LanguageManager"""
|
||||
# setup if multi-lingual settings are in the global model config
|
||||
language_manager = None
|
||||
if check_config_and_model_args(self.tts_config, "use_language_embedding", True):
|
||||
if self.tts_languages_file:
|
||||
language_manager = LanguageManager(language_ids_file_path=self.tts_languages_file)
|
||||
elif self.tts_config.get("language_ids_file", None):
|
||||
language_manager = LanguageManager(language_ids_file_path=self.tts_config.language_ids_file)
|
||||
else:
|
||||
language_manager = LanguageManager(config=self.tts_config)
|
||||
return language_manager
|
||||
|
||||
def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None:
|
||||
"""Load the vocoder model.
|
||||
|
||||
|
@ -174,13 +244,21 @@ class Synthesizer(object):
|
|||
wav = np.array(wav)
|
||||
self.ap.save_wav(wav, path, self.output_sample_rate)
|
||||
|
||||
def tts(self, text: str, speaker_idx: str = "", speaker_wav=None, style_wav=None) -> List[int]:
|
||||
def tts(
|
||||
self,
|
||||
text: str,
|
||||
speaker_name: str = "",
|
||||
language_name: str = "",
|
||||
speaker_wav: Union[str, List[str]] = None,
|
||||
style_wav=None,
|
||||
) -> List[int]:
|
||||
"""🐸 TTS magic. Run all the models and generate speech.
|
||||
|
||||
Args:
|
||||
text (str): input text.
|
||||
speaker_idx (str, optional): spekaer id for multi-speaker models. Defaults to "".
|
||||
speaker_wav ():
|
||||
speaker_name (str, optional): spekaer id for multi-speaker models. Defaults to "".
|
||||
language_name (str, optional): language id for multi-language models. Defaults to "".
|
||||
speaker_wav (Union[str, List[str]], optional): path to the speaker wav. Defaults to None.
|
||||
style_wav ([type], optional): style waveform for GST. Defaults to None.
|
||||
|
||||
Returns:
|
||||
|
@ -196,29 +274,49 @@ class Synthesizer(object):
|
|||
speaker_embedding = None
|
||||
speaker_id = None
|
||||
if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "speaker_ids"):
|
||||
if speaker_idx and isinstance(speaker_idx, str):
|
||||
if speaker_name and isinstance(speaker_name, str):
|
||||
if self.tts_config.use_d_vector_file:
|
||||
# get the speaker embedding from the saved d_vectors.
|
||||
speaker_embedding = self.tts_model.speaker_manager.get_d_vectors_by_speaker(speaker_idx)[0]
|
||||
speaker_embedding = self.tts_model.speaker_manager.get_d_vectors_by_speaker(speaker_name)[0]
|
||||
speaker_embedding = np.array(speaker_embedding)[None, :] # [1 x embedding_dim]
|
||||
else:
|
||||
# get speaker idx from the speaker name
|
||||
speaker_id = self.tts_model.speaker_manager.speaker_ids[speaker_idx]
|
||||
speaker_id = self.tts_model.speaker_manager.speaker_ids[speaker_name]
|
||||
|
||||
elif not speaker_idx and not speaker_wav:
|
||||
elif not speaker_name and not speaker_wav:
|
||||
raise ValueError(
|
||||
" [!] Look like you use a multi-speaker model. "
|
||||
"You need to define either a `speaker_idx` or a `style_wav` to use a multi-speaker model."
|
||||
"You need to define either a `speaker_name` or a `style_wav` to use a multi-speaker model."
|
||||
)
|
||||
else:
|
||||
speaker_embedding = None
|
||||
else:
|
||||
if speaker_idx:
|
||||
if speaker_name:
|
||||
raise ValueError(
|
||||
f" [!] Missing speakers.json file path for selecting speaker {speaker_idx}."
|
||||
f" [!] Missing speakers.json file path for selecting speaker {speaker_name}."
|
||||
"Define path for speaker.json if it is a multi-speaker model or remove defined speaker idx. "
|
||||
)
|
||||
|
||||
# handle multi-lingaul
|
||||
language_id = None
|
||||
if self.tts_languages_file or (
|
||||
hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None
|
||||
):
|
||||
if language_name and isinstance(language_name, str):
|
||||
language_id = self.tts_model.language_manager.language_id_mapping[language_name]
|
||||
|
||||
elif not language_name:
|
||||
raise ValueError(
|
||||
" [!] Look like you use a multi-lingual model. "
|
||||
"You need to define either a `language_name` or a `style_wav` to use a multi-lingual model."
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f" [!] Missing language_ids.json file path for selecting language {language_name}."
|
||||
"Define path for language_ids.json if it is a multi-lingual model or remove defined language idx. "
|
||||
)
|
||||
|
||||
# compute a new d_vector from the given clip.
|
||||
if speaker_wav is not None:
|
||||
speaker_embedding = self.tts_model.speaker_manager.compute_d_vector_from_clip(speaker_wav)
|
||||
|
@ -234,6 +332,8 @@ class Synthesizer(object):
|
|||
use_cuda=self.use_cuda,
|
||||
ap=self.ap,
|
||||
speaker_id=speaker_id,
|
||||
language_id=language_id,
|
||||
language_name=language_name,
|
||||
style_wav=style_wav,
|
||||
enable_eos_bos_chars=self.tts_config.enable_eos_bos_chars,
|
||||
use_griffin_lim=use_gl,
|
||||
|
|
|
@ -0,0 +1,144 @@
|
|||
# This code is adpated from: https://github.com/wiseman/py-webrtcvad/blob/master/example.py
|
||||
import collections
|
||||
import contextlib
|
||||
import wave
|
||||
|
||||
import webrtcvad
|
||||
|
||||
|
||||
def read_wave(path):
|
||||
"""Reads a .wav file.
|
||||
|
||||
Takes the path, and returns (PCM audio data, sample rate).
|
||||
"""
|
||||
with contextlib.closing(wave.open(path, "rb")) as wf:
|
||||
num_channels = wf.getnchannels()
|
||||
assert num_channels == 1
|
||||
sample_width = wf.getsampwidth()
|
||||
assert sample_width == 2
|
||||
sample_rate = wf.getframerate()
|
||||
assert sample_rate in (8000, 16000, 32000, 48000)
|
||||
pcm_data = wf.readframes(wf.getnframes())
|
||||
return pcm_data, sample_rate
|
||||
|
||||
|
||||
def write_wave(path, audio, sample_rate):
|
||||
"""Writes a .wav file.
|
||||
|
||||
Takes path, PCM audio data, and sample rate.
|
||||
"""
|
||||
with contextlib.closing(wave.open(path, "wb")) as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(sample_rate)
|
||||
wf.writeframes(audio)
|
||||
|
||||
|
||||
class Frame(object):
|
||||
"""Represents a "frame" of audio data."""
|
||||
|
||||
def __init__(self, _bytes, timestamp, duration):
|
||||
self.bytes = _bytes
|
||||
self.timestamp = timestamp
|
||||
self.duration = duration
|
||||
|
||||
|
||||
def frame_generator(frame_duration_ms, audio, sample_rate):
|
||||
"""Generates audio frames from PCM audio data.
|
||||
|
||||
Takes the desired frame duration in milliseconds, the PCM data, and
|
||||
the sample rate.
|
||||
|
||||
Yields Frames of the requested duration.
|
||||
"""
|
||||
n = int(sample_rate * (frame_duration_ms / 1000.0) * 2)
|
||||
offset = 0
|
||||
timestamp = 0.0
|
||||
duration = (float(n) / sample_rate) / 2.0
|
||||
while offset + n < len(audio):
|
||||
yield Frame(audio[offset : offset + n], timestamp, duration)
|
||||
timestamp += duration
|
||||
offset += n
|
||||
|
||||
|
||||
def vad_collector(sample_rate, frame_duration_ms, padding_duration_ms, vad, frames):
|
||||
"""Filters out non-voiced audio frames.
|
||||
|
||||
Given a webrtcvad.Vad and a source of audio frames, yields only
|
||||
the voiced audio.
|
||||
|
||||
Uses a padded, sliding window algorithm over the audio frames.
|
||||
When more than 90% of the frames in the window are voiced (as
|
||||
reported by the VAD), the collector triggers and begins yielding
|
||||
audio frames. Then the collector waits until 90% of the frames in
|
||||
the window are unvoiced to detrigger.
|
||||
|
||||
The window is padded at the front and back to provide a small
|
||||
amount of silence or the beginnings/endings of speech around the
|
||||
voiced frames.
|
||||
|
||||
Arguments:
|
||||
|
||||
sample_rate - The audio sample rate, in Hz.
|
||||
frame_duration_ms - The frame duration in milliseconds.
|
||||
padding_duration_ms - The amount to pad the window, in milliseconds.
|
||||
vad - An instance of webrtcvad.Vad.
|
||||
frames - a source of audio frames (sequence or generator).
|
||||
|
||||
Returns: A generator that yields PCM audio data.
|
||||
"""
|
||||
num_padding_frames = int(padding_duration_ms / frame_duration_ms)
|
||||
# We use a deque for our sliding window/ring buffer.
|
||||
ring_buffer = collections.deque(maxlen=num_padding_frames)
|
||||
# We have two states: TRIGGERED and NOTTRIGGERED. We start in the
|
||||
# NOTTRIGGERED state.
|
||||
triggered = False
|
||||
|
||||
voiced_frames = []
|
||||
for frame in frames:
|
||||
is_speech = vad.is_speech(frame.bytes, sample_rate)
|
||||
|
||||
# sys.stdout.write('1' if is_speech else '0')
|
||||
if not triggered:
|
||||
ring_buffer.append((frame, is_speech))
|
||||
num_voiced = len([f for f, speech in ring_buffer if speech])
|
||||
# If we're NOTTRIGGERED and more than 90% of the frames in
|
||||
# the ring buffer are voiced frames, then enter the
|
||||
# TRIGGERED state.
|
||||
if num_voiced > 0.9 * ring_buffer.maxlen:
|
||||
triggered = True
|
||||
# sys.stdout.write('+(%s)' % (ring_buffer[0][0].timestamp,))
|
||||
# We want to yield all the audio we see from now until
|
||||
# we are NOTTRIGGERED, but we have to start with the
|
||||
# audio that's already in the ring buffer.
|
||||
for f, _ in ring_buffer:
|
||||
voiced_frames.append(f)
|
||||
ring_buffer.clear()
|
||||
else:
|
||||
# We're in the TRIGGERED state, so collect the audio data
|
||||
# and add it to the ring buffer.
|
||||
voiced_frames.append(frame)
|
||||
ring_buffer.append((frame, is_speech))
|
||||
num_unvoiced = len([f for f, speech in ring_buffer if not speech])
|
||||
# If more than 90% of the frames in the ring buffer are
|
||||
# unvoiced, then enter NOTTRIGGERED and yield whatever
|
||||
# audio we've collected.
|
||||
if num_unvoiced > 0.9 * ring_buffer.maxlen:
|
||||
# sys.stdout.write('-(%s)' % (frame.timestamp + frame.duration))
|
||||
triggered = False
|
||||
yield b"".join([f.bytes for f in voiced_frames])
|
||||
ring_buffer.clear()
|
||||
voiced_frames = []
|
||||
# If we have any leftover voiced audio when we run out of input,
|
||||
# yield it.
|
||||
if voiced_frames:
|
||||
yield b"".join([f.bytes for f in voiced_frames])
|
||||
|
||||
|
||||
def get_vad_speech_segments(audio, sample_rate, aggressiveness=2, padding_duration_ms=300):
|
||||
|
||||
vad = webrtcvad.Vad(int(aggressiveness))
|
||||
frames = list(frame_generator(30, audio, sample_rate))
|
||||
segments = vad_collector(sample_rate, 30, padding_duration_ms, vad, frames)
|
||||
|
||||
return segments
|
|
@ -3,10 +3,15 @@
|
|||
VITS (Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech
|
||||
) is an End-to-End (encoder -> vocoder together) TTS model that takes advantage of SOTA DL techniques like GANs, VAE,
|
||||
Normalizing Flows. It does not require external alignment annotations and learns the text-to-audio alignment
|
||||
using MAS as explained in the paper. The model architecture is a combination of GlowTTS encoder and HiFiGAN vocoder.
|
||||
using MAS, as explained in the paper. The model architecture is a combination of GlowTTS encoder and HiFiGAN vocoder.
|
||||
It is a feed-forward model with x67.12 real-time factor on a GPU.
|
||||
|
||||
🐸 YourTTS is a multi-speaker and multi-lingual TTS model that can perform voice conversion and zero-shot speaker adaptation.
|
||||
It can also learn a new language or voice with a ~ 1 minute long audio clip. This is a big open gate for training
|
||||
TTS models in low-resources languages. 🐸 YourTTS uses VITS as the backbone architecture coupled with a speaker encoder model.
|
||||
|
||||
## Important resources & papers
|
||||
- 🐸 YourTTS: https://arxiv.org/abs/2112.02418
|
||||
- VITS: https://arxiv.org/pdf/2106.06103.pdf
|
||||
- Neural Spline Flows: https://arxiv.org/abs/1906.04032
|
||||
- Variational Autoencoder: https://arxiv.org/pdf/1312.6114.pdf
|
||||
|
|
|
@ -180,7 +180,7 @@ def plot_phonemes(train_path, cmu_dict_path, save_path):
|
|||
|
||||
plt.figure()
|
||||
plt.rcParams["figure.figsize"] = (50, 20)
|
||||
barplot = sns.barplot(x, y)
|
||||
barplot = sns.barplot(x=x, y=y)
|
||||
if save_path:
|
||||
fig = barplot.get_figure()
|
||||
fig.savefig(os.path.join(save_path, "phoneme_dist"))
|
||||
|
|
|
@ -0,0 +1,130 @@
|
|||
import os
|
||||
from glob import glob
|
||||
|
||||
from TTS.config.shared_configs import BaseAudioConfig
|
||||
from TTS.trainer import Trainer, TrainingArgs
|
||||
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.models.vits import Vits, VitsArgs
|
||||
from TTS.tts.utils.languages import LanguageManager
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
mailabs_path = "/home/julian/workspace/mailabs/**"
|
||||
dataset_paths = glob(mailabs_path)
|
||||
dataset_config = [
|
||||
BaseDatasetConfig(name="mailabs", meta_file_train=None, path=path, language=path.split("/")[-1])
|
||||
for path in dataset_paths
|
||||
]
|
||||
|
||||
audio_config = BaseAudioConfig(
|
||||
sample_rate=16000,
|
||||
win_length=1024,
|
||||
hop_length=256,
|
||||
num_mels=80,
|
||||
preemphasis=0.0,
|
||||
ref_level_db=20,
|
||||
log_func="np.log",
|
||||
do_trim_silence=False,
|
||||
trim_db=23.0,
|
||||
mel_fmin=0,
|
||||
mel_fmax=None,
|
||||
spec_gain=1.0,
|
||||
signal_norm=True,
|
||||
do_amp_to_db_linear=False,
|
||||
resample=False,
|
||||
)
|
||||
|
||||
vitsArgs = VitsArgs(
|
||||
use_language_embedding=True,
|
||||
embedded_language_dim=4,
|
||||
use_speaker_embedding=True,
|
||||
use_sdp=False,
|
||||
)
|
||||
|
||||
config = VitsConfig(
|
||||
model_args=vitsArgs,
|
||||
audio=audio_config,
|
||||
run_name="vits_vctk",
|
||||
use_speaker_embedding=True,
|
||||
batch_size=32,
|
||||
eval_batch_size=16,
|
||||
batch_group_size=0,
|
||||
num_loader_workers=4,
|
||||
num_eval_loader_workers=4,
|
||||
run_eval=True,
|
||||
test_delay_epochs=-1,
|
||||
epochs=1000,
|
||||
text_cleaner="multilingual_cleaners",
|
||||
use_phonemes=False,
|
||||
phoneme_language="en-us",
|
||||
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
|
||||
compute_input_seq_cache=True,
|
||||
print_step=25,
|
||||
use_language_weighted_sampler=True,
|
||||
print_eval=False,
|
||||
mixed_precision=False,
|
||||
sort_by_audio_len=True,
|
||||
min_seq_len=32 * 256 * 4,
|
||||
max_seq_len=160000,
|
||||
output_path=output_path,
|
||||
datasets=dataset_config,
|
||||
characters={
|
||||
"pad": "_",
|
||||
"eos": "&",
|
||||
"bos": "*",
|
||||
"characters": "!¡'(),-.:;¿?abcdefghijklmnopqrstuvwxyzµßàáâäåæçèéêëìíîïñòóôöùúûüąćęłńœśşźżƒабвгдежзийклмнопрстуфхцчшщъыьэюяёєіїґӧ «°±µ»$%&‘’‚“`”„",
|
||||
"punctuations": "!¡'(),-.:;¿? ",
|
||||
"phonemes": None,
|
||||
"unique": True,
|
||||
},
|
||||
test_sentences=[
|
||||
[
|
||||
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
"mary_ann",
|
||||
None,
|
||||
"en_US",
|
||||
],
|
||||
[
|
||||
"Il m'a fallu beaucoup de temps pour d\u00e9velopper une voix, et maintenant que je l'ai, je ne vais pas me taire.",
|
||||
"ezwa",
|
||||
None,
|
||||
"fr_FR",
|
||||
],
|
||||
["Ich finde, dieses Startup ist wirklich unglaublich.", "eva_k", None, "de_DE"],
|
||||
["Я думаю, что этот стартап действительно удивительный.", "oblomov", None, "ru_RU"],
|
||||
],
|
||||
)
|
||||
|
||||
# init audio processor
|
||||
ap = AudioProcessor(**config.audio.to_dict())
|
||||
|
||||
# load training samples
|
||||
train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
|
||||
|
||||
# init speaker manager for multi-speaker training
|
||||
# it maps speaker-id to speaker-name in the model and data-loader
|
||||
speaker_manager = SpeakerManager()
|
||||
speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples)
|
||||
config.model_args.num_speakers = speaker_manager.num_speakers
|
||||
|
||||
language_manager = LanguageManager(config=config)
|
||||
config.model_args.num_languages = language_manager.num_languages
|
||||
|
||||
# init model
|
||||
model = Vits(config, speaker_manager, language_manager)
|
||||
|
||||
# init the trainer and 🚀
|
||||
trainer = Trainer(
|
||||
TrainingArgs(),
|
||||
config,
|
||||
output_path,
|
||||
model=model,
|
||||
train_samples=train_samples,
|
||||
eval_samples=eval_samples,
|
||||
training_assets={"audio_processor": ap},
|
||||
)
|
||||
trainer.fit()
|
|
@ -26,3 +26,5 @@ unidic-lite==1.0.8
|
|||
gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=2.0.0
|
||||
fsspec>=2021.04.0
|
||||
pyworld
|
||||
webrtcvad
|
||||
torchaudio
|
||||
|
|
|
@ -38,3 +38,14 @@ def run_cli(command):
|
|||
|
||||
def get_test_data_config():
|
||||
return BaseDatasetConfig(name="ljspeech", path="tests/data/ljspeech/", meta_file_train="metadata.csv")
|
||||
|
||||
|
||||
def assertHasAttr(test_obj, obj, intendedAttr):
|
||||
# from https://stackoverflow.com/questions/48078636/pythons-unittest-lacks-an-asserthasattr-method-what-should-i-use-instead
|
||||
testBool = hasattr(obj, intendedAttr)
|
||||
test_obj.assertTrue(testBool, msg=f"obj lacking an attribute. obj: {obj}, intendedAttr: {intendedAttr}")
|
||||
|
||||
|
||||
def assertHasNotAttr(test_obj, obj, intendedAttr):
|
||||
testBool = hasattr(obj, intendedAttr)
|
||||
test_obj.assertFalse(testBool, msg=f"obj should not have an attribute. obj: {obj}, intendedAttr: {intendedAttr}")
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
import os
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from tests import get_tests_output_path, run_cli
|
||||
from TTS.config.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
|
||||
torch.manual_seed(1)
|
||||
|
||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||
|
||||
dataset_config_en = BaseDatasetConfig(
|
||||
name="ljspeech",
|
||||
meta_file_train="metadata.csv",
|
||||
meta_file_val="metadata.csv",
|
||||
path="tests/data/ljspeech",
|
||||
language="en",
|
||||
)
|
||||
|
||||
dataset_config_pt = BaseDatasetConfig(
|
||||
name="ljspeech",
|
||||
meta_file_train="metadata.csv",
|
||||
meta_file_val="metadata.csv",
|
||||
path="tests/data/ljspeech",
|
||||
language="pt-br",
|
||||
)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
class TestFindUniquePhonemes(unittest.TestCase):
|
||||
@staticmethod
|
||||
def test_espeak_phonemes():
|
||||
# prepare the config
|
||||
config = VitsConfig(
|
||||
batch_size=2,
|
||||
eval_batch_size=2,
|
||||
num_loader_workers=0,
|
||||
num_eval_loader_workers=0,
|
||||
text_cleaner="english_cleaners",
|
||||
use_phonemes=True,
|
||||
use_espeak_phonemes=True,
|
||||
phoneme_language="en-us",
|
||||
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
|
||||
run_eval=True,
|
||||
test_delay_epochs=-1,
|
||||
epochs=1,
|
||||
print_step=1,
|
||||
print_eval=True,
|
||||
datasets=[dataset_config_en, dataset_config_pt],
|
||||
)
|
||||
config.save_json(config_path)
|
||||
|
||||
# run test
|
||||
run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/find_unique_phonemes.py --config_path "{config_path}"')
|
||||
|
||||
@staticmethod
|
||||
def test_no_espeak_phonemes():
|
||||
# prepare the config
|
||||
config = VitsConfig(
|
||||
batch_size=2,
|
||||
eval_batch_size=2,
|
||||
num_loader_workers=0,
|
||||
num_eval_loader_workers=0,
|
||||
text_cleaner="english_cleaners",
|
||||
use_phonemes=True,
|
||||
use_espeak_phonemes=False,
|
||||
phoneme_language="en-us",
|
||||
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
|
||||
run_eval=True,
|
||||
test_delay_epochs=-1,
|
||||
epochs=1,
|
||||
print_step=1,
|
||||
print_eval=True,
|
||||
datasets=[dataset_config_en, dataset_config_pt],
|
||||
)
|
||||
config.save_json(config_path)
|
||||
|
||||
# run test
|
||||
run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/find_unique_phonemes.py --config_path "{config_path}"')
|
|
@ -0,0 +1,29 @@
|
|||
import os
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from tests import get_tests_input_path, get_tests_output_path, run_cli
|
||||
|
||||
torch.manual_seed(1)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
class TestRemoveSilenceVAD(unittest.TestCase):
|
||||
@staticmethod
|
||||
def test():
|
||||
# set paths
|
||||
wav_path = os.path.join(get_tests_input_path(), "../data/ljspeech/wavs")
|
||||
output_path = os.path.join(get_tests_output_path(), "output_wavs_removed_silence/")
|
||||
output_resample_path = os.path.join(get_tests_output_path(), "output_ljspeech_16khz/")
|
||||
|
||||
# resample audios
|
||||
run_cli(
|
||||
f'CUDA_VISIBLE_DEVICES="" python TTS/bin/resample.py --input_dir "{wav_path}" --output_dir "{output_resample_path}" --output_sr 16000'
|
||||
)
|
||||
|
||||
# run test
|
||||
run_cli(
|
||||
f'CUDA_VISIBLE_DEVICES="" python TTS/bin/remove_silence_using_vad.py --input_dir "{output_resample_path}" --output_dir "{output_path}"'
|
||||
)
|
||||
run_cli(f'rm -rf "{output_resample_path}"')
|
||||
run_cli(f'rm -rf "{output_path}"')
|
|
@ -13,7 +13,7 @@ file_path = get_tests_input_path()
|
|||
class LSTMSpeakerEncoderTests(unittest.TestCase):
|
||||
# pylint: disable=R0201
|
||||
def test_in_out(self):
|
||||
dummy_input = T.rand(4, 20, 80) # B x T x D
|
||||
dummy_input = T.rand(4, 80, 20) # B x D x T
|
||||
dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)]
|
||||
model = LSTMSpeakerEncoder(input_dim=80, proj_dim=256, lstm_dim=768, num_lstm_layers=3)
|
||||
# computing d vectors
|
||||
|
@ -34,7 +34,7 @@ class LSTMSpeakerEncoderTests(unittest.TestCase):
|
|||
assert output.type() == "torch.FloatTensor"
|
||||
assert abs(assert_diff) < 1e-4, f" [!] output_norm has wrong values - {assert_diff}"
|
||||
# compute d for a given batch
|
||||
dummy_input = T.rand(1, 240, 80) # B x T x D
|
||||
dummy_input = T.rand(1, 80, 240) # B x T x D
|
||||
output = model.compute_embedding(dummy_input, num_frames=160, num_eval=5)
|
||||
assert output.shape[0] == 1
|
||||
assert output.shape[1] == 256
|
||||
|
@ -44,7 +44,7 @@ class LSTMSpeakerEncoderTests(unittest.TestCase):
|
|||
class ResNetSpeakerEncoderTests(unittest.TestCase):
|
||||
# pylint: disable=R0201
|
||||
def test_in_out(self):
|
||||
dummy_input = T.rand(4, 20, 80) # B x T x D
|
||||
dummy_input = T.rand(4, 80, 20) # B x D x T
|
||||
dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)]
|
||||
model = ResNetSpeakerEncoder(input_dim=80, proj_dim=256)
|
||||
# computing d vectors
|
||||
|
@ -61,7 +61,7 @@ class ResNetSpeakerEncoderTests(unittest.TestCase):
|
|||
assert output.type() == "torch.FloatTensor"
|
||||
assert abs(assert_diff) < 1e-4, f" [!] output_norm has wrong values - {assert_diff}"
|
||||
# compute d for a given batch
|
||||
dummy_input = T.rand(1, 240, 80) # B x T x D
|
||||
dummy_input = T.rand(1, 80, 240) # B x D x T
|
||||
output = model.compute_embedding(dummy_input, num_frames=160, num_eval=10)
|
||||
assert output.shape[0] == 1
|
||||
assert output.shape[1] == 256
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
|||
|
||||
from tests import get_tests_input_path
|
||||
from TTS.config import load_config
|
||||
from TTS.speaker_encoder.utils.generic_utils import setup_model
|
||||
from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
|
||||
from TTS.speaker_encoder.utils.io import save_checkpoint
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
@ -28,7 +28,7 @@ class SpeakerManagerTest(unittest.TestCase):
|
|||
config.audio.resample = True
|
||||
|
||||
# create a dummy speaker encoder
|
||||
model = setup_model(config)
|
||||
model = setup_speaker_encoder_model(config)
|
||||
save_checkpoint(model, None, None, get_tests_input_path(), 0)
|
||||
|
||||
# load audio processor and speaker encoder
|
||||
|
@ -38,7 +38,7 @@ class SpeakerManagerTest(unittest.TestCase):
|
|||
# load a sample audio and compute embedding
|
||||
waveform = ap.load_wav(sample_wav_path)
|
||||
mel = ap.melspectrogram(waveform)
|
||||
d_vector = manager.compute_d_vector(mel.T)
|
||||
d_vector = manager.compute_d_vector(mel)
|
||||
assert d_vector.shape[1] == 256
|
||||
|
||||
# compute d_vector directly from an input file
|
||||
|
|
|
@ -38,6 +38,11 @@ class TestTTSDataset(unittest.TestCase):
|
|||
|
||||
def _create_dataloader(self, batch_size, r, bgs):
|
||||
items = ljspeech(c.data_path, "metadata.csv")
|
||||
|
||||
# add a default language because now the TTSDataset expect a language
|
||||
language = ""
|
||||
items = [[*item, language] for item in items]
|
||||
|
||||
dataset = TTSDataset(
|
||||
r,
|
||||
c.text_cleaner,
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
import functools
|
||||
|
||||
import torch
|
||||
|
||||
from TTS.config.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.utils.languages import get_language_weighted_sampler
|
||||
|
||||
# Fixing random state to avoid random fails
|
||||
torch.manual_seed(0)
|
||||
|
||||
dataset_config_en = BaseDatasetConfig(
|
||||
name="ljspeech",
|
||||
meta_file_train="metadata.csv",
|
||||
meta_file_val="metadata.csv",
|
||||
path="tests/data/ljspeech",
|
||||
language="en",
|
||||
)
|
||||
|
||||
dataset_config_pt = BaseDatasetConfig(
|
||||
name="ljspeech",
|
||||
meta_file_train="metadata.csv",
|
||||
meta_file_val="metadata.csv",
|
||||
path="tests/data/ljspeech",
|
||||
language="pt-br",
|
||||
)
|
||||
|
||||
# Adding the EN samples twice to create an unbalanced dataset
|
||||
train_samples, eval_samples = load_tts_samples(
|
||||
[dataset_config_en, dataset_config_en, dataset_config_pt], eval_split=True
|
||||
)
|
||||
|
||||
|
||||
def is_balanced(lang_1, lang_2):
|
||||
return 0.85 < lang_1 / lang_2 < 1.2
|
||||
|
||||
|
||||
random_sampler = torch.utils.data.RandomSampler(train_samples)
|
||||
ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)])
|
||||
en, pt = 0, 0
|
||||
for index in ids:
|
||||
if train_samples[index][3] == "en":
|
||||
en += 1
|
||||
else:
|
||||
pt += 1
|
||||
|
||||
assert not is_balanced(en, pt), "Random sampler is supposed to be unbalanced"
|
||||
|
||||
weighted_sampler = get_language_weighted_sampler(train_samples)
|
||||
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
|
||||
en, pt = 0, 0
|
||||
for index in ids:
|
||||
if train_samples[index][3] == "en":
|
||||
en += 1
|
||||
else:
|
||||
pt += 1
|
||||
|
||||
assert is_balanced(en, pt), "Weighted sampler is supposed to be balanced"
|
|
@ -0,0 +1,5 @@
|
|||
{
|
||||
"en": 0,
|
||||
"fr-fr": 1,
|
||||
"pt-br": 2
|
||||
}
|
|
@ -0,0 +1,240 @@
|
|||
import os
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from tests import assertHasAttr, assertHasNotAttr, get_tests_input_path
|
||||
from TTS.config import load_config
|
||||
from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
from TTS.tts.models.vits import Vits, VitsArgs
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
|
||||
LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
|
||||
SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json")
|
||||
|
||||
|
||||
torch.manual_seed(1)
|
||||
use_cuda = torch.cuda.is_available()
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
class TestVits(unittest.TestCase):
|
||||
def test_init_multispeaker(self):
|
||||
num_speakers = 10
|
||||
args = VitsArgs(num_speakers=num_speakers, use_speaker_embedding=True)
|
||||
model = Vits(args)
|
||||
assertHasAttr(self, model, "emb_g")
|
||||
|
||||
args = VitsArgs(num_speakers=0, use_speaker_embedding=True)
|
||||
model = Vits(args)
|
||||
assertHasNotAttr(self, model, "emb_g")
|
||||
|
||||
args = VitsArgs(num_speakers=10, use_speaker_embedding=False)
|
||||
model = Vits(args)
|
||||
assertHasNotAttr(self, model, "emb_g")
|
||||
|
||||
args = VitsArgs(d_vector_dim=101, use_d_vector_file=True)
|
||||
model = Vits(args)
|
||||
self.assertEqual(model.embedded_speaker_dim, 101)
|
||||
|
||||
def test_init_multilingual(self):
|
||||
args = VitsArgs(language_ids_file=None, use_language_embedding=False)
|
||||
model = Vits(args)
|
||||
self.assertEqual(model.language_manager, None)
|
||||
self.assertEqual(model.embedded_language_dim, 0)
|
||||
self.assertEqual(model.emb_l, None)
|
||||
|
||||
args = VitsArgs(language_ids_file=LANG_FILE)
|
||||
model = Vits(args)
|
||||
self.assertNotEqual(model.language_manager, None)
|
||||
self.assertEqual(model.embedded_language_dim, 0)
|
||||
self.assertEqual(model.emb_l, None)
|
||||
|
||||
args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True)
|
||||
model = Vits(args)
|
||||
self.assertNotEqual(model.language_manager, None)
|
||||
self.assertEqual(model.embedded_language_dim, args.embedded_language_dim)
|
||||
self.assertNotEqual(model.emb_l, None)
|
||||
|
||||
args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, embedded_language_dim=102)
|
||||
model = Vits(args)
|
||||
self.assertNotEqual(model.language_manager, None)
|
||||
self.assertEqual(model.embedded_language_dim, args.embedded_language_dim)
|
||||
self.assertNotEqual(model.emb_l, None)
|
||||
|
||||
def test_get_aux_input(self):
|
||||
aux_input = {"speaker_ids": None, "style_wav": None, "d_vectors": None, "language_ids": None}
|
||||
args = VitsArgs()
|
||||
model = Vits(args)
|
||||
aux_out = model.get_aux_input(aux_input)
|
||||
|
||||
speaker_id = torch.randint(10, (1,))
|
||||
language_id = torch.randint(10, (1,))
|
||||
d_vector = torch.rand(1, 128)
|
||||
aux_input = {"speaker_ids": speaker_id, "style_wav": None, "d_vectors": d_vector, "language_ids": language_id}
|
||||
aux_out = model.get_aux_input(aux_input)
|
||||
self.assertEqual(aux_out["speaker_ids"].shape, speaker_id.shape)
|
||||
self.assertEqual(aux_out["language_ids"].shape, language_id.shape)
|
||||
self.assertEqual(aux_out["d_vectors"].shape, d_vector.unsqueeze(0).transpose(2, 1).shape)
|
||||
|
||||
def test_voice_conversion(self):
|
||||
num_speakers = 10
|
||||
spec_len = 101
|
||||
spec_effective_len = 50
|
||||
|
||||
args = VitsArgs(num_speakers=num_speakers, use_speaker_embedding=True)
|
||||
model = Vits(args)
|
||||
|
||||
ref_inp = torch.randn(1, spec_len, 513)
|
||||
ref_inp_len = torch.randint(1, spec_effective_len, (1,))
|
||||
ref_spk_id = torch.randint(1, num_speakers, (1,))
|
||||
tgt_spk_id = torch.randint(1, num_speakers, (1,))
|
||||
o_hat, y_mask, (z, z_p, z_hat) = model.voice_conversion(ref_inp, ref_inp_len, ref_spk_id, tgt_spk_id)
|
||||
|
||||
self.assertEqual(o_hat.shape, (1, 1, spec_len * 256))
|
||||
self.assertEqual(y_mask.shape, (1, 1, spec_len))
|
||||
self.assertEqual(y_mask.sum(), ref_inp_len[0])
|
||||
self.assertEqual(z.shape, (1, args.hidden_channels, spec_len))
|
||||
self.assertEqual(z_p.shape, (1, args.hidden_channels, spec_len))
|
||||
self.assertEqual(z_hat.shape, (1, args.hidden_channels, spec_len))
|
||||
|
||||
def _init_inputs(self, config):
|
||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
||||
input_lengths[-1] = 128
|
||||
spec = torch.rand(8, config.audio["fft_size"] // 2 + 1, 30).to(device)
|
||||
spec_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||
spec_lengths[-1] = spec.size(2)
|
||||
waveform = torch.rand(8, 1, spec.size(2) * config.audio["hop_length"]).to(device)
|
||||
return input_dummy, input_lengths, spec, spec_lengths, waveform
|
||||
|
||||
def _check_forward_outputs(self, config, output_dict, encoder_config=None):
|
||||
self.assertEqual(
|
||||
output_dict["model_outputs"].shape[2], config.model_args.spec_segment_size * config.audio["hop_length"]
|
||||
)
|
||||
self.assertEqual(output_dict["alignments"].shape, (8, 128, 30))
|
||||
self.assertEqual(output_dict["alignments"].max(), 1)
|
||||
self.assertEqual(output_dict["alignments"].min(), 0)
|
||||
self.assertEqual(output_dict["z"].shape, (8, config.model_args.hidden_channels, 30))
|
||||
self.assertEqual(output_dict["z_p"].shape, (8, config.model_args.hidden_channels, 30))
|
||||
self.assertEqual(output_dict["m_p"].shape, (8, config.model_args.hidden_channels, 30))
|
||||
self.assertEqual(output_dict["logs_p"].shape, (8, config.model_args.hidden_channels, 30))
|
||||
self.assertEqual(output_dict["m_q"].shape, (8, config.model_args.hidden_channels, 30))
|
||||
self.assertEqual(output_dict["logs_q"].shape, (8, config.model_args.hidden_channels, 30))
|
||||
self.assertEqual(
|
||||
output_dict["waveform_seg"].shape[2], config.model_args.spec_segment_size * config.audio["hop_length"]
|
||||
)
|
||||
if encoder_config:
|
||||
self.assertEqual(output_dict["gt_spk_emb"].shape, (8, encoder_config.model_params["proj_dim"]))
|
||||
self.assertEqual(output_dict["syn_spk_emb"].shape, (8, encoder_config.model_params["proj_dim"]))
|
||||
else:
|
||||
self.assertEqual(output_dict["gt_spk_emb"], None)
|
||||
self.assertEqual(output_dict["syn_spk_emb"], None)
|
||||
|
||||
def test_forward(self):
|
||||
num_speakers = 0
|
||||
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
|
||||
config.model_args.spec_segment_size = 10
|
||||
input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config)
|
||||
model = Vits(config).to(device)
|
||||
output_dict = model.forward(input_dummy, input_lengths, spec, spec_lengths, waveform)
|
||||
self._check_forward_outputs(config, output_dict)
|
||||
|
||||
def test_multispeaker_forward(self):
|
||||
num_speakers = 10
|
||||
|
||||
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
|
||||
config.model_args.spec_segment_size = 10
|
||||
|
||||
input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config)
|
||||
speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device)
|
||||
|
||||
model = Vits(config).to(device)
|
||||
output_dict = model.forward(
|
||||
input_dummy, input_lengths, spec, spec_lengths, waveform, aux_input={"speaker_ids": speaker_ids}
|
||||
)
|
||||
self._check_forward_outputs(config, output_dict)
|
||||
|
||||
def test_multilingual_forward(self):
|
||||
num_speakers = 10
|
||||
num_langs = 3
|
||||
|
||||
args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10)
|
||||
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args)
|
||||
|
||||
input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config)
|
||||
speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device)
|
||||
lang_ids = torch.randint(0, num_langs, (8,)).long().to(device)
|
||||
|
||||
model = Vits(config).to(device)
|
||||
output_dict = model.forward(
|
||||
input_dummy,
|
||||
input_lengths,
|
||||
spec,
|
||||
spec_lengths,
|
||||
waveform,
|
||||
aux_input={"speaker_ids": speaker_ids, "language_ids": lang_ids},
|
||||
)
|
||||
self._check_forward_outputs(config, output_dict)
|
||||
|
||||
def test_secl_forward(self):
|
||||
num_speakers = 10
|
||||
num_langs = 3
|
||||
|
||||
speaker_encoder_config = load_config(SPEAKER_ENCODER_CONFIG)
|
||||
speaker_encoder_config.model_params["use_torch_spec"] = True
|
||||
speaker_encoder = setup_speaker_encoder_model(speaker_encoder_config).to(device)
|
||||
speaker_manager = SpeakerManager()
|
||||
speaker_manager.speaker_encoder = speaker_encoder
|
||||
|
||||
args = VitsArgs(
|
||||
language_ids_file=LANG_FILE,
|
||||
use_language_embedding=True,
|
||||
spec_segment_size=10,
|
||||
use_speaker_encoder_as_loss=True,
|
||||
)
|
||||
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args)
|
||||
config.audio.sample_rate = 16000
|
||||
|
||||
input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config)
|
||||
speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device)
|
||||
lang_ids = torch.randint(0, num_langs, (8,)).long().to(device)
|
||||
|
||||
model = Vits(config, speaker_manager=speaker_manager).to(device)
|
||||
output_dict = model.forward(
|
||||
input_dummy,
|
||||
input_lengths,
|
||||
spec,
|
||||
spec_lengths,
|
||||
waveform,
|
||||
aux_input={"speaker_ids": speaker_ids, "language_ids": lang_ids},
|
||||
)
|
||||
self._check_forward_outputs(config, output_dict, speaker_encoder_config)
|
||||
|
||||
def test_inference(self):
|
||||
num_speakers = 0
|
||||
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
|
||||
input_dummy = torch.randint(0, 24, (1, 128)).long().to(device)
|
||||
model = Vits(config).to(device)
|
||||
_ = model.inference(input_dummy)
|
||||
|
||||
def test_multispeaker_inference(self):
|
||||
num_speakers = 10
|
||||
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
|
||||
input_dummy = torch.randint(0, 24, (1, 128)).long().to(device)
|
||||
speaker_ids = torch.randint(0, num_speakers, (1,)).long().to(device)
|
||||
model = Vits(config).to(device)
|
||||
_ = model.inference(input_dummy, {"speaker_ids": speaker_ids})
|
||||
|
||||
def test_multilingual_inference(self):
|
||||
num_speakers = 10
|
||||
num_langs = 3
|
||||
args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10)
|
||||
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args)
|
||||
input_dummy = torch.randint(0, 24, (1, 128)).long().to(device)
|
||||
speaker_ids = torch.randint(0, num_speakers, (1,)).long().to(device)
|
||||
lang_ids = torch.randint(0, num_langs, (1,)).long().to(device)
|
||||
model = Vits(config).to(device)
|
||||
_ = model.inference(input_dummy, {"speaker_ids": speaker_ids, "language_ids": lang_ids})
|
|
@ -0,0 +1,62 @@
|
|||
import glob
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
|
||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||
|
||||
|
||||
config = VitsConfig(
|
||||
batch_size=2,
|
||||
eval_batch_size=2,
|
||||
num_loader_workers=0,
|
||||
num_eval_loader_workers=0,
|
||||
text_cleaner="english_cleaners",
|
||||
use_phonemes=True,
|
||||
use_espeak_phonemes=True,
|
||||
phoneme_language="en-us",
|
||||
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
|
||||
run_eval=True,
|
||||
test_delay_epochs=-1,
|
||||
epochs=1,
|
||||
print_step=1,
|
||||
print_eval=True,
|
||||
test_sentences=[
|
||||
["Be a voice, not an echo.", "ljspeech-0"],
|
||||
],
|
||||
)
|
||||
# set audio config
|
||||
config.audio.do_trim_silence = True
|
||||
config.audio.trim_db = 60
|
||||
|
||||
# active multispeaker d-vec mode
|
||||
config.model_args.use_d_vector_file = True
|
||||
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
|
||||
config.model_args.d_vector_dim = 256
|
||||
|
||||
|
||||
config.save_json(config_path)
|
||||
|
||||
# train the model for one epoch
|
||||
command_train = (
|
||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||
f"--coqpit.output_path {output_path} "
|
||||
"--coqpit.datasets.0.name ljspeech "
|
||||
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
|
||||
"--coqpit.test_delay_epochs 0"
|
||||
)
|
||||
run_cli(command_train)
|
||||
|
||||
# Find latest folder
|
||||
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||
|
||||
# restore the model and continue training for one more epoch
|
||||
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||
run_cli(command_train)
|
||||
shutil.rmtree(continue_path)
|
|
@ -0,0 +1,91 @@
|
|||
import glob
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.config.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
|
||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||
|
||||
|
||||
dataset_config_en = BaseDatasetConfig(
|
||||
name="ljspeech",
|
||||
meta_file_train="metadata.csv",
|
||||
meta_file_val="metadata.csv",
|
||||
path="tests/data/ljspeech",
|
||||
language="en",
|
||||
)
|
||||
|
||||
dataset_config_pt = BaseDatasetConfig(
|
||||
name="ljspeech",
|
||||
meta_file_train="metadata.csv",
|
||||
meta_file_val="metadata.csv",
|
||||
path="tests/data/ljspeech",
|
||||
language="pt-br",
|
||||
)
|
||||
|
||||
config = VitsConfig(
|
||||
batch_size=2,
|
||||
eval_batch_size=2,
|
||||
num_loader_workers=0,
|
||||
num_eval_loader_workers=0,
|
||||
text_cleaner="english_cleaners",
|
||||
use_phonemes=False,
|
||||
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
|
||||
run_eval=True,
|
||||
test_delay_epochs=-1,
|
||||
epochs=1,
|
||||
print_step=1,
|
||||
print_eval=True,
|
||||
test_sentences=[
|
||||
["Be a voice, not an echo.", "ljspeech-0", None, "en"],
|
||||
["Be a voice, not an echo.", "ljspeech-1", None, "pt-br"],
|
||||
],
|
||||
datasets=[dataset_config_en, dataset_config_pt],
|
||||
)
|
||||
# set audio config
|
||||
config.audio.do_trim_silence = True
|
||||
config.audio.trim_db = 60
|
||||
|
||||
# active multilingual mode
|
||||
config.model_args.use_language_embedding = True
|
||||
config.use_language_embedding = True
|
||||
|
||||
# deactivate multispeaker mode
|
||||
config.model_args.use_speaker_embedding = False
|
||||
config.use_speaker_embedding = False
|
||||
|
||||
# active multispeaker d-vec mode
|
||||
config.model_args.use_d_vector_file = True
|
||||
config.use_d_vector_file = True
|
||||
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
|
||||
config.d_vector_file = "tests/data/ljspeech/speakers.json"
|
||||
config.model_args.d_vector_dim = 256
|
||||
config.d_vector_dim = 256
|
||||
|
||||
# duration predictor
|
||||
config.model_args.use_sdp = True
|
||||
config.use_sdp = True
|
||||
|
||||
# deactivate language sampler
|
||||
config.use_language_weighted_sampler = False
|
||||
|
||||
config.save_json(config_path)
|
||||
|
||||
# train the model for one epoch
|
||||
command_train = (
|
||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||
f"--coqpit.output_path {output_path} "
|
||||
"--coqpit.test_delay_epochs 0"
|
||||
)
|
||||
run_cli(command_train)
|
||||
|
||||
# Find latest folder
|
||||
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||
|
||||
# restore the model and continue training for one more epoch
|
||||
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||
run_cli(command_train)
|
||||
shutil.rmtree(continue_path)
|
|
@ -0,0 +1,88 @@
|
|||
import glob
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.config.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
|
||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||
|
||||
|
||||
dataset_config_en = BaseDatasetConfig(
|
||||
name="ljspeech",
|
||||
meta_file_train="metadata.csv",
|
||||
meta_file_val="metadata.csv",
|
||||
path="tests/data/ljspeech",
|
||||
language="en",
|
||||
)
|
||||
|
||||
dataset_config_pt = BaseDatasetConfig(
|
||||
name="ljspeech",
|
||||
meta_file_train="metadata.csv",
|
||||
meta_file_val="metadata.csv",
|
||||
path="tests/data/ljspeech",
|
||||
language="pt-br",
|
||||
)
|
||||
|
||||
config = VitsConfig(
|
||||
batch_size=2,
|
||||
eval_batch_size=2,
|
||||
num_loader_workers=0,
|
||||
num_eval_loader_workers=0,
|
||||
text_cleaner="english_cleaners",
|
||||
use_phonemes=True,
|
||||
use_espeak_phonemes=True,
|
||||
phoneme_language="en-us",
|
||||
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
|
||||
run_eval=True,
|
||||
test_delay_epochs=-1,
|
||||
epochs=1,
|
||||
print_step=1,
|
||||
print_eval=True,
|
||||
test_sentences=[
|
||||
["Be a voice, not an echo.", "ljspeech", None, "en"],
|
||||
["Be a voice, not an echo.", "ljspeech", None, "pt-br"],
|
||||
],
|
||||
datasets=[dataset_config_en, dataset_config_pt],
|
||||
)
|
||||
# set audio config
|
||||
config.audio.do_trim_silence = True
|
||||
config.audio.trim_db = 60
|
||||
|
||||
# active multilingual mode
|
||||
config.model_args.use_language_embedding = True
|
||||
config.use_language_embedding = True
|
||||
# active multispeaker mode
|
||||
config.model_args.use_speaker_embedding = True
|
||||
config.use_speaker_embedding = True
|
||||
|
||||
# deactivate multispeaker d-vec mode
|
||||
config.model_args.use_d_vector_file = False
|
||||
config.use_d_vector_file = False
|
||||
|
||||
# duration predictor
|
||||
config.model_args.use_sdp = False
|
||||
config.use_sdp = False
|
||||
|
||||
# active language sampler
|
||||
config.use_language_weighted_sampler = True
|
||||
|
||||
config.save_json(config_path)
|
||||
|
||||
# train the model for one epoch
|
||||
command_train = (
|
||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||
f"--coqpit.output_path {output_path} "
|
||||
"--coqpit.test_delay_epochs 0"
|
||||
)
|
||||
run_cli(command_train)
|
||||
|
||||
# Find latest folder
|
||||
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||
|
||||
# restore the model and continue training for one more epoch
|
||||
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||
run_cli(command_train)
|
||||
shutil.rmtree(continue_path)
|
|
@ -0,0 +1,63 @@
|
|||
import glob
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
|
||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||
|
||||
|
||||
config = VitsConfig(
|
||||
batch_size=2,
|
||||
eval_batch_size=2,
|
||||
num_loader_workers=0,
|
||||
num_eval_loader_workers=0,
|
||||
text_cleaner="english_cleaners",
|
||||
use_phonemes=True,
|
||||
use_espeak_phonemes=True,
|
||||
phoneme_language="en-us",
|
||||
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
|
||||
run_eval=True,
|
||||
test_delay_epochs=-1,
|
||||
epochs=1,
|
||||
print_step=1,
|
||||
print_eval=True,
|
||||
test_sentences=[
|
||||
["Be a voice, not an echo.", "ljspeech"],
|
||||
],
|
||||
)
|
||||
# set audio config
|
||||
config.audio.do_trim_silence = True
|
||||
config.audio.trim_db = 60
|
||||
|
||||
# active multispeaker d-vec mode
|
||||
config.model_args.use_speaker_embedding = True
|
||||
config.model_args.use_d_vector_file = False
|
||||
config.model_args.d_vector_file = None
|
||||
config.model_args.d_vector_dim = 256
|
||||
|
||||
|
||||
config.save_json(config_path)
|
||||
|
||||
# train the model for one epoch
|
||||
command_train = (
|
||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||
f"--coqpit.output_path {output_path} "
|
||||
"--coqpit.datasets.0.name ljspeech "
|
||||
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
|
||||
"--coqpit.test_delay_epochs 0"
|
||||
)
|
||||
run_cli(command_train)
|
||||
|
||||
# Find latest folder
|
||||
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||
|
||||
# restore the model and continue training for one more epoch
|
||||
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||
run_cli(command_train)
|
||||
shutil.rmtree(continue_path)
|
|
@ -25,7 +25,7 @@ config = VitsConfig(
|
|||
print_step=1,
|
||||
print_eval=True,
|
||||
test_sentences=[
|
||||
"Be a voice, not an echo.",
|
||||
["Be a voice, not an echo."],
|
||||
],
|
||||
)
|
||||
config.audio.do_trim_silence = True
|
||||
|
|
|
@ -4,6 +4,7 @@ import os
|
|||
import shutil
|
||||
|
||||
from tests import get_tests_output_path, run_cli
|
||||
from TTS.tts.utils.languages import LanguageManager
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.utils.generic_utils import get_user_data_dir
|
||||
from TTS.utils.manage import ModelManager
|
||||
|
@ -17,21 +18,30 @@ def test_run_all_models():
|
|||
manager = ModelManager(output_prefix=get_tests_output_path())
|
||||
model_names = manager.list_models()
|
||||
for model_name in model_names:
|
||||
print(f"\n > Run - {model_name}")
|
||||
model_path, _, _ = manager.download_model(model_name)
|
||||
if "tts_models" in model_name:
|
||||
local_download_dir = os.path.dirname(model_path)
|
||||
# download and run the model
|
||||
speaker_files = glob.glob(local_download_dir + "/speaker*")
|
||||
language_files = glob.glob(local_download_dir + "/language*")
|
||||
language_id = ""
|
||||
if len(speaker_files) > 0:
|
||||
# multi-speaker model
|
||||
if "speaker_ids" in speaker_files[0]:
|
||||
speaker_manager = SpeakerManager(speaker_id_file_path=speaker_files[0])
|
||||
elif "speakers" in speaker_files[0]:
|
||||
speaker_manager = SpeakerManager(d_vectors_file_path=speaker_files[0])
|
||||
|
||||
# multi-lingual model - Assuming multi-lingual models are also multi-speaker
|
||||
if len(language_files) > 0 and "language_ids" in language_files[0]:
|
||||
language_manager = LanguageManager(language_ids_file_path=language_files[0])
|
||||
language_id = language_manager.language_names[0]
|
||||
|
||||
speaker_id = list(speaker_manager.speaker_ids.keys())[0]
|
||||
run_cli(
|
||||
f"tts --model_name {model_name} "
|
||||
f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}"'
|
||||
f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}" --language_idx "{language_id}" '
|
||||
)
|
||||
else:
|
||||
# single-speaker model
|
||||
|
|
Loading…
Reference in New Issue