Merge branch 'pr/Edresson/731-rebased' into dev

This commit is contained in:
Eren Gölge 2022-01-01 15:37:35 +00:00
commit d37cfe474a
55 changed files with 2625 additions and 323 deletions

5
.gitignore vendored
View File

@ -128,6 +128,8 @@ core
recipes/WIP/*
recipes/ljspeech/LJSpeech-1.1/*
recipes/vctk/VCTK/*
recipes/**/*.npy
recipes/**/*.json
VCTK-Corpus-removed-silence/*
# ignore training logs
@ -161,4 +163,5 @@ speakers.json
internal/*
*_pitch.npy
*_phoneme.npy
wandb
wandb
depot/*

View File

@ -1,5 +1,17 @@
{
"tts_models": {
"multilingual":{
"multi-dataset":{
"your_tts":{
"description": "Your TTS model accompanying the paper https://arxiv.org/abs/2112.02418",
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.5.0_models/tts_models--multilingual--multi-dataset--your_tts.zip",
"default_vocoder": null,
"commit": "e9a1953e",
"license": "CC BY-NC-ND 4.0",
"contact": "egolge@coqui.ai"
}
}
},
"en": {
"ek1": {
"tacotron2": {

View File

@ -12,7 +12,7 @@ from tqdm import tqdm
from TTS.config import load_config
from TTS.tts.datasets import TTSDataset, load_tts_samples
from TTS.tts.models import setup_model
from TTS.tts.utils.speakers import get_speaker_manager
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters
@ -37,8 +37,8 @@ def setup_loader(ap, r, verbose=False):
enable_eos_bos=c.enable_eos_bos_chars,
use_noise_augment=False,
verbose=verbose,
speaker_id_mapping=speaker_manager.speaker_ids,
d_vector_mapping=speaker_manager.d_vectors if c.use_speaker_embedding and c.use_d_vector_file else None,
speaker_id_mapping=speaker_manager.speaker_ids if c.use_speaker_embedding else None,
d_vector_mapping=speaker_manager.d_vectors if c.use_d_vector_file else None,
)
if c.use_phonemes and c.compute_input_seq_cache:
@ -234,8 +234,13 @@ def main(args): # pylint: disable=redefined-outer-name
# use eval and training partitions
meta_data = meta_data_train + meta_data_eval
# parse speakers
speaker_manager = get_speaker_manager(c, args, meta_data_train)
# init speaker manager
if c.use_speaker_embedding:
speaker_manager = SpeakerManager(data_items=meta_data)
elif c.use_d_vector_file:
speaker_manager = SpeakerManager(d_vectors_file_path=c.d_vector_file)
else:
speaker_manager = None
# setup model
model = setup_model(c)

View File

@ -0,0 +1,62 @@
"""Find all the unique characters in a dataset"""
import argparse
import multiprocessing
from argparse import RawTextHelpFormatter
from tqdm.contrib.concurrent import process_map
from TTS.config import load_config
from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.text import text2phone
def compute_phonemes(item):
try:
text = item[0]
language = item[-1]
ph = text2phone(text, language, use_espeak_phonemes=c.use_espeak_phonemes).split("|")
except:
return []
return list(set(ph))
def main():
# pylint: disable=W0601
global c
# pylint: disable=bad-option-value
parser = argparse.ArgumentParser(
description="""Find all the unique characters or phonemes in a dataset.\n\n"""
"""
Example runs:
python TTS/bin/find_unique_chars.py --config_path config.json
""",
formatter_class=RawTextHelpFormatter,
)
parser.add_argument("--config_path", type=str, help="Path to dataset config file.", required=True)
args = parser.parse_args()
c = load_config(args.config_path)
# load all datasets
train_items, eval_items = load_tts_samples(c.datasets, eval_split=True)
items = train_items + eval_items
print("Num items:", len(items))
phonemes = process_map(compute_phonemes, items, max_workers=multiprocessing.cpu_count(), chunksize=15)
phones = []
for ph in phonemes:
phones.extend(ph)
phones = set(phones)
lower_phones = filter(lambda c: c.islower(), phones)
phones_force_lower = [c.lower() for c in phones]
phones_force_lower = set(phones_force_lower)
print(f" > Number of unique phonemes: {len(phones)}")
print(f" > Unique phonemes: {''.join(sorted(phones))}")
print(f" > Unique lower phonemes: {''.join(sorted(lower_phones))}")
print(f" > Unique all forced to lower phonemes: {''.join(sorted(phones_force_lower))}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,89 @@
import argparse
import glob
import multiprocessing
import os
import pathlib
from tqdm.contrib.concurrent import process_map
from TTS.utils.vad import get_vad_speech_segments, read_wave, write_wave
def remove_silence(filepath):
output_path = filepath.replace(os.path.join(args.input_dir, ""), os.path.join(args.output_dir, ""))
# ignore if the file exists
if os.path.exists(output_path) and not args.force:
return
# create all directory structure
pathlib.Path(output_path).parent.mkdir(parents=True, exist_ok=True)
# load wave
audio, sample_rate = read_wave(filepath)
# get speech segments
segments = get_vad_speech_segments(audio, sample_rate, aggressiveness=args.aggressiveness)
segments = list(segments)
num_segments = len(segments)
flag = False
# create the output wave
if num_segments != 0:
for i, segment in reversed(list(enumerate(segments))):
if i >= 1:
if not flag:
concat_segment = segment
flag = True
else:
concat_segment = segment + concat_segment
else:
if flag:
segment = segment + concat_segment
# print("Saving: ", output_path)
write_wave(output_path, segment, sample_rate)
return
else:
print("> Just Copying the file to:", output_path)
# if fail to remove silence just write the file
write_wave(output_path, audio, sample_rate)
return
def preprocess_audios():
files = sorted(glob.glob(os.path.join(args.input_dir, args.glob), recursive=True))
print("> Number of files: ", len(files))
if not args.force:
print("> Ignoring files that already exist in the output directory.")
if files:
# create threads
num_threads = multiprocessing.cpu_count()
process_map(remove_silence, files, max_workers=num_threads, chunksize=15)
else:
print("> No files Found !")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="python remove_silence.py -i=VCTK-Corpus-bk/ -o=../VCTK-Corpus-removed-silence -g=wav48/*/*.wav -a=2"
)
parser.add_argument("-i", "--input_dir", type=str, default="../VCTK-Corpus", help="Dataset root dir")
parser.add_argument(
"-o", "--output_dir", type=str, default="../VCTK-Corpus-removed-silence", help="Output Dataset dir"
)
parser.add_argument("-f", "--force", default=False, action="store_true", help="Force the replace of exists files")
parser.add_argument(
"-g",
"--glob",
type=str,
default="**/*.wav",
help="path in glob format for acess wavs from input_dir. ex: wav48/*/*.wav",
)
parser.add_argument(
"-a",
"--aggressiveness",
type=int,
default=2,
help="set its aggressiveness mode, which is an integer between 0 and 3. 0 is the least aggressive about filtering out non-speech, 3 is the most aggressive.",
)
args = parser.parse_args()
preprocess_audios()

View File

@ -152,12 +152,19 @@ If you don't specify any models, then it uses LJSpeech based English model.
# args for multi-speaker synthesis
parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None)
parser.add_argument("--language_ids_file_path", type=str, help="JSON file for multi-lingual model.", default=None)
parser.add_argument(
"--speaker_idx",
type=str,
help="Target speaker ID for a multi-speaker TTS model.",
default=None,
)
parser.add_argument(
"--language_idx",
type=str,
help="Target language ID for a multi-lingual TTS model.",
default=None,
)
parser.add_argument(
"--speaker_wav",
nargs="+",
@ -173,6 +180,14 @@ If you don't specify any models, then it uses LJSpeech based English model.
const=True,
default=False,
)
parser.add_argument(
"--list_language_idxs",
help="List available language ids for the defined multi-lingual model.",
type=str2bool,
nargs="?",
const=True,
default=False,
)
# aux args
parser.add_argument(
"--save_spectogram",
@ -184,7 +199,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
args = parser.parse_args()
# print the description if either text or list_models is not set
if args.text is None and not args.list_models and not args.list_speaker_idxs:
if args.text is None and not args.list_models and not args.list_speaker_idxs and not args.list_language_idxs:
parser.parse_args(["-h"])
# load model manager
@ -194,6 +209,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
model_path = None
config_path = None
speakers_file_path = None
language_ids_file_path = None
vocoder_path = None
vocoder_config_path = None
encoder_path = None
@ -217,6 +233,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
model_path = args.model_path
config_path = args.config_path
speakers_file_path = args.speakers_file_path
language_ids_file_path = args.language_ids_file_path
if args.vocoder_path is not None:
vocoder_path = args.vocoder_path
@ -231,6 +248,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
model_path,
config_path,
speakers_file_path,
language_ids_file_path,
vocoder_path,
vocoder_config_path,
encoder_path,
@ -246,6 +264,14 @@ If you don't specify any models, then it uses LJSpeech based English model.
print(synthesizer.tts_model.speaker_manager.speaker_ids)
return
# query langauge ids of a multi-lingual model.
if args.list_language_idxs:
print(
" > Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model."
)
print(synthesizer.tts_model.language_manager.language_id_mapping)
return
# check the arguments against a multi-speaker model.
if synthesizer.tts_speakers_file and (not args.speaker_idx and not args.speaker_wav):
print(
@ -258,7 +284,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
print(" > Text: {}".format(args.text))
# kick it
wav = synthesizer.tts(args.text, args.speaker_idx, args.speaker_wav, args.gst_style)
wav = synthesizer.tts(args.text, args.speaker_idx, args.language_idx, args.speaker_wav)
# save the results
print(" > Saving output to {}".format(args.out_path))

View File

@ -11,7 +11,7 @@ from torch.utils.data import DataLoader
from TTS.speaker_encoder.dataset import SpeakerEncoderDataset
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_model
from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_speaker_encoder_model
from TTS.speaker_encoder.utils.training import init_training
from TTS.speaker_encoder.utils.visual import plot_embeddings
from TTS.tts.datasets import load_tts_samples
@ -151,7 +151,7 @@ def main(args): # pylint: disable=redefined-outer-name
global meta_data_eval
ap = AudioProcessor(**c.audio)
model = setup_model(c)
model = setup_speaker_encoder_model(c)
optimizer = RAdam(model.parameters(), lr=c.lr)

View File

@ -1,9 +1,10 @@
import os
from TTS.config import load_config, register_config
from TTS.config import check_config_and_model_args, get_from_config_or_model_args, load_config, register_config
from TTS.trainer import Trainer, TrainingArgs
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models import setup_model
from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor
@ -45,15 +46,32 @@ def main():
ap = AudioProcessor(**config.audio)
# init speaker manager
if config.use_speaker_embedding:
if check_config_and_model_args(config, "use_speaker_embedding", True):
speaker_manager = SpeakerManager(data_items=train_samples + eval_samples)
elif config.use_d_vector_file:
speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file)
if hasattr(config, "model_args"):
config.model_args.num_speakers = speaker_manager.num_speakers
else:
config.num_speakers = speaker_manager.num_speakers
elif check_config_and_model_args(config, "use_d_vector_file", True):
speaker_manager = SpeakerManager(d_vectors_file_path=get_from_config_or_model_args(config, "d_vector_file"))
if hasattr(config, "model_args"):
config.model_args.num_speakers = speaker_manager.num_speakers
else:
config.num_speakers = speaker_manager.num_speakers
else:
speaker_manager = None
if hasattr(config, "use_language_embedding") and config.use_language_embedding:
language_manager = LanguageManager(config=config)
if hasattr(config, "model_args"):
config.model_args.num_languages = language_manager.num_languages
else:
config.num_languages = language_manager.num_languages
else:
language_manager = None
# init the model from config
model = setup_model(config, speaker_manager)
model = setup_model(config, speaker_manager, language_manager)
# init the trainer and 🚀
trainer = Trainer(

View File

@ -95,3 +95,38 @@ def load_config(config_path: str) -> None:
config = config_class()
config.from_dict(config_dict)
return config
def check_config_and_model_args(config, arg_name, value):
"""Check the give argument in `config.model_args` if exist or in `config` for
the given value.
Return False if the argument does not exist in `config.model_args` or `config`.
This is to patch up the compatibility between models with and without `model_args`.
TODO: Remove this in the future with a unified approach.
"""
if hasattr(config, "model_args"):
if arg_name in config.model_args:
return config.model_args[arg_name] == value
if hasattr(config, arg_name):
return config[arg_name] == value
return False
def get_from_config_or_model_args(config, arg_name):
"""Get the given argument from `config.model_args` if exist or in `config`."""
if hasattr(config, "model_args"):
if arg_name in config.model_args:
return config.model_args[arg_name]
return config[arg_name]
def get_from_config_or_model_args_with_default(config, arg_name, def_val):
"""Get the given argument from `config.model_args` if exist or in `config`."""
if hasattr(config, "model_args"):
if arg_name in config.model_args:
return config.model_args[arg_name]
if hasattr(config, arg_name):
return config[arg_name]
return def_val

View File

@ -60,6 +60,12 @@ class BaseAudioConfig(Coqpit):
trim_db (int):
Silence threshold used for silence trimming. Defaults to 45.
do_rms_norm (bool, optional):
enable/disable RMS volume normalization when loading an audio file. Defaults to False.
db_level (int, optional):
dB level used for rms normalization. The range is -99 to 0. Defaults to None.
power (float):
Exponent used for expanding spectrogra levels before running Griffin Lim. It helps to reduce the
artifacts in the synthesized voice. Defaults to 1.5.
@ -116,6 +122,9 @@ class BaseAudioConfig(Coqpit):
# silence trimming
do_trim_silence: bool = True
trim_db: int = 45
# rms volume normalization
do_rms_norm: bool = False
db_level: float = None
# griffin-lim params
power: float = 1.5
griffin_lim_iters: int = 60
@ -198,7 +207,8 @@ class BaseDatasetConfig(Coqpit):
name: str = ""
path: str = ""
meta_file_train: str = ""
ununsed_speakers: List[str] = None
ignored_speakers: List[str] = None
language: str = ""
meta_file_val: str = ""
meta_file_attn_mask: str = ""
@ -335,6 +345,8 @@ class BaseTrainingConfig(Coqpit):
num_loader_workers: int = 0
num_eval_loader_workers: int = 0
use_noise_augment: bool = False
use_language_weighted_sampler: bool = False
# paths
output_path: str = None
# distributed

View File

@ -100,7 +100,15 @@ if args.vocoder_path is not None:
# load models
synthesizer = Synthesizer(
model_path, config_path, speakers_file_path, vocoder_path, vocoder_config_path, use_cuda=args.use_cuda
tts_checkpoint=model_path,
tts_config_path=config_path,
tts_speakers_file=speakers_file_path,
tts_languages_file=None,
vocoder_checkpoint=vocoder_path,
vocoder_config=vocoder_config_path,
encoder_checkpoint="",
encoder_config="",
use_cuda=args.use_cuda,
)
use_multi_speaker = hasattr(synthesizer.tts_model, "num_speakers") and synthesizer.tts_model.num_speakers > 1
@ -165,7 +173,7 @@ def tts():
style_wav = style_wav_uri_to_dict(style_wav)
print(" > Model input: {}".format(text))
wavs = synthesizer.tts(text, speaker_idx=speaker_idx, style_wav=style_wav)
wavs = synthesizer.tts(text, speaker_name=speaker_idx, style_wav=style_wav)
out = io.BytesIO()
synthesizer.save_wav(wavs, out)
return send_file(out, mimetype="audio/wav")

View File

@ -250,4 +250,4 @@ class SpeakerEncoderDataset(Dataset):
feats = torch.stack(feats)
labels = torch.stack(labels)
return feats.transpose(1, 2), labels
return feats, labels

View File

@ -1,7 +1,9 @@
import numpy as np
import torch
import torchaudio
from torch import nn
from TTS.speaker_encoder.models.resnet import PreEmphasis
from TTS.utils.io import load_fsspec
@ -33,9 +35,21 @@ class LSTMWithoutProjection(nn.Module):
class LSTMSpeakerEncoder(nn.Module):
def __init__(self, input_dim, proj_dim=256, lstm_dim=768, num_lstm_layers=3, use_lstm_with_projection=True):
def __init__(
self,
input_dim,
proj_dim=256,
lstm_dim=768,
num_lstm_layers=3,
use_lstm_with_projection=True,
use_torch_spec=False,
audio_config=None,
):
super().__init__()
self.use_lstm_with_projection = use_lstm_with_projection
self.use_torch_spec = use_torch_spec
self.audio_config = audio_config
layers = []
# choise LSTM layer
if use_lstm_with_projection:
@ -46,6 +60,38 @@ class LSTMSpeakerEncoder(nn.Module):
else:
self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers)
self.instancenorm = nn.InstanceNorm1d(input_dim)
if self.use_torch_spec:
self.torch_spec = torch.nn.Sequential(
PreEmphasis(audio_config["preemphasis"]),
# TorchSTFT(
# n_fft=audio_config["fft_size"],
# hop_length=audio_config["hop_length"],
# win_length=audio_config["win_length"],
# sample_rate=audio_config["sample_rate"],
# window="hamming_window",
# mel_fmin=0.0,
# mel_fmax=None,
# use_htk=True,
# do_amp_to_db=False,
# n_mels=audio_config["num_mels"],
# power=2.0,
# use_mel=True,
# mel_norm=None,
# )
torchaudio.transforms.MelSpectrogram(
sample_rate=audio_config["sample_rate"],
n_fft=audio_config["fft_size"],
win_length=audio_config["win_length"],
hop_length=audio_config["hop_length"],
window_fn=torch.hamming_window,
n_mels=audio_config["num_mels"],
),
)
else:
self.torch_spec = None
self._init_layers()
def _init_layers(self):
@ -55,22 +101,33 @@ class LSTMSpeakerEncoder(nn.Module):
elif "weight" in name:
nn.init.xavier_normal_(param)
def forward(self, x):
# TODO: implement state passing for lstms
def forward(self, x, l2_norm=True):
"""Forward pass of the model.
Args:
x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
to compute the spectrogram on-the-fly.
l2_norm (bool): Whether to L2-normalize the outputs.
Shapes:
- x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
"""
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False):
if self.use_torch_spec:
x.squeeze_(1)
x = self.torch_spec(x)
x = self.instancenorm(x).transpose(1, 2)
d = self.layers(x)
if self.use_lstm_with_projection:
d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1)
else:
d = d[:, -1]
if l2_norm:
d = torch.nn.functional.normalize(d, p=2, dim=1)
return d
@torch.no_grad()
def inference(self, x):
d = self.layers.forward(x)
if self.use_lstm_with_projection:
d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1)
else:
d = torch.nn.functional.normalize(d, p=2, dim=1)
def inference(self, x, l2_norm=True):
d = self.forward(x, l2_norm=l2_norm)
return d
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True):

View File

@ -1,10 +1,25 @@
import numpy as np
import torch
import torchaudio
from torch import nn
# from TTS.utils.audio import TorchSTFT
from TTS.utils.io import load_fsspec
class PreEmphasis(nn.Module):
def __init__(self, coefficient=0.97):
super().__init__()
self.coefficient = coefficient
self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0))
def forward(self, x):
assert len(x.size()) == 2
x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect")
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
class SELayer(nn.Module):
def __init__(self, channel, reduction=8):
super(SELayer, self).__init__()
@ -70,12 +85,17 @@ class ResNetSpeakerEncoder(nn.Module):
num_filters=[32, 64, 128, 256],
encoder_type="ASP",
log_input=False,
use_torch_spec=False,
audio_config=None,
):
super(ResNetSpeakerEncoder, self).__init__()
self.encoder_type = encoder_type
self.input_dim = input_dim
self.log_input = log_input
self.use_torch_spec = use_torch_spec
self.audio_config = audio_config
self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU(inplace=True)
self.bn1 = nn.BatchNorm2d(num_filters[0])
@ -88,6 +108,36 @@ class ResNetSpeakerEncoder(nn.Module):
self.instancenorm = nn.InstanceNorm1d(input_dim)
if self.use_torch_spec:
self.torch_spec = torch.nn.Sequential(
PreEmphasis(audio_config["preemphasis"]),
# TorchSTFT(
# n_fft=audio_config["fft_size"],
# hop_length=audio_config["hop_length"],
# win_length=audio_config["win_length"],
# sample_rate=audio_config["sample_rate"],
# window="hamming_window",
# mel_fmin=0.0,
# mel_fmax=None,
# use_htk=True,
# do_amp_to_db=False,
# n_mels=audio_config["num_mels"],
# power=2.0,
# use_mel=True,
# mel_norm=None,
# )
torchaudio.transforms.MelSpectrogram(
sample_rate=audio_config["sample_rate"],
n_fft=audio_config["fft_size"],
win_length=audio_config["win_length"],
hop_length=audio_config["hop_length"],
window_fn=torch.hamming_window,
n_mels=audio_config["num_mels"],
),
)
else:
self.torch_spec = None
outmap_size = int(self.input_dim / 8)
self.attention = nn.Sequential(
@ -140,9 +190,23 @@ class ResNetSpeakerEncoder(nn.Module):
return out
def forward(self, x, l2_norm=False):
x = x.transpose(1, 2)
"""Forward pass of the model.
Args:
x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
to compute the spectrogram on-the-fly.
l2_norm (bool): Whether to L2-normalize the outputs.
Shapes:
- x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
"""
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False):
x.squeeze_(1)
# if you torch spec compute it otherwise use the mel spec computed by the AP
if self.use_torch_spec:
x = self.torch_spec(x)
if self.log_input:
x = (x + 1e-6).log()
x = self.instancenorm(x).unsqueeze(1)
@ -175,11 +239,19 @@ class ResNetSpeakerEncoder(nn.Module):
return x
@torch.no_grad()
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True):
def inference(self, x, l2_norm=False):
return self.forward(x, l2_norm)
@torch.no_grad()
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True, l2_norm=True):
"""
Generate embeddings for a batch of utterances
x: 1xTxD
"""
# map to the waveform size
if self.use_torch_spec:
num_frames = num_frames * self.audio_config["hop_length"]
max_len = x.shape[1]
if max_len < num_frames:
@ -195,11 +267,10 @@ class ResNetSpeakerEncoder(nn.Module):
frames_batch.append(frames)
frames_batch = torch.cat(frames_batch, dim=0)
embeddings = self.forward(frames_batch, l2_norm=True)
embeddings = self.inference(frames_batch, l2_norm=l2_norm)
if return_mean:
embeddings = torch.mean(embeddings, dim=0, keepdim=True)
return embeddings
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):

View File

@ -170,16 +170,24 @@ def to_camel(text):
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
def setup_model(c):
if c.model_params["model_name"].lower() == "lstm":
def setup_speaker_encoder_model(config: "Coqpit"):
if config.model_params["model_name"].lower() == "lstm":
model = LSTMSpeakerEncoder(
c.model_params["input_dim"],
c.model_params["proj_dim"],
c.model_params["lstm_dim"],
c.model_params["num_lstm_layers"],
config.model_params["input_dim"],
config.model_params["proj_dim"],
config.model_params["lstm_dim"],
config.model_params["num_lstm_layers"],
use_torch_spec=config.model_params.get("use_torch_spec", False),
audio_config=config.audio,
)
elif config.model_params["model_name"].lower() == "resnet":
model = ResNetSpeakerEncoder(
input_dim=config.model_params["input_dim"],
proj_dim=config.model_params["proj_dim"],
log_input=config.model_params.get("log_input", False),
use_torch_spec=config.model_params.get("use_torch_spec", False),
audio_config=config.audio,
)
elif c.model_params["model_name"].lower() == "resnet":
model = ResNetSpeakerEncoder(input_dim=c.model_params["input_dim"], proj_dim=c.model_params["proj_dim"])
return model

View File

@ -202,7 +202,7 @@ class Trainer:
os.makedirs(output_path, exist_ok=True)
# copy training assets to the output folder
copy_model_files(config, output_path, new_fields=None)
copy_model_files(config, output_path)
# init class members
self.args = args
@ -439,7 +439,7 @@ class Trainer:
if "scaler" in checkpoint and self.use_amp_scaler and checkpoint["scaler"]:
print(" > Restoring Scaler...")
scaler = _restore_list_objs(checkpoint["scaler"], scaler)
except (KeyError, RuntimeError):
except (KeyError, RuntimeError, ValueError):
print(" > Partial model initialization...")
model_dict = model.state_dict()
model_dict = set_init_dict(model_dict, checkpoint["model"], config)

View File

@ -82,8 +82,14 @@ class VitsConfig(BaseTTSConfig):
add_blank (bool):
If true, a blank token is added in between every character. Defaults to `True`.
test_sentences (List[str]):
List of sentences to be used for testing.
test_sentences (List[List]):
List of sentences with speaker and language information to be used for testing.
language_ids_file (str):
Path to the language ids file.
use_language_embedding (bool):
If true, language embedding is used. Defaults to `False`.
Note:
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
@ -117,6 +123,7 @@ class VitsConfig(BaseTTSConfig):
feat_loss_alpha: float = 1.0
mel_loss_alpha: float = 45.0
dur_loss_alpha: float = 1.0
speaker_encoder_loss_alpha: float = 1.0
# data loader params
return_wav: bool = True
@ -130,13 +137,13 @@ class VitsConfig(BaseTTSConfig):
add_blank: bool = True
# testing
test_sentences: List[str] = field(
test_sentences: List[List] = field(
default_factory=lambda: [
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.",
"I'm sorry Dave. I'm afraid I can't do that.",
"This cake is great. It's so delicious and moist.",
"Prior to November 22, 1963.",
["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent."],
["Be a voice, not an echo."],
["I'm sorry Dave. I'm afraid I can't do that."],
["This cake is great. It's so delicious and moist."],
["Prior to November 22, 1963."],
]
)
@ -146,29 +153,15 @@ class VitsConfig(BaseTTSConfig):
use_speaker_embedding: bool = False
speakers_file: str = None
speaker_embedding_channels: int = 256
language_ids_file: str = None
use_language_embedding: bool = False
# use d-vectors
use_d_vector_file: bool = False
d_vector_file: str = False
d_vector_file: str = None
d_vector_dim: int = None
def __post_init__(self):
# Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there.
if self.num_speakers > 0:
self.model_args.num_speakers = self.num_speakers
# speaker embedding settings
if self.use_speaker_embedding:
self.model_args.use_speaker_embedding = True
if self.speakers_file:
self.model_args.speakers_file = self.speakers_file
if self.speaker_embedding_channels:
self.model_args.speaker_embedding_channels = self.speaker_embedding_channels
# d-vector settings
if self.use_d_vector_file:
self.model_args.use_d_vector_file = True
if self.d_vector_dim is not None and self.d_vector_dim > 0:
self.model_args.d_vector_dim = self.d_vector_dim
if self.d_vector_file:
self.model_args.d_vector_file = self.d_vector_file
for key, val in self.model_args.items():
if hasattr(self, key):
self[key] = val

View File

@ -67,16 +67,22 @@ def load_tts_samples(
root_path = dataset["path"]
meta_file_train = dataset["meta_file_train"]
meta_file_val = dataset["meta_file_val"]
ignored_speakers = dataset["ignored_speakers"]
language = dataset["language"]
# setup the right data processor
if formatter is None:
formatter = _get_formatter_by_name(name)
# load train set
meta_data_train = formatter(root_path, meta_file_train)
meta_data_train = formatter(root_path, meta_file_train, ignored_speakers=ignored_speakers)
meta_data_train = [[*item, language] for item in meta_data_train]
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
# load evaluation split if set
if eval_split:
if meta_file_val:
meta_data_eval = formatter(root_path, meta_file_val)
meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers)
meta_data_eval = [[*item, language] for item in meta_data_eval]
else:
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
meta_data_eval_all += meta_data_eval

View File

@ -37,6 +37,7 @@ class TTSDataset(Dataset):
enable_eos_bos: bool = False,
speaker_id_mapping: Dict = None,
d_vector_mapping: Dict = None,
language_id_mapping: Dict = None,
use_noise_augment: bool = False,
verbose: bool = False,
):
@ -122,7 +123,9 @@ class TTSDataset(Dataset):
self.enable_eos_bos = enable_eos_bos
self.speaker_id_mapping = speaker_id_mapping
self.d_vector_mapping = d_vector_mapping
self.language_id_mapping = language_id_mapping
self.use_noise_augment = use_noise_augment
self.verbose = verbose
self.input_seq_computed = False
self.rescue_item_idx = 1
@ -197,10 +200,10 @@ class TTSDataset(Dataset):
def load_data(self, idx):
item = self.items[idx]
if len(item) == 4:
text, wav_file, speaker_name, attn_file = item
if len(item) == 5:
text, wav_file, speaker_name, language_name, attn_file = item
else:
text, wav_file, speaker_name = item
text, wav_file, speaker_name, language_name = item
attn = None
raw_text = text
@ -218,7 +221,7 @@ class TTSDataset(Dataset):
self.phoneme_cache_path,
self.enable_eos_bos,
self.cleaners,
self.phoneme_language,
language_name if language_name else self.phoneme_language,
self.custom_symbols,
self.characters,
self.add_blank,
@ -260,6 +263,7 @@ class TTSDataset(Dataset):
"attn": attn,
"item_idx": self.items[idx][1],
"speaker_name": speaker_name,
"language_name": language_name,
"wav_file_name": os.path.basename(wav_file),
}
return sample
@ -269,6 +273,7 @@ class TTSDataset(Dataset):
item = args[0]
func_args = args[1]
text, wav_file, *_ = item
func_args[3] = item[3]
phonemes = TTSDataset._load_or_generate_phoneme_sequence(wav_file, text, *func_args)
return phonemes
@ -335,7 +340,6 @@ class TTSDataset(Dataset):
else:
lengths = np.array([len(ins[0]) for ins in self.items])
# sort items based on the sequence length in ascending order
idxs = np.argsort(lengths)
new_items = []
ignored = []
@ -345,10 +349,7 @@ class TTSDataset(Dataset):
ignored.append(idx)
else:
new_items.append(self.items[idx])
# shuffle batch groups
# create batches with similar length items
# the larger the `batch_group_size`, the higher the length variety in a batch.
if self.batch_group_size > 0:
for i in range(len(new_items) // self.batch_group_size):
offset = i * self.batch_group_size
@ -356,14 +357,8 @@ class TTSDataset(Dataset):
temp_items = new_items[offset:end_offset]
random.shuffle(temp_items)
new_items[offset:end_offset] = temp_items
if len(new_items) == 0:
raise RuntimeError(" [!] No items left after filtering.")
# update items to the new sorted items
self.items = new_items
# logging
if self.verbose:
print(" | > Max length sequence: {}".format(np.max(lengths)))
print(" | > Min length sequence: {}".format(np.min(lengths)))
@ -413,9 +408,14 @@ class TTSDataset(Dataset):
# convert list of dicts to dict of lists
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
# get language ids from language names
if self.language_id_mapping is not None:
language_ids = [self.language_id_mapping[ln] for ln in batch["language_name"]]
else:
language_ids = None
# get pre-computed d-vectors
if self.d_vector_mapping is not None:
wav_files_names = [batch["wav_file_name"][idx] for idx in ids_sorted_decreasing]
wav_files_names = list(batch["wav_file_name"])
d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names]
else:
d_vectors = None
@ -466,6 +466,9 @@ class TTSDataset(Dataset):
if speaker_ids is not None:
speaker_ids = torch.LongTensor(speaker_ids)
if language_ids is not None:
language_ids = torch.LongTensor(language_ids)
# compute linear spectrogram
if self.compute_linear_spec:
linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]]
@ -528,6 +531,7 @@ class TTSDataset(Dataset):
"waveform": wav_padded,
"raw_text": batch["raw_text"],
"pitch": pitch,
"language_ids": language_ids,
}
raise TypeError(
@ -542,7 +546,6 @@ class TTSDataset(Dataset):
class PitchExtractor:
"""Pitch Extractor for computing F0 from wav files.
Args:
items (List[List]): Dataset samples.
verbose (bool): Whether to print the progress.

View File

@ -12,7 +12,7 @@ from tqdm import tqdm
########################
def tweb(root_path, meta_file):
def tweb(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
"""Normalize TWEB dataset.
https://www.kaggle.com/bryanpark/the-world-english-bible-speech-dataset
"""
@ -28,7 +28,7 @@ def tweb(root_path, meta_file):
return items
def mozilla(root_path, meta_file):
def mozilla(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
"""Normalizes Mozilla meta data files to TTS format"""
txt_file = os.path.join(root_path, meta_file)
items = []
@ -43,7 +43,7 @@ def mozilla(root_path, meta_file):
return items
def mozilla_de(root_path, meta_file):
def mozilla_de(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
"""Normalizes Mozilla meta data files to TTS format"""
txt_file = os.path.join(root_path, meta_file)
items = []
@ -59,7 +59,7 @@ def mozilla_de(root_path, meta_file):
return items
def mailabs(root_path, meta_files=None):
def mailabs(root_path, meta_files=None, ignored_speakers=None):
"""Normalizes M-AI-Labs meta data files to TTS format
Args:
@ -68,25 +68,34 @@ def mailabs(root_path, meta_files=None):
recursively. Defaults to None
"""
speaker_regex = re.compile("by_book/(male|female)/(?P<speaker_name>[^/]+)/")
if meta_files is None:
if not meta_files:
csv_files = glob(root_path + "/**/metadata.csv", recursive=True)
else:
csv_files = meta_files
# meta_files = [f.strip() for f in meta_files.split(",")]
items = []
for csv_file in csv_files:
txt_file = os.path.join(root_path, csv_file)
if os.path.isfile(csv_file):
txt_file = csv_file
else:
txt_file = os.path.join(root_path, csv_file)
folder = os.path.dirname(txt_file)
# determine speaker based on folder structure...
speaker_name_match = speaker_regex.search(txt_file)
if speaker_name_match is None:
continue
speaker_name = speaker_name_match.group("speaker_name")
# ignore speakers
if isinstance(ignored_speakers, list):
if speaker_name in ignored_speakers:
continue
print(" | > {}".format(csv_file))
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("|")
if meta_files is None:
if not meta_files:
wav_file = os.path.join(folder, "wavs", cols[0] + ".wav")
else:
wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav")
@ -94,11 +103,12 @@ def mailabs(root_path, meta_files=None):
text = cols[1].strip()
items.append([text, wav_file, speaker_name])
else:
raise RuntimeError("> File %s does not exist!" % (wav_file))
# M-AI-Labs have some missing samples, so just print the warning
print("> File %s does not exist!" % (wav_file))
return items
def ljspeech(root_path, meta_file):
def ljspeech(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
"""Normalizes the LJSpeech meta data file to TTS format
https://keithito.com/LJ-Speech-Dataset/"""
txt_file = os.path.join(root_path, meta_file)
@ -113,7 +123,7 @@ def ljspeech(root_path, meta_file):
return items
def ljspeech_test(root_path, meta_file):
def ljspeech_test(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
"""Normalizes the LJSpeech meta data file for TTS testing
https://keithito.com/LJ-Speech-Dataset/"""
txt_file = os.path.join(root_path, meta_file)
@ -127,7 +137,7 @@ def ljspeech_test(root_path, meta_file):
return items
def sam_accenture(root_path, meta_file):
def sam_accenture(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
"""Normalizes the sam-accenture meta data file to TTS format
https://github.com/Sam-Accenture-Non-Binary-Voice/non-binary-voice-files"""
xml_file = os.path.join(root_path, "voice_over_recordings", meta_file)
@ -144,12 +154,12 @@ def sam_accenture(root_path, meta_file):
return items
def ruslan(root_path, meta_file):
def ruslan(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
"""Normalizes the RUSLAN meta data file to TTS format
https://ruslan-corpus.github.io/"""
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "ljspeech"
speaker_name = "ruslan"
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("|")
@ -159,11 +169,11 @@ def ruslan(root_path, meta_file):
return items
def css10(root_path, meta_file):
def css10(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
"""Normalizes the CSS10 dataset file to TTS format"""
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "ljspeech"
speaker_name = "css10"
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("|")
@ -173,7 +183,7 @@ def css10(root_path, meta_file):
return items
def nancy(root_path, meta_file):
def nancy(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
"""Normalizes the Nancy meta data file to TTS format"""
txt_file = os.path.join(root_path, meta_file)
items = []
@ -187,7 +197,7 @@ def nancy(root_path, meta_file):
return items
def common_voice(root_path, meta_file):
def common_voice(root_path, meta_file, ignored_speakers=None):
"""Normalize the common voice meta data file to TTS format."""
txt_file = os.path.join(root_path, meta_file)
items = []
@ -198,15 +208,19 @@ def common_voice(root_path, meta_file):
cols = line.split("\t")
text = cols[2]
speaker_name = cols[0]
# ignore speakers
if isinstance(ignored_speakers, list):
if speaker_name in ignored_speakers:
continue
wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav"))
items.append([text, wav_file, "MCV_" + speaker_name])
return items
def libri_tts(root_path, meta_files=None):
def libri_tts(root_path, meta_files=None, ignored_speakers=None):
"""https://ai.google/tools/datasets/libri-tts/"""
items = []
if meta_files is None:
if not meta_files:
meta_files = glob(f"{root_path}/**/*trans.tsv", recursive=True)
else:
if isinstance(meta_files, str):
@ -222,13 +236,17 @@ def libri_tts(root_path, meta_files=None):
_root_path = os.path.join(root_path, f"{speaker_name}/{chapter_id}")
wav_file = os.path.join(_root_path, file_name + ".wav")
text = cols[2]
# ignore speakers
if isinstance(ignored_speakers, list):
if speaker_name in ignored_speakers:
continue
items.append([text, wav_file, "LTTS_" + speaker_name])
for item in items:
assert os.path.exists(item[1]), f" [!] wav files don't exist - {item[1]}"
return items
def custom_turkish(root_path, meta_file):
def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "turkish-female"
@ -247,7 +265,7 @@ def custom_turkish(root_path, meta_file):
# ToDo: add the dataset link when the dataset is released publicly
def brspeech(root_path, meta_file):
def brspeech(root_path, meta_file, ignored_speakers=None):
"""BRSpeech 3.0 beta"""
txt_file = os.path.join(root_path, meta_file)
items = []
@ -258,21 +276,25 @@ def brspeech(root_path, meta_file):
cols = line.split("|")
wav_file = os.path.join(root_path, cols[0])
text = cols[2]
speaker_name = cols[3]
items.append([text, wav_file, speaker_name])
speaker_id = cols[3]
# ignore speakers
if isinstance(ignored_speakers, list):
if speaker_id in ignored_speakers:
continue
items.append([text, wav_file, speaker_id])
return items
def vctk(root_path, meta_files=None, wavs_path="wav48"):
def vctk(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None):
"""homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz"""
test_speakers = meta_files
items = []
meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
for meta_file in meta_files:
_, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep)
file_id = txt_file.split(".")[0]
if isinstance(test_speakers, list): # if is list ignore this speakers ids
if speaker_id in test_speakers:
# ignore speakers
if isinstance(ignored_speakers, list):
if speaker_id in ignored_speakers:
continue
with open(meta_file, "r", encoding="utf-8") as file_text:
text = file_text.readlines()[0]
@ -282,15 +304,16 @@ def vctk(root_path, meta_files=None, wavs_path="wav48"):
return items
def vctk_slim(root_path, meta_files=None, wavs_path="wav48"):
def vctk_slim(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None): # pylint: disable=unused-argument
"""homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz"""
items = []
txt_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
for text_file in txt_files:
_, speaker_id, txt_file = os.path.relpath(text_file, root_path).split(os.sep)
file_id = txt_file.split(".")[0]
if isinstance(meta_files, list): # if is list ignore this speakers ids
if speaker_id in meta_files:
# ignore speakers
if isinstance(ignored_speakers, list):
if speaker_id in ignored_speakers:
continue
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
items.append([None, wav_file, "VCTK_" + speaker_id])
@ -298,7 +321,7 @@ def vctk_slim(root_path, meta_files=None, wavs_path="wav48"):
return items
def mls(root_path, meta_files=None):
def mls(root_path, meta_files=None, ignored_speakers=None):
"""http://www.openslr.org/94/"""
items = []
with open(os.path.join(root_path, meta_files), "r", encoding="utf-8") as meta:
@ -307,19 +330,23 @@ def mls(root_path, meta_files=None):
text = text[:-1]
speaker, book, *_ = file.split("_")
wav_file = os.path.join(root_path, os.path.dirname(meta_files), "audio", speaker, book, file + ".wav")
# ignore speakers
if isinstance(ignored_speakers, list):
if speaker in ignored_speakers:
continue
items.append([text, wav_file, "MLS_" + speaker])
return items
# ======================================== VOX CELEB ===========================================
def voxceleb2(root_path, meta_file=None):
def voxceleb2(root_path, meta_file=None, **kwargs): # pylint: disable=unused-argument
"""
:param meta_file Used only for consistency with load_tts_samples api
"""
return _voxcel_x(root_path, meta_file, voxcel_idx="2")
def voxceleb1(root_path, meta_file=None):
def voxceleb1(root_path, meta_file=None, **kwargs): # pylint: disable=unused-argument
"""
:param meta_file Used only for consistency with load_tts_samples api
"""
@ -361,7 +388,7 @@ def _voxcel_x(root_path, meta_file, voxcel_idx):
return [x.strip().split("|") for x in f.readlines()]
def baker(root_path: str, meta_file: str) -> List[List[str]]:
def baker(root_path: str, meta_file: str, **kwargs) -> List[List[str]]: # pylint: disable=unused-argument
"""Normalizes the Baker meta data file to TTS format
Args:
@ -381,7 +408,7 @@ def baker(root_path: str, meta_file: str) -> List[List[str]]:
return items
def kokoro(root_path, meta_file):
def kokoro(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
"""Japanese single-speaker dataset from https://github.com/kaiidams/Kokoro-Speech-Dataset"""
txt_file = os.path.join(root_path, meta_file)
items = []

View File

@ -18,8 +18,13 @@ class DurationPredictor(nn.Module):
dropout_p (float): Dropout rate used after each conv layer.
"""
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p, cond_channels=None):
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p, cond_channels=None, language_emb_dim=None):
super().__init__()
# add language embedding dim in the input
if language_emb_dim:
in_channels += language_emb_dim
# class arguments
self.in_channels = in_channels
self.filter_channels = hidden_channels
@ -36,7 +41,10 @@ class DurationPredictor(nn.Module):
if cond_channels is not None and cond_channels != 0:
self.cond = nn.Conv1d(cond_channels, in_channels, 1)
def forward(self, x, x_mask, g=None):
if language_emb_dim != 0 and language_emb_dim is not None:
self.cond_lang = nn.Conv1d(language_emb_dim, in_channels, 1)
def forward(self, x, x_mask, g=None, lang_emb=None):
"""
Shapes:
- x: :math:`[B, C, T]`
@ -45,6 +53,10 @@ class DurationPredictor(nn.Module):
"""
if g is not None:
x = x + self.cond(g)
if lang_emb is not None:
x = x + self.cond_lang(lang_emb)
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)

View File

@ -532,6 +532,7 @@ class VitsGeneratorLoss(nn.Module):
self.feat_loss_alpha = c.feat_loss_alpha
self.dur_loss_alpha = c.dur_loss_alpha
self.mel_loss_alpha = c.mel_loss_alpha
self.spk_encoder_loss_alpha = c.speaker_encoder_loss_alpha
self.stft = TorchSTFT(
c.audio.fft_size,
c.audio.hop_length,
@ -585,6 +586,11 @@ class VitsGeneratorLoss(nn.Module):
l = kl / torch.sum(z_mask)
return l
@staticmethod
def cosine_similarity_loss(gt_spk_emb, syn_spk_emb):
l = -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean()
return l
def forward(
self,
waveform,
@ -598,6 +604,9 @@ class VitsGeneratorLoss(nn.Module):
feats_disc_fake,
feats_disc_real,
loss_duration,
use_speaker_encoder_as_loss=False,
gt_spk_emb=None,
syn_spk_emb=None,
):
"""
Shapes:
@ -618,13 +627,20 @@ class VitsGeneratorLoss(nn.Module):
# compute mel spectrograms from the waveforms
mel = self.stft(waveform)
mel_hat = self.stft(waveform_hat)
# compute losses
loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha
loss_feat = self.feature_loss(feats_disc_fake, feats_disc_real) * self.feat_loss_alpha
loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha
loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha
loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha
loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
if use_speaker_encoder_as_loss:
loss_se = self.cosine_similarity_loss(gt_spk_emb, syn_spk_emb) * self.spk_encoder_loss_alpha
loss += loss_se
return_dict["loss_spk_encoder"] = loss_se
# pass losses to the dict
return_dict["loss_gen"] = loss_gen
return_dict["loss_kl"] = loss_kl

View File

@ -37,6 +37,7 @@ class TextEncoder(nn.Module):
num_layers: int,
kernel_size: int,
dropout_p: float,
language_emb_dim: int = None,
):
"""Text Encoder for VITS model.
@ -55,8 +56,12 @@ class TextEncoder(nn.Module):
self.hidden_channels = hidden_channels
self.emb = nn.Embedding(n_vocab, hidden_channels)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5)
if language_emb_dim:
hidden_channels += language_emb_dim
self.encoder = RelativePositionTransformer(
in_channels=hidden_channels,
out_channels=hidden_channels,
@ -72,13 +77,18 @@ class TextEncoder(nn.Module):
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths):
def forward(self, x, x_lengths, lang_emb=None):
"""
Shapes:
- x: :math:`[B, T]`
- x_length: :math:`[B]`
"""
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
# concat the lang emb in embedding chars
if lang_emb is not None:
x = torch.cat((x, lang_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1)
x = torch.transpose(x, 1, -1) # [b, h, t]
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)

View File

@ -178,10 +178,21 @@ class StochasticDurationPredictor(nn.Module):
"""
def __init__(
self, in_channels: int, hidden_channels: int, kernel_size: int, dropout_p: float, num_flows=4, cond_channels=0
self,
in_channels: int,
hidden_channels: int,
kernel_size: int,
dropout_p: float,
num_flows=4,
cond_channels=0,
language_emb_dim=0,
):
super().__init__()
# add language embedding dim in the input
if language_emb_dim:
in_channels += language_emb_dim
# condition encoder text
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p)
@ -205,7 +216,10 @@ class StochasticDurationPredictor(nn.Module):
if cond_channels != 0 and cond_channels is not None:
self.cond = nn.Conv1d(cond_channels, hidden_channels, 1)
def forward(self, x, x_mask, dr=None, g=None, reverse=False, noise_scale=1.0):
if language_emb_dim != 0 and language_emb_dim is not None:
self.cond_lang = nn.Conv1d(language_emb_dim, hidden_channels, 1)
def forward(self, x, x_mask, dr=None, g=None, lang_emb=None, reverse=False, noise_scale=1.0):
"""
Shapes:
- x: :math:`[B, C, T]`
@ -217,6 +231,10 @@ class StochasticDurationPredictor(nn.Module):
x = self.pre(x)
if g is not None:
x = x + self.cond(g)
if lang_emb is not None:
x = x + self.cond_lang(lang_emb)
x = self.convs(x, x_mask)
x = self.proj(x) * x_mask

View File

@ -2,7 +2,7 @@ from TTS.tts.utils.text.symbols import make_symbols, parse_symbols
from TTS.utils.generic_utils import find_module
def setup_model(config, speaker_manager: "SpeakerManager" = None):
def setup_model(config, speaker_manager: "SpeakerManager" = None, language_manager: "LanguageManager" = None):
print(" > Using model: {}".format(config.model))
# fetch the right model implementation.
if "base_model" in config and config["base_model"] is not None:
@ -31,7 +31,10 @@ def setup_model(config, speaker_manager: "SpeakerManager" = None):
config.model_params.num_chars = num_chars
if "model_args" in config:
config.model_args.num_chars = num_chars
model = MyModel(config, speaker_manager=speaker_manager)
if config.model.lower() in ["vits"]: # If model supports multiple languages
model = MyModel(config, speaker_manager=speaker_manager, language_manager=language_manager)
else:
model = MyModel(config, speaker_manager=speaker_manager)
return model

View File

@ -12,7 +12,8 @@ from torch.utils.data.distributed import DistributedSampler
from TTS.model import BaseModel
from TTS.tts.configs.shared_configs import CharactersConfig
from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text import make_symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
@ -73,9 +74,18 @@ class BaseTTS(BaseModel):
def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager:
return get_speaker_manager(config, restore_path, data, out_path)
def init_multispeaker(self, config: Coqpit):
"""Init speaker embedding layer if `use_speaker_embedding` is True and set the expected speaker embedding
vector dimension in the network. If model uses d-vectors, then it only sets the expected dimension.
def init_multispeaker(self, config: Coqpit, data: List = None):
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
`in_channels` size of the connected layers.
This implementation yields 3 possible outcomes:
1. If `config.use_speaker_embedding` and `config.use_d_vector_file are False, do nothing.
2. If `config.use_d_vector_file` is True, set expected embedding channel size to `config.d_vector_dim` or 512.
3. If `config.use_speaker_embedding`, initialize a speaker embedding layer with channel size of
`config.d_vector_dim` or 512.
You can override this function for new models.
Args:
config (Coqpit): Model configuration.
@ -97,6 +107,57 @@ class BaseTTS(BaseModel):
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
self.speaker_embedding.weight.data.normal_(0, 0.3)
def get_aux_input(self, **kwargs) -> Dict:
"""Prepare and return `aux_input` used by `forward()`"""
return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None}
def get_aux_input_from_test_setences(self, sentence_info):
if hasattr(self.config, "model_args"):
config = self.config.model_args
else:
config = self.config
# extract speaker and language info
text, speaker_name, style_wav, language_name = None, None, None, None
if isinstance(sentence_info, list):
if len(sentence_info) == 1:
text = sentence_info[0]
elif len(sentence_info) == 2:
text, speaker_name = sentence_info
elif len(sentence_info) == 3:
text, speaker_name, style_wav = sentence_info
elif len(sentence_info) == 4:
text, speaker_name, style_wav, language_name = sentence_info
else:
text = sentence_info
# get speaker id/d_vector
speaker_id, d_vector, language_id = None, None, None
if hasattr(self, "speaker_manager"):
if config.use_d_vector_file:
if speaker_name is None:
d_vector = self.speaker_manager.get_random_d_vector()
else:
d_vector = self.speaker_manager.get_d_vector_by_speaker(speaker_name)
elif config.use_speaker_embedding:
if speaker_name is None:
speaker_id = self.speaker_manager.get_random_speaker_id()
else:
speaker_id = self.speaker_manager.speaker_ids[speaker_name]
# get language id
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
language_id = self.language_manager.language_id_mapping[language_name]
return {
"text": text,
"speaker_id": speaker_id,
"style_wav": style_wav,
"d_vector": d_vector,
"language_id": language_id,
}
def format_batch(self, batch: Dict) -> Dict:
"""Generic batch formatting for `TTSDataset`.
@ -122,6 +183,7 @@ class BaseTTS(BaseModel):
attn_mask = batch["attns"]
waveform = batch["waveform"]
pitch = batch["pitch"]
language_ids = batch["language_ids"]
max_text_length = torch.max(text_lengths.float())
max_spec_length = torch.max(mel_lengths.float())
@ -169,6 +231,7 @@ class BaseTTS(BaseModel):
"item_idx": item_idx,
"waveform": waveform,
"pitch": pitch,
"language_ids": language_ids,
}
def get_data_loader(
@ -188,8 +251,15 @@ class BaseTTS(BaseModel):
# setup multi-speaker attributes
if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None
d_vector_mapping = self.speaker_manager.d_vectors if config.use_d_vector_file else None
if hasattr(config, "model_args"):
speaker_id_mapping = (
self.speaker_manager.speaker_ids if config.model_args.use_speaker_embedding else None
)
d_vector_mapping = self.speaker_manager.d_vectors if config.model_args.use_d_vector_file else None
config.use_d_vector_file = config.model_args.use_d_vector_file
else:
speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None
d_vector_mapping = self.speaker_manager.d_vectors if config.use_d_vector_file else None
else:
speaker_id_mapping = None
d_vector_mapping = None
@ -199,7 +269,14 @@ class BaseTTS(BaseModel):
if hasattr(self, "make_symbols"):
custom_symbols = self.make_symbols(self.config)
# init dataset
if hasattr(self, "language_manager"):
language_id_mapping = (
self.language_manager.language_id_mapping if self.args.use_language_embedding else None
)
else:
language_id_mapping = None
# init dataloader
dataset = TTSDataset(
outputs_per_step=config.r if "r" in config else 1,
text_cleaner=config.text_cleaner,
@ -222,7 +299,8 @@ class BaseTTS(BaseModel):
use_noise_augment=False if is_eval else config.use_noise_augment,
verbose=verbose,
speaker_id_mapping=speaker_id_mapping,
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
d_vector_mapping=d_vector_mapping,
language_id_mapping=language_id_mapping,
)
# pre-compute phonemes
@ -268,7 +346,22 @@ class BaseTTS(BaseModel):
# sampler for DDP
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
# init dataloader
# Weighted samplers
assert not (
num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False)
), "language_weighted_sampler is not supported with DistributedSampler"
assert not (
num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False)
), "speaker_weighted_sampler is not supported with DistributedSampler"
if sampler is None:
if getattr(config, "use_language_weighted_sampler", False):
print(" > Using Language weighted sampler")
sampler = get_language_weighted_sampler(dataset.items)
elif getattr(config, "use_speaker_weighted_sampler", False):
print(" > Using Language weighted sampler")
sampler = get_speaker_weighted_sampler(dataset.items)
loader = DataLoader(
dataset,
batch_size=config.eval_batch_size if is_eval else config.batch_size,
@ -340,8 +433,7 @@ class BaseTTS(BaseModel):
return test_figures, test_audios
def on_init_start(self, trainer):
"""Save the speaker.json at the beginning of the training. And update the config.json with the
speakers.json file path."""
"""Save the speaker.json and language_ids.json at the beginning of the training. Also update both paths."""
if self.speaker_manager is not None:
output_path = os.path.join(trainer.output_path, "speakers.json")
self.speaker_manager.save_speaker_ids_to_file(output_path)
@ -352,3 +444,13 @@ class BaseTTS(BaseModel):
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
print(f" > `speakers.json` is saved to {output_path}.")
print(" > `speakers_file` is updated in the config.json.")
if hasattr(self, "language_manager") and self.language_manager is not None:
output_path = os.path.join(trainer.output_path, "language_ids.json")
self.language_manager.save_language_ids_to_file(output_path)
trainer.config.language_ids_file = output_path
if hasattr(trainer.config, "model_args"):
trainer.config.model_args.language_ids_file = output_path
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
print(f" > `language_ids.json` is saved to {output_path}.")
print(" > `language_ids_file` is updated in the config.json.")

View File

@ -1,13 +1,15 @@
import math
import random
from dataclasses import dataclass, field
from itertools import chain
from typing import Dict, List, Tuple
import torch
# import torchaudio
from coqpit import Coqpit
from torch import nn
from torch.cuda.amp.autocast_mode import autocast
from torch.nn import functional as F
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
@ -15,6 +17,7 @@ from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlock
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment
@ -138,11 +141,50 @@ class VitsArgs(Coqpit):
use_d_vector_file (bool):
Enable/Disable the use of d-vectors for multi-speaker training. Defaults to False.
d_vector_file (str):
Path to the file including pre-computed speaker embeddings. Defaults to None.
d_vector_dim (int):
Number of d-vector channels. Defaults to 0.
detach_dp_input (bool):
Detach duration predictor's input from the network for stopping the gradients. Defaults to True.
use_language_embedding (bool):
Enable/Disable language embedding for multilingual models. Defaults to False.
embedded_language_dim (int):
Number of language embedding channels. Defaults to 4.
num_languages (int):
Number of languages for the language embedding layer. Defaults to 0.
language_ids_file (str):
Path to the language mapping file for the Language Manager. Defaults to None.
use_speaker_encoder_as_loss (bool):
Enable/Disable Speaker Consistency Loss (SCL). Defaults to False.
speaker_encoder_config_path (str):
Path to the file speaker encoder config file, to use for SCL. Defaults to "".
speaker_encoder_model_path (str):
Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "".
freeze_encoder (bool):
Freeze the encoder weigths during training. Defaults to False.
freeze_DP (bool):
Freeze the duration predictor weigths during training. Defaults to False.
freeze_PE (bool):
Freeze the posterior encoder weigths during training. Defaults to False.
freeze_flow_encoder (bool):
Freeze the flow encoder weigths during training. Defaults to False.
freeze_waveform_decoder (bool):
Freeze the waveform decoder weigths during training. Defaults to False.
"""
num_chars: int = 100
@ -179,11 +221,23 @@ class VitsArgs(Coqpit):
use_speaker_embedding: bool = False
num_speakers: int = 0
speakers_file: str = None
d_vector_file: str = None
speaker_embedding_channels: int = 256
use_d_vector_file: bool = False
d_vector_file: str = None
d_vector_dim: int = 0
detach_dp_input: bool = True
use_language_embedding: bool = False
embedded_language_dim: int = 4
num_languages: int = 0
language_ids_file: str = None
use_speaker_encoder_as_loss: bool = False
speaker_encoder_config_path: str = ""
speaker_encoder_model_path: str = ""
freeze_encoder: bool = False
freeze_DP: bool = False
freeze_PE: bool = False
freeze_flow_decoder: bool = False
freeze_waveform_decoder: bool = False
class Vits(BaseTTS):
@ -216,13 +270,18 @@ class Vits(BaseTTS):
# pylint: disable=dangerous-default-value
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
def __init__(
self,
config: Coqpit,
speaker_manager: SpeakerManager = None,
language_manager: LanguageManager = None,
):
super().__init__(config)
self.END2END = True
self.speaker_manager = speaker_manager
self.language_manager = language_manager
if config.__class__.__name__ == "VitsConfig":
# loading from VitsConfig
if "num_chars" not in config:
@ -242,6 +301,7 @@ class Vits(BaseTTS):
self.args = args
self.init_multispeaker(config)
self.init_multilingual(config)
self.length_scale = args.length_scale
self.noise_scale = args.noise_scale
@ -260,6 +320,7 @@ class Vits(BaseTTS):
args.num_layers_text_encoder,
args.kernel_size_text_encoder,
args.dropout_p_text_encoder,
language_emb_dim=self.embedded_language_dim,
)
self.posterior_encoder = PosteriorEncoder(
@ -289,10 +350,16 @@ class Vits(BaseTTS):
args.dropout_p_duration_predictor,
4,
cond_channels=self.embedded_speaker_dim,
language_emb_dim=self.embedded_language_dim,
)
else:
self.duration_predictor = DurationPredictor(
args.hidden_channels, 256, 3, args.dropout_p_duration_predictor, cond_channels=self.embedded_speaker_dim
args.hidden_channels,
256,
3,
args.dropout_p_duration_predictor,
cond_channels=self.embedded_speaker_dim,
language_emb_dim=self.embedded_language_dim,
)
self.waveform_decoder = HifiganGenerator(
@ -318,54 +385,158 @@ class Vits(BaseTTS):
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
or with external `d_vectors` computed from a speaker encoder model.
You must provide a `speaker_manager` at initialization to set up the multi-speaker modules.
Args:
config (Coqpit): Model configuration.
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
"""
self.embedded_speaker_dim = 0
if hasattr(config, "model_args"):
config = config.model_args
self.num_speakers = self.args.num_speakers
self.num_speakers = config.num_speakers
if self.speaker_manager:
self.num_speakers = self.speaker_manager.num_speakers
if config.use_speaker_embedding:
self._init_speaker_embedding(config)
if self.args.use_speaker_embedding:
self._init_speaker_embedding()
if config.use_d_vector_file:
self._init_d_vector(config)
if self.args.use_d_vector_file:
self._init_d_vector()
def _init_speaker_embedding(self, config):
# TODO: make this a function
if self.args.use_speaker_encoder_as_loss:
if self.speaker_manager.speaker_encoder is None and (
not config.speaker_encoder_model_path or not config.speaker_encoder_config_path
):
raise RuntimeError(
" [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!"
)
self.speaker_manager.speaker_encoder.eval()
print(" > External Speaker Encoder Loaded !!")
if (
hasattr(self.speaker_manager.speaker_encoder, "audio_config")
and self.config.audio["sample_rate"] != self.speaker_manager.speaker_encoder.audio_config["sample_rate"]
):
# TODO: change this with torchaudio Resample
raise RuntimeError(
" [!] To use the speaker consistency loss (SCL) you need to have matching sample rates between the TTS model ({}) and the speaker encoder ({})!".format(
self.config.audio["sample_rate"],
self.speaker_manager.speaker_encoder.audio_config["sample_rate"],
)
)
# pylint: disable=W0101,W0105
""" self.audio_transform = torchaudio.transforms.Resample(
orig_freq=self.audio_config["sample_rate"],
new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"],
)
else:
self.audio_transform = None
"""
def _init_speaker_embedding(self):
# pylint: disable=attribute-defined-outside-init
if config.speakers_file is not None:
self.speaker_manager = SpeakerManager(speaker_id_file_path=config.speakers_file)
if self.num_speakers > 0:
print(" > initialization of speaker-embedding layers.")
self.embedded_speaker_dim = config.speaker_embedding_channels
self.embedded_speaker_dim = self.args.speaker_embedding_channels
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
def _init_d_vector(self, config):
def _init_d_vector(self):
# pylint: disable=attribute-defined-outside-init
if hasattr(self, "emb_g"):
raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.")
self.speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file)
self.embedded_speaker_dim = config.d_vector_dim
self.embedded_speaker_dim = self.args.d_vector_dim
def init_multilingual(self, config: Coqpit):
"""Initialize multilingual modules of a model.
Args:
config (Coqpit): Model configuration.
"""
if self.args.language_ids_file is not None:
self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file)
if self.args.use_language_embedding and self.language_manager:
self.num_languages = self.language_manager.num_languages
self.embedded_language_dim = self.args.embedded_language_dim
self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim)
torch.nn.init.xavier_uniform_(self.emb_l.weight)
else:
self.embedded_language_dim = 0
self.emb_l = None
@staticmethod
def _set_cond_input(aux_input: Dict):
"""Set the speaker conditioning input based on the multi-speaker mode."""
sid, g = None, None
sid, g, lid = None, None, None
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
sid = aux_input["speaker_ids"]
if sid.ndim == 0:
sid = sid.unsqueeze_(0)
if "d_vectors" in aux_input and aux_input["d_vectors"] is not None:
g = aux_input["d_vectors"]
return sid, g
g = F.normalize(aux_input["d_vectors"]).unsqueeze(-1)
if g.ndim == 2:
g = g.unsqueeze_(0)
if "language_ids" in aux_input and aux_input["language_ids"] is not None:
lid = aux_input["language_ids"]
if lid.ndim == 0:
lid = lid.unsqueeze_(0)
return sid, g, lid
def get_aux_input(self, aux_input: Dict):
sid, g = self._set_cond_input(aux_input)
return {"speaker_id": sid, "style_wav": None, "d_vector": g}
sid, g, lid = self._set_cond_input(aux_input)
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid}
def get_aux_input_from_test_sentences(self, sentence_info):
if hasattr(self.config, "model_args"):
config = self.config.model_args
else:
config = self.config
# extract speaker and language info
text, speaker_name, style_wav, language_name = None, None, None, None
if isinstance(sentence_info, list):
if len(sentence_info) == 1:
text = sentence_info[0]
elif len(sentence_info) == 2:
text, speaker_name = sentence_info
elif len(sentence_info) == 3:
text, speaker_name, style_wav = sentence_info
elif len(sentence_info) == 4:
text, speaker_name, style_wav, language_name = sentence_info
else:
text = sentence_info
# get speaker id/d_vector
speaker_id, d_vector, language_id = None, None, None
if hasattr(self, "speaker_manager"):
if config.use_d_vector_file:
if speaker_name is None:
d_vector = self.speaker_manager.get_random_d_vector()
else:
d_vector = self.speaker_manager.get_mean_d_vector(speaker_name, num_samples=1, randomize=False)
elif config.use_speaker_embedding:
if speaker_name is None:
speaker_id = self.speaker_manager.get_random_speaker_id()
else:
speaker_id = self.speaker_manager.speaker_ids[speaker_name]
# get language id
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
language_id = self.language_manager.language_id_mapping[language_name]
return {
"text": text,
"speaker_id": speaker_id,
"style_wav": style_wav,
"d_vector": d_vector,
"language_id": language_id,
"language_name": language_name,
}
def forward(
self,
@ -373,7 +544,8 @@ class Vits(BaseTTS):
x_lengths: torch.tensor,
y: torch.tensor,
y_lengths: torch.tensor,
aux_input={"d_vectors": None, "speaker_ids": None},
waveform: torch.tensor,
aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None},
) -> Dict:
"""Forward pass of the model.
@ -382,7 +554,9 @@ class Vits(BaseTTS):
x_lengths (torch.tensor): Batch of input character sequence lengths.
y (torch.tensor): Batch of input spectrograms.
y_lengths (torch.tensor): Batch of input spectrogram lengths.
aux_input (dict, optional): Auxiliary inputs for multi-speaker training. Defaults to {"d_vectors": None, "speaker_ids": None}.
waveform (torch.tensor): Batch of ground truth waveforms per sample.
aux_input (dict, optional): Auxiliary inputs for multi-speaker and multi-lingual training.
Defaults to {"d_vectors": None, "speaker_ids": None, "language_ids": None}.
Returns:
Dict: model outputs keyed by the output name.
@ -392,17 +566,24 @@ class Vits(BaseTTS):
- x_lengths: :math:`[B]`
- y: :math:`[B, C, T_spec]`
- y_lengths: :math:`[B]`
- waveform: :math:`[B, T_wav, 1]`
- d_vectors: :math:`[B, C, 1]`
- speaker_ids: :math:`[B]`
- language_ids: :math:`[B]`
"""
outputs = {}
sid, g = self._set_cond_input(aux_input)
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths)
sid, g, lid = self._set_cond_input(aux_input)
# speaker embedding
if self.num_speakers > 1 and sid is not None:
if self.args.use_speaker_embedding and sid is not None:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
# language embedding
lang_emb = None
if self.args.use_language_embedding and lid is not None:
lang_emb = self.emb_l(lid).unsqueeze(-1)
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
# posterior encoder
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
@ -428,6 +609,7 @@ class Vits(BaseTTS):
x_mask,
attn_durations,
g=g.detach() if self.args.detach_dp_input and g is not None else g,
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
)
loss_duration = loss_duration / torch.sum(x_mask)
else:
@ -436,6 +618,7 @@ class Vits(BaseTTS):
x.detach() if self.args.detach_dp_input else x,
x_mask,
g=g.detach() if self.args.detach_dp_input and g is not None else g,
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
)
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
outputs["loss_duration"] = loss_duration
@ -447,40 +630,73 @@ class Vits(BaseTTS):
# select a random feature segment for the waveform decoder
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size)
o = self.waveform_decoder(z_slice, g=g)
wav_seg = segment(
waveform,
slice_ids * self.config.audio.hop_length,
self.args.spec_segment_size * self.config.audio.hop_length,
)
if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None:
# concate generated and GT waveforms
wavs_batch = torch.cat((wav_seg, o), dim=0)
# resample audio to speaker encoder sample_rate
# pylint: disable=W0105
"""if self.audio_transform is not None:
wavs_batch = self.audio_transform(wavs_batch)"""
pred_embs = self.speaker_manager.speaker_encoder.forward(wavs_batch, l2_norm=True)
# split generated and GT speaker embeddings
gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0)
else:
gt_spk_emb, syn_spk_emb = None, None
outputs.update(
{
"model_outputs": o,
"alignments": attn.squeeze(1),
"slice_ids": slice_ids,
"z": z,
"z_p": z_p,
"m_p": m_p,
"logs_p": logs_p,
"m_q": m_q,
"logs_q": logs_q,
"waveform_seg": wav_seg,
"gt_spk_emb": gt_spk_emb,
"syn_spk_emb": syn_spk_emb,
}
)
return outputs
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}):
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}):
"""
Shapes:
- x: :math:`[B, T_seq]`
- d_vectors: :math:`[B, C, 1]`
- speaker_ids: :math:`[B]`
"""
sid, g = self._set_cond_input(aux_input)
sid, g, lid = self._set_cond_input(aux_input)
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths)
if self.num_speakers > 0 and sid is not None:
# speaker embedding
if self.args.use_speaker_embedding and sid is not None:
g = self.emb_g(sid).unsqueeze(-1)
# language embedding
lang_emb = None
if self.args.use_language_embedding and lid is not None:
lang_emb = self.emb_l(lid).unsqueeze(-1)
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
if self.args.use_sdp:
logw = self.duration_predictor(x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp)
logw = self.duration_predictor(
x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb
)
else:
logw = self.duration_predictor(x, x_mask, g=g)
logw = self.duration_predictor(x, x_mask, g=g, lang_emb=lang_emb)
w = torch.exp(logw) * x_mask * self.length_scale
w_ceil = torch.ceil(w)
@ -499,12 +715,30 @@ class Vits(BaseTTS):
outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p}
return outputs
def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
"""TODO: create an end-point for voice conversion"""
def voice_conversion(self, y, y_lengths, speaker_cond_src, speaker_cond_tgt):
"""Forward pass for voice conversion
TODO: create an end-point for voice conversion
Args:
y (Tensor): Reference spectrograms. Tensor of shape [B, T, C]
y_lengths (Tensor): Length of each reference spectrogram. Tensor of shape [B]
speaker_cond_src (Tensor): Reference speaker ID. Tensor of shape [B,]
speaker_cond_tgt (Tensor): Target speaker ID. Tensor of shape [B,]
"""
assert self.num_speakers > 0, "num_speakers have to be larger than 0."
g_src = self.emb_g(sid_src).unsqueeze(-1)
g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
z, _, _, y_mask = self.enc_q(y, y_lengths, g=g_src)
# speaker embedding
if self.args.use_speaker_embedding and not self.args.use_d_vector_file:
g_src = self.emb_g(speaker_cond_src).unsqueeze(-1)
g_tgt = self.emb_g(speaker_cond_tgt).unsqueeze(-1)
elif self.args.use_speaker_embedding and self.args.use_d_vector_file:
g_src = F.normalize(speaker_cond_src).unsqueeze(-1)
g_tgt = F.normalize(speaker_cond_tgt).unsqueeze(-1)
else:
raise RuntimeError(" [!] Voice conversion is only supported on multi-speaker models.")
z, _, _, y_mask = self.posterior_encoder(y.transpose(1, 2), y_lengths, g=g_src)
z_p = self.flow(z, y_mask, g=g_src)
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt)
@ -525,6 +759,30 @@ class Vits(BaseTTS):
if optimizer_idx not in [0, 1]:
raise ValueError(" [!] Unexpected `optimizer_idx`.")
if self.args.freeze_encoder:
for param in self.text_encoder.parameters():
param.requires_grad = False
if hasattr(self, "emb_l"):
for param in self.emb_l.parameters():
param.requires_grad = False
if self.args.freeze_PE:
for param in self.posterior_encoder.parameters():
param.requires_grad = False
if self.args.freeze_DP:
for param in self.duration_predictor.parameters():
param.requires_grad = False
if self.args.freeze_flow_decoder:
for param in self.flow.parameters():
param.requires_grad = False
if self.args.freeze_waveform_decoder:
for param in self.waveform_decoder.parameters():
param.requires_grad = False
if optimizer_idx == 0:
text_input = batch["text_input"]
text_lengths = batch["text_lengths"]
@ -532,6 +790,7 @@ class Vits(BaseTTS):
linear_input = batch["linear_input"]
d_vectors = batch["d_vectors"]
speaker_ids = batch["speaker_ids"]
language_ids = batch["language_ids"]
waveform = batch["waveform"]
# generator pass
@ -540,31 +799,26 @@ class Vits(BaseTTS):
text_lengths,
linear_input.transpose(1, 2),
mel_lengths,
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids},
waveform.transpose(1, 2),
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids},
)
# cache tensors for the discriminator
self.y_disc_cache = None
self.wav_seg_disc_cache = None
self.y_disc_cache = outputs["model_outputs"]
wav_seg = segment(
waveform.transpose(1, 2),
outputs["slice_ids"] * self.config.audio.hop_length,
self.args.spec_segment_size * self.config.audio.hop_length,
)
self.wav_seg_disc_cache = wav_seg
outputs["waveform_seg"] = wav_seg
self.wav_seg_disc_cache = outputs["waveform_seg"]
# compute discriminator scores and features
outputs["scores_disc_fake"], outputs["feats_disc_fake"], _, outputs["feats_disc_real"] = self.disc(
outputs["model_outputs"], wav_seg
outputs["model_outputs"], outputs["waveform_seg"]
)
# compute losses
with autocast(enabled=False): # use float32 for the criterion
loss_dict = criterion[optimizer_idx](
waveform_hat=outputs["model_outputs"].float(),
waveform=wav_seg.float(),
waveform=outputs["waveform_seg"].float(),
z_p=outputs["z_p"].float(),
logs_q=outputs["logs_q"].float(),
m_p=outputs["m_p"].float(),
@ -574,6 +828,9 @@ class Vits(BaseTTS):
feats_disc_fake=outputs["feats_disc_fake"],
feats_disc_real=outputs["feats_disc_real"],
loss_duration=outputs["loss_duration"],
use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss,
gt_spk_emb=outputs["gt_spk_emb"],
syn_spk_emb=outputs["syn_spk_emb"],
)
elif optimizer_idx == 1:
@ -651,32 +908,28 @@ class Vits(BaseTTS):
test_audios = {}
test_figures = {}
test_sentences = self.config.test_sentences
aux_inputs = {
"speaker_id": None
if not self.config.use_speaker_embedding
else random.sample(sorted(self.speaker_manager.speaker_ids.values()), 1),
"d_vector": None
if not self.config.use_d_vector_file
else random.samples(sorted(self.speaker_manager.d_vectors.values()), 1),
"style_wav": None,
}
for idx, sen in enumerate(test_sentences):
wav, alignment, _, _ = synthesis(
self,
sen,
self.config,
"cuda" in str(next(self.parameters()).device),
ap,
speaker_id=aux_inputs["speaker_id"],
d_vector=aux_inputs["d_vector"],
style_wav=aux_inputs["style_wav"],
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
use_griffin_lim=True,
do_trim_silence=False,
).values()
test_audios["{}-audio".format(idx)] = wav
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False)
for idx, s_info in enumerate(test_sentences):
try:
aux_inputs = self.get_aux_input_from_test_sentences(s_info)
wav, alignment, _, _ = synthesis(
self,
aux_inputs["text"],
self.config,
"cuda" in str(next(self.parameters()).device),
ap,
speaker_id=aux_inputs["speaker_id"],
d_vector=aux_inputs["d_vector"],
style_wav=aux_inputs["style_wav"],
language_id=aux_inputs["language_id"],
language_name=aux_inputs["language_name"],
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
use_griffin_lim=True,
do_trim_silence=False,
).values()
test_audios["{}-audio".format(idx)] = wav
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False)
except: # pylint: disable=bare-except
print(" !! Error creating Test Sentence -", idx)
return test_figures, test_audios
def get_optimizer(self) -> List:
@ -695,8 +948,12 @@ class Vits(BaseTTS):
self.waveform_decoder.parameters(),
)
# add the speaker embedding layer
if hasattr(self, "emb_g"):
if hasattr(self, "emb_g") and self.args.use_speaker_embedding and not self.args.use_d_vector_file:
gen_parameters = chain(gen_parameters, self.emb_g.parameters())
# add the language embedding layer
if hasattr(self, "emb_l") and self.args.use_language_embedding:
gen_parameters = chain(gen_parameters, self.emb_l.parameters())
optimizer0 = get_optimizer(
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
)
@ -769,6 +1026,10 @@ class Vits(BaseTTS):
): # pylint: disable=unused-argument, redefined-builtin
"""Load the model checkpoint and setup for training or inference"""
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
# compat band-aid for the pre-trained models to not use the encoder baked into the model
# TODO: consider baking the speaker encoder into the model and call it from there.
# as it is probably easier for model distribution.
state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k}
self.load_state_dict(state["model"])
if eval:
self.eval()

122
TTS/tts/utils/languages.py Normal file
View File

@ -0,0 +1,122 @@
import json
import os
from typing import Dict, List
import fsspec
import numpy as np
import torch
from coqpit import Coqpit
from torch.utils.data.sampler import WeightedRandomSampler
class LanguageManager:
"""Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information
in a way that can be queried by language.
Args:
language_ids_file_path (str, optional): Path to the metafile that maps language names to ids used by
TTS models. Defaults to "".
config (Coqpit, optional): Coqpit config that contains the language information in the datasets filed.
Defaults to None.
Examples:
>>> manager = LanguageManager(language_ids_file_path=language_ids_file_path)
>>> language_id_mapper = manager.language_ids
"""
language_id_mapping: Dict = {}
def __init__(
self,
language_ids_file_path: str = "",
config: Coqpit = None,
):
self.language_id_mapping = {}
if language_ids_file_path:
self.set_language_ids_from_file(language_ids_file_path)
if config:
self.set_language_ids_from_config(config)
@staticmethod
def _load_json(json_file_path: str) -> Dict:
with fsspec.open(json_file_path, "r") as f:
return json.load(f)
@staticmethod
def _save_json(json_file_path: str, data: dict) -> None:
with fsspec.open(json_file_path, "w") as f:
json.dump(data, f, indent=4)
@property
def num_languages(self) -> int:
return len(list(self.language_id_mapping.keys()))
@property
def language_names(self) -> List:
return list(self.language_id_mapping.keys())
@staticmethod
def parse_language_ids_from_config(c: Coqpit) -> Dict:
"""Set language id from config.
Args:
c (Coqpit): Config
Returns:
Tuple[Dict, int]: Language ID mapping and the number of languages.
"""
languages = set({})
for dataset in c.datasets:
if "language" in dataset:
languages.add(dataset["language"])
else:
raise ValueError(f"Dataset {dataset['name']} has no language specified.")
return {name: i for i, name in enumerate(sorted(list(languages)))}
def set_language_ids_from_config(self, c: Coqpit) -> None:
"""Set language IDs from config samples.
Args:
items (List): Data sampled returned by `load_meta_data()`.
"""
self.language_id_mapping = self.parse_language_ids_from_config(c)
def set_language_ids_from_file(self, file_path: str) -> None:
"""Load language ids from a json file.
Args:
file_path (str): Path to the target json file.
"""
self.language_id_mapping = self._load_json(file_path)
def save_language_ids_to_file(self, file_path: str) -> None:
"""Save language IDs to a json file.
Args:
file_path (str): Path to the output file.
"""
self._save_json(file_path, self.language_id_mapping)
def _set_file_path(path):
"""Find the language_ids.json under the given path or the above it.
Intended to band aid the different paths returned in restored and continued training."""
path_restore = os.path.join(os.path.dirname(path), "language_ids.json")
path_continue = os.path.join(path, "language_ids.json")
fs = fsspec.get_mapper(path).fs
if fs.exists(path_restore):
return path_restore
if fs.exists(path_continue):
return path_continue
return None
def get_language_weighted_sampler(items: list):
language_names = np.array([item[3] for item in items])
unique_language_names = np.unique(language_names).tolist()
language_ids = [unique_language_names.index(l) for l in language_names]
language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names])
weight_language = 1.0 / language_count
dataset_samples_weight = torch.from_numpy(np.array([weight_language[l] for l in language_ids])).double()
return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight))

View File

@ -7,9 +7,10 @@ import fsspec
import numpy as np
import torch
from coqpit import Coqpit
from torch.utils.data.sampler import WeightedRandomSampler
from TTS.config import load_config
from TTS.speaker_encoder.utils.generic_utils import setup_model
from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
from TTS.utils.audio import AudioProcessor
@ -161,8 +162,10 @@ class SpeakerManager:
file_path (str): Path to the target json file.
"""
self.d_vectors = self._load_json(file_path)
speakers = sorted({x["name"] for x in self.d_vectors.values()})
self.speaker_ids = {name: i for i, name in enumerate(speakers)}
self.clip_ids = list(set(sorted(clip_name for clip_name in self.d_vectors.keys())))
def get_d_vector_by_clip(self, clip_idx: str) -> List:
@ -209,6 +212,32 @@ class SpeakerManager:
d_vectors = np.stack(d_vectors[:num_samples]).mean(0)
return d_vectors
def get_random_speaker_id(self) -> Any:
"""Get a random d_vector.
Args:
Returns:
np.ndarray: d_vector.
"""
if self.speaker_ids:
return self.speaker_ids[random.choices(list(self.speaker_ids.keys()))[0]]
return None
def get_random_d_vector(self) -> Any:
"""Get a random D ID.
Args:
Returns:
np.ndarray: d_vector.
"""
if self.d_vectors:
return self.d_vectors[random.choices(list(self.d_vectors.keys()))[0]]["embedding"]
return None
def get_speakers(self) -> List:
return self.speaker_ids
@ -223,18 +252,15 @@ class SpeakerManager:
config_path (str): Model config file path.
"""
self.speaker_encoder_config = load_config(config_path)
self.speaker_encoder = setup_model(self.speaker_encoder_config)
self.speaker_encoder = setup_speaker_encoder_model(self.speaker_encoder_config)
self.speaker_encoder.load_checkpoint(config_path, model_path, eval=True, use_cuda=self.use_cuda)
self.speaker_encoder_ap = AudioProcessor(**self.speaker_encoder_config.audio)
# normalize the input audio level and trim silences
# self.speaker_encoder_ap.do_sound_norm = True
# self.speaker_encoder_ap.do_trim_silence = True
def compute_d_vector_from_clip(self, wav_file: Union[str, list]) -> list:
def compute_d_vector_from_clip(self, wav_file: Union[str, List[str]]) -> list:
"""Compute a d_vector from a given audio file.
Args:
wav_file (Union[str, list]): Target file path.
wav_file (Union[str, List[str]]): Target file path.
Returns:
list: Computed d_vector.
@ -242,12 +268,16 @@ class SpeakerManager:
def _compute(wav_file: str):
waveform = self.speaker_encoder_ap.load_wav(wav_file, sr=self.speaker_encoder_ap.sample_rate)
spec = self.speaker_encoder_ap.melspectrogram(waveform)
spec = torch.from_numpy(spec.T)
if not self.speaker_encoder_config.model_params.get("use_torch_spec", False):
m_input = self.speaker_encoder_ap.melspectrogram(waveform)
m_input = torch.from_numpy(m_input)
else:
m_input = torch.from_numpy(waveform)
if self.use_cuda:
spec = spec.cuda()
spec = spec.unsqueeze(0)
d_vector = self.speaker_encoder.compute_embedding(spec)
m_input = m_input.cuda()
m_input = m_input.unsqueeze(0)
d_vector = self.speaker_encoder.compute_embedding(m_input)
return d_vector
if isinstance(wav_file, list):
@ -364,11 +394,14 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
elif c.use_speaker_embedding and "speakers_file" in c and c.speakers_file:
# new speaker manager with speaker IDs file.
speaker_manager.set_speaker_ids_from_file(c.speakers_file)
print(
" > Speaker manager is loaded with {} speakers: {}".format(
speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids)
if speaker_manager.num_speakers > 0:
print(
" > Speaker manager is loaded with {} speakers: {}".format(
speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids)
)
)
)
# save file if path is defined
if out_path:
out_file_path = os.path.join(out_path, "speakers.json")
@ -378,3 +411,13 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
else:
speaker_manager.save_speaker_ids_to_file(out_file_path)
return speaker_manager
def get_speaker_weighted_sampler(items: list):
speaker_names = np.array([item[2] for item in items])
unique_speaker_names = np.unique(speaker_names).tolist()
speaker_ids = [unique_speaker_names.index(l) for l in speaker_names]
speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names])
weight_speaker = 1.0 / speaker_count
dataset_samples_weight = torch.from_numpy(np.array([weight_speaker[l] for l in speaker_ids])).double()
return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight))

View File

@ -15,7 +15,7 @@ if "tensorflow" in installed or "tensorflow-gpu" in installed:
import tensorflow as tf
def text_to_seq(text, CONFIG, custom_symbols=None):
def text_to_seq(text, CONFIG, custom_symbols=None, language=None):
text_cleaner = [CONFIG.text_cleaner]
# text ot phonemes to sequence vector
if CONFIG.use_phonemes:
@ -23,7 +23,7 @@ def text_to_seq(text, CONFIG, custom_symbols=None):
phoneme_to_sequence(
text,
text_cleaner,
CONFIG.phoneme_language,
language if language else CONFIG.phoneme_language,
CONFIG.enable_eos_bos_chars,
tp=CONFIG.characters,
add_blank=CONFIG.add_blank,
@ -71,6 +71,7 @@ def run_model_torch(
speaker_id: int = None,
style_mel: torch.Tensor = None,
d_vector: torch.Tensor = None,
language_id: torch.Tensor = None,
) -> Dict:
"""Run a torch model for inference. It does not support batch inference.
@ -96,6 +97,7 @@ def run_model_torch(
"speaker_ids": speaker_id,
"d_vectors": d_vector,
"style_mel": style_mel,
"language_ids": language_id,
},
)
return outputs
@ -160,19 +162,20 @@ def inv_spectrogram(postnet_output, ap, CONFIG):
return wav
def speaker_id_to_torch(speaker_id, cuda=False):
if speaker_id is not None:
speaker_id = np.asarray(speaker_id)
speaker_id = torch.from_numpy(speaker_id)
def id_to_torch(aux_id, cuda=False):
if aux_id is not None:
aux_id = np.asarray(aux_id)
aux_id = torch.from_numpy(aux_id)
if cuda:
return speaker_id.cuda()
return speaker_id
return aux_id.cuda()
return aux_id
def embedding_to_torch(d_vector, cuda=False):
if d_vector is not None:
d_vector = np.asarray(d_vector)
d_vector = torch.from_numpy(d_vector).type(torch.FloatTensor)
d_vector = d_vector.squeeze().unsqueeze(0)
if cuda:
return d_vector.cuda()
return d_vector
@ -208,6 +211,8 @@ def synthesis(
use_griffin_lim=False,
do_trim_silence=False,
d_vector=None,
language_id=None,
language_name=None,
backend="torch",
):
"""Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to
@ -244,6 +249,12 @@ def synthesis(
d_vector (torch.Tensor):
d-vector for multi-speaker models in share :math:`[1, D]`. Defaults to None.
language_id (int):
Language ID passed to the language embedding layer in multi-langual model. Defaults to None.
language_name (str):
Language name corresponding to the language code used by the phonemizer. Defaults to None.
backend (str):
tf or torch. Defaults to "torch".
"""
@ -258,15 +269,18 @@ def synthesis(
if hasattr(model, "make_symbols"):
custom_symbols = model.make_symbols(CONFIG)
# preprocess the given text
text_inputs = text_to_seq(text, CONFIG, custom_symbols=custom_symbols)
text_inputs = text_to_seq(text, CONFIG, custom_symbols=custom_symbols, language=language_name)
# pass tensors to backend
if backend == "torch":
if speaker_id is not None:
speaker_id = speaker_id_to_torch(speaker_id, cuda=use_cuda)
speaker_id = id_to_torch(speaker_id, cuda=use_cuda)
if d_vector is not None:
d_vector = embedding_to_torch(d_vector, cuda=use_cuda)
if language_id is not None:
language_id = id_to_torch(language_id, cuda=use_cuda)
if not isinstance(style_mel, dict):
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda)
@ -278,7 +292,7 @@ def synthesis(
text_inputs = tf.expand_dims(text_inputs, 0)
# synthesize voice
if backend == "torch":
outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, d_vector=d_vector)
outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, d_vector=d_vector, language_id=language_id)
model_outputs = outputs["model_outputs"]
model_outputs = model_outputs[0].data.cpu().numpy()
alignments = outputs["alignments"]

View File

@ -135,3 +135,12 @@ def phoneme_cleaners(text):
text = remove_aux_symbols(text)
text = collapse_whitespace(text)
return text
def multilingual_cleaners(text):
"""Pipeline for multilingual text"""
text = lowercase(text)
text = replace_symbols(text, lang=None)
text = remove_aux_symbols(text)
text = collapse_whitespace(text)
return text

View File

@ -16,6 +16,60 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
"""Some of the audio processing funtions using Torch for faster batch processing.
TODO: Merge this with audio.py
Args:
n_fft (int):
FFT window size for STFT.
hop_length (int):
number of frames between STFT columns.
win_length (int, optional):
STFT window length.
pad_wav (bool, optional):
If True pad the audio with (n_fft - hop_length) / 2). Defaults to False.
window (str, optional):
The name of a function to create a window tensor that is applied/multiplied to each frame/window. Defaults to "hann_window"
sample_rate (int, optional):
target audio sampling rate. Defaults to None.
mel_fmin (int, optional):
minimum filter frequency for computing melspectrograms. Defaults to None.
mel_fmax (int, optional):
maximum filter frequency for computing melspectrograms. Defaults to None.
n_mels (int, optional):
number of melspectrogram dimensions. Defaults to None.
use_mel (bool, optional):
If True compute the melspectrograms otherwise. Defaults to False.
do_amp_to_db_linear (bool, optional):
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to False.
spec_gain (float, optional):
gain applied when converting amplitude to DB. Defaults to 1.0.
power (float, optional):
Exponent for the magnitude spectrogram, e.g., 1 for energy, 2 for power, etc. Defaults to None.
use_htk (bool, optional):
Use HTK formula in mel filter instead of Slaney.
mel_norm (None, 'slaney', or number, optional):
If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization).
If numeric, use `librosa.util.normalize` to normalize each filter by to unit l_p norm.
See `librosa.util.normalize` for a full description of supported norm values
(including `+-np.inf`).
Otherwise, leave all the triangles aiming for a peak value of 1.0. Defaults to "slaney".
"""
def __init__(
@ -32,6 +86,9 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
use_mel=False,
do_amp_to_db=False,
spec_gain=1.0,
power=None,
use_htk=False,
mel_norm="slaney",
):
super().__init__()
self.n_fft = n_fft
@ -45,6 +102,9 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
self.use_mel = use_mel
self.do_amp_to_db = do_amp_to_db
self.spec_gain = spec_gain
self.power = power
self.use_htk = use_htk
self.mel_norm = mel_norm
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
self.mel_basis = None
if use_mel:
@ -83,6 +143,10 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
M = o[:, :, :, 0]
P = o[:, :, :, 1]
S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))
if self.power is not None:
S = S ** self.power
if self.use_mel:
S = torch.matmul(self.mel_basis.to(x), S)
if self.do_amp_to_db:
@ -91,7 +155,13 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
def _build_mel_basis(self):
mel_basis = librosa.filters.mel(
self.sample_rate, self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax
self.sample_rate,
self.n_fft,
n_mels=self.n_mels,
fmin=self.mel_fmin,
fmax=self.mel_fmax,
htk=self.use_htk,
norm=self.mel_norm,
)
self.mel_basis = torch.from_numpy(mel_basis).float()
@ -167,7 +237,7 @@ class AudioProcessor(object):
minimum filter frequency for computing melspectrograms. Defaults to None.
mel_fmax (int, optional):
maximum filter frequency for computing melspectrograms.. Defaults to None.
maximum filter frequency for computing melspectrograms. Defaults to None.
spec_gain (int, optional):
gain applied when converting amplitude to DB. Defaults to 20.
@ -196,6 +266,12 @@ class AudioProcessor(object):
do_amp_to_db_mel (bool, optional):
enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
do_rms_norm (bool, optional):
enable/disable RMS volume normalization when loading an audio file. Defaults to False.
db_level (int, optional):
dB level used for rms normalization. The range is -99 to 0. Defaults to None.
stats_path (str, optional):
Path to the computed stats file. Defaults to None.
@ -233,6 +309,8 @@ class AudioProcessor(object):
do_sound_norm=False,
do_amp_to_db_linear=True,
do_amp_to_db_mel=True,
do_rms_norm=False,
db_level=None,
stats_path=None,
verbose=True,
**_,
@ -264,6 +342,8 @@ class AudioProcessor(object):
self.do_sound_norm = do_sound_norm
self.do_amp_to_db_linear = do_amp_to_db_linear
self.do_amp_to_db_mel = do_amp_to_db_mel
self.do_rms_norm = do_rms_norm
self.db_level = db_level
self.stats_path = stats_path
# setup exp_func for db to amp conversion
if log_func == "np.log":
@ -656,21 +736,6 @@ class AudioProcessor(object):
frame_period=1000 * self.hop_length / self.sample_rate,
)
f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate)
# pad = int((self.win_length / self.hop_length) / 2)
# f0 = [0.0] * pad + f0 + [0.0] * pad
# f0 = np.pad(f0, (pad, pad), mode="constant", constant_values=0)
# f0 = np.array(f0, dtype=np.float32)
# f01, _, _ = librosa.pyin(
# x,
# fmin=65 if self.mel_fmin == 0 else self.mel_fmin,
# fmax=self.mel_fmax,
# frame_length=self.win_length,
# sr=self.sample_rate,
# fill_na=0.0,
# )
# spec = self.melspectrogram(x)
return f0
### Audio Processing ###
@ -713,10 +778,33 @@ class AudioProcessor(object):
"""
return x / abs(x).max() * 0.95
@staticmethod
def _rms_norm(wav, db_level=-27):
r = 10 ** (db_level / 20)
a = np.sqrt((len(wav) * (r ** 2)) / np.sum(wav ** 2))
return wav * a
def rms_volume_norm(self, x: np.ndarray, db_level: float = None) -> np.ndarray:
"""Normalize the volume based on RMS of the signal.
Args:
x (np.ndarray): Raw waveform.
Returns:
np.ndarray: RMS normalized waveform.
"""
if db_level is None:
db_level = self.db_level
assert -99 <= db_level <= 0, " [!] db_level should be between -99 and 0"
wav = self._rms_norm(x, db_level)
return wav
### save and load ###
def load_wav(self, filename: str, sr: int = None) -> np.ndarray:
"""Read a wav file using Librosa and optionally resample, silence trim, volume normalize.
Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before.
Args:
filename (str): Path to the wav file.
sr (int, optional): Sampling rate for resampling. Defaults to None.
@ -725,8 +813,10 @@ class AudioProcessor(object):
np.ndarray: Loaded waveform.
"""
if self.resample:
# loading with resampling. It is significantly slower.
x, sr = librosa.load(filename, sr=self.sample_rate)
elif sr is None:
# SF is faster than librosa for loading files
x, sr = sf.read(filename)
assert self.sample_rate == sr, "%s vs %s" % (self.sample_rate, sr)
else:
@ -738,6 +828,8 @@ class AudioProcessor(object):
print(f" [!] File cannot be trimmed for silence - {filename}")
if self.do_sound_norm:
x = self.sound_norm(x)
if self.do_rms_norm:
x = self.rms_volume_norm(x, self.db_level)
return x
def save_wav(self, wav: np.ndarray, path: str, sr: int = None) -> None:

View File

@ -26,7 +26,7 @@ class AttrDict(dict):
self.__dict__ = self
def copy_model_files(config: Coqpit, out_path, new_fields):
def copy_model_files(config: Coqpit, out_path, new_fields=None):
"""Copy config.json and other model files to training folder and add
new fields.

View File

@ -46,36 +46,66 @@ class ModelManager(object):
with open(file_path, "r", encoding="utf-8") as json_file:
self.models_dict = json.load(json_file)
def list_langs(self):
print(" Name format: type/language")
for model_type in self.models_dict:
for lang in self.models_dict[model_type]:
print(f" >: {model_type}/{lang} ")
def _list_models(self, model_type, model_count=0):
model_list = []
for lang in self.models_dict[model_type]:
for dataset in self.models_dict[model_type][lang]:
for model in self.models_dict[model_type][lang][dataset]:
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
output_path = os.path.join(self.output_prefix, model_full_name)
if os.path.exists(output_path):
print(f" {model_count}: {model_type}/{lang}/{dataset}/{model} [already downloaded]")
else:
print(f" {model_count}: {model_type}/{lang}/{dataset}/{model}")
model_list.append(f"{model_type}/{lang}/{dataset}/{model}")
model_count += 1
return model_list
def list_datasets(self):
print(" Name format: type/language/dataset")
for model_type in self.models_dict:
for lang in self.models_dict[model_type]:
for dataset in self.models_dict[model_type][lang]:
print(f" >: {model_type}/{lang}/{dataset}")
def _list_for_model_type(self, model_type):
print(" Name format: language/dataset/model")
models_name_list = []
model_count = 1
model_type = "tts_models"
models_name_list.extend(self._list_models(model_type, model_count))
return [name.replace(model_type + "/", "") for name in models_name_list]
def list_models(self):
print(" Name format: type/language/dataset/model")
models_name_list = []
model_count = 1
for model_type in self.models_dict:
model_list = self._list_models(model_type, model_count)
models_name_list.extend(model_list)
return models_name_list
def list_tts_models(self):
"""Print all `TTS` models and return a list of model names
Format is `language/dataset/model`
"""
return self._list_for_model_type("tts_models")
def list_vocoder_models(self):
"""Print all the `vocoder` models and return a list of model names
Format is `language/dataset/model`
"""
return self._list_for_model_type("vocoder_models")
def list_langs(self):
"""Print all the available languages"""
print(" Name format: type/language")
for model_type in self.models_dict:
for lang in self.models_dict[model_type]:
print(f" >: {model_type}/{lang} ")
def list_datasets(self):
"""Print all the datasets"""
print(" Name format: type/language/dataset")
for model_type in self.models_dict:
for lang in self.models_dict[model_type]:
for dataset in self.models_dict[model_type][lang]:
for model in self.models_dict[model_type][lang][dataset]:
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
output_path = os.path.join(self.output_prefix, model_full_name)
if os.path.exists(output_path):
print(f" {model_count}: {model_type}/{lang}/{dataset}/{model} [already downloaded]")
else:
print(f" {model_count}: {model_type}/{lang}/{dataset}/{model}")
models_name_list.append(f"{model_type}/{lang}/{dataset}/{model}")
model_count += 1
return models_name_list
print(f" >: {model_type}/{lang}/{dataset}")
def download_model(self, model_name):
"""Download model files given the full model name.
@ -121,6 +151,8 @@ class ModelManager(object):
output_stats_path = os.path.join(output_path, "scale_stats.npy")
output_d_vector_file_path = os.path.join(output_path, "speakers.json")
output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json")
speaker_encoder_config_path = os.path.join(output_path, "config_se.json")
speaker_encoder_model_path = os.path.join(output_path, "model_se.pth.tar")
# update the scale_path.npy file path in the model config.json
self._update_path("audio.stats_path", output_stats_path, config_path)
@ -133,6 +165,12 @@ class ModelManager(object):
self._update_path("speakers_file", output_speaker_ids_file_path, config_path)
self._update_path("model_args.speakers_file", output_speaker_ids_file_path, config_path)
# update the speaker_encoder file path in the model config.json to the current path
self._update_path("speaker_encoder_model_path", speaker_encoder_model_path, config_path)
self._update_path("model_args.speaker_encoder_model_path", speaker_encoder_model_path, config_path)
self._update_path("speaker_encoder_config_path", speaker_encoder_config_path, config_path)
self._update_path("model_args.speaker_encoder_config_path", speaker_encoder_config_path, config_path)
@staticmethod
def _update_path(field_name, new_path, config_path):
"""Update the path in the model config.json for the current environment after download"""
@ -159,8 +197,12 @@ class ModelManager(object):
# download the file
r = requests.get(file_url)
# extract the file
with zipfile.ZipFile(io.BytesIO(r.content)) as z:
z.extractall(output_folder)
try:
with zipfile.ZipFile(io.BytesIO(r.content)) as z:
z.extractall(output_folder)
except zipfile.BadZipFile:
print(f" > Error: Bad zip file - {file_url}")
raise zipfile.BadZipFile
# move the files to the outer path
for file_path in z.namelist()[1:]:
src_path = os.path.join(output_folder, file_path)

View File

@ -1,12 +1,13 @@
import time
from typing import List
from typing import List, Union
import numpy as np
import pysbd
import torch
from TTS.config import load_config
from TTS.config import check_config_and_model_args, get_from_config_or_model_args_with_default, load_config
from TTS.tts.models import setup_model as setup_tts_model
from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
# pylint: disable=unused-wildcard-import
@ -23,6 +24,7 @@ class Synthesizer(object):
tts_checkpoint: str,
tts_config_path: str,
tts_speakers_file: str = "",
tts_languages_file: str = "",
vocoder_checkpoint: str = "",
vocoder_config: str = "",
encoder_checkpoint: str = "",
@ -52,6 +54,7 @@ class Synthesizer(object):
self.tts_checkpoint = tts_checkpoint
self.tts_config_path = tts_config_path
self.tts_speakers_file = tts_speakers_file
self.tts_languages_file = tts_languages_file
self.vocoder_checkpoint = vocoder_checkpoint
self.vocoder_config = vocoder_config
self.encoder_checkpoint = encoder_checkpoint
@ -63,6 +66,9 @@ class Synthesizer(object):
self.speaker_manager = None
self.num_speakers = 0
self.tts_speakers = {}
self.language_manager = None
self.num_languages = 0
self.tts_languages = {}
self.d_vector_dim = 0
self.seg = self._get_segmenter("en")
self.use_cuda = use_cuda
@ -110,29 +116,93 @@ class Synthesizer(object):
self.ap = AudioProcessor(verbose=False, **self.tts_config.audio)
speaker_manager = self._init_speaker_manager()
language_manager = self._init_language_manager()
self._set_speaker_encoder_paths_from_tts_config()
speaker_manager = self._init_speaker_encoder(speaker_manager)
self.tts_model = setup_tts_model(config=self.tts_config, speaker_manager=speaker_manager)
if language_manager is not None:
self.tts_model = setup_tts_model(
config=self.tts_config,
speaker_manager=speaker_manager,
language_manager=language_manager,
)
else:
self.tts_model = setup_tts_model(config=self.tts_config, speaker_manager=speaker_manager)
self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True)
if use_cuda:
self.tts_model.cuda()
def _set_speaker_encoder_paths_from_tts_config(self):
"""Set the encoder paths from the tts model config for models with speaker encoders."""
if hasattr(self.tts_config, "model_args") and hasattr(
self.tts_config.model_args, "speaker_encoder_config_path"
):
self.encoder_checkpoint = self.tts_config.model_args.speaker_encoder_model_path
self.encoder_config = self.tts_config.model_args.speaker_encoder_config_path
def _is_use_speaker_embedding(self):
"""Check if the speaker embedding is used in the model"""
# we handle here the case that some models use model_args some don't
use_speaker_embedding = False
if hasattr(self.tts_config, "model_args"):
use_speaker_embedding = self.tts_config["model_args"].get("use_speaker_embedding", False)
use_speaker_embedding = use_speaker_embedding or self.tts_config.get("use_speaker_embedding", False)
return use_speaker_embedding
def _is_use_d_vector_file(self):
"""Check if the d-vector file is used in the model"""
# we handle here the case that some models use model_args some don't
use_d_vector_file = False
if hasattr(self.tts_config, "model_args"):
config = self.tts_config.model_args
use_d_vector_file = config.get("use_d_vector_file", False)
config = self.tts_config
use_d_vector_file = use_d_vector_file or config.get("use_d_vector_file", False)
return use_d_vector_file
def _init_speaker_manager(self):
"""Initialize the SpeakerManager"""
# setup if multi-speaker settings are in the global model config
speaker_manager = None
if hasattr(self.tts_config, "use_speaker_embedding") and self.tts_config.use_speaker_embedding is True:
speakers_file = get_from_config_or_model_args_with_default(self.tts_config, "speakers_file", None)
if self._is_use_speaker_embedding():
if self.tts_speakers_file:
speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_speakers_file)
if self.tts_config.get("speakers_file", None):
speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_config.speakers_file)
if speakers_file:
speaker_manager = SpeakerManager(speaker_id_file_path=speakers_file)
if hasattr(self.tts_config, "use_d_vector_file") and self.tts_config.use_speaker_embedding is True:
if self._is_use_d_vector_file():
d_vector_file = get_from_config_or_model_args_with_default(self.tts_config, "d_vector_file", None)
if self.tts_speakers_file:
speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_speakers_file)
if self.tts_config.get("d_vector_file", None):
speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_config.d_vector_file)
if d_vector_file:
speaker_manager = SpeakerManager(d_vectors_file_path=d_vector_file)
return speaker_manager
def _init_speaker_encoder(self, speaker_manager):
"""Initialize the SpeakerEncoder"""
if self.encoder_checkpoint:
if speaker_manager is None:
speaker_manager = SpeakerManager(
encoder_model_path=self.encoder_checkpoint, encoder_config_path=self.encoder_config
)
else:
speaker_manager.init_speaker_encoder(self.encoder_checkpoint, self.encoder_config)
return speaker_manager
def _init_language_manager(self):
"""Initialize the LanguageManager"""
# setup if multi-lingual settings are in the global model config
language_manager = None
if check_config_and_model_args(self.tts_config, "use_language_embedding", True):
if self.tts_languages_file:
language_manager = LanguageManager(language_ids_file_path=self.tts_languages_file)
elif self.tts_config.get("language_ids_file", None):
language_manager = LanguageManager(language_ids_file_path=self.tts_config.language_ids_file)
else:
language_manager = LanguageManager(config=self.tts_config)
return language_manager
def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None:
"""Load the vocoder model.
@ -174,13 +244,21 @@ class Synthesizer(object):
wav = np.array(wav)
self.ap.save_wav(wav, path, self.output_sample_rate)
def tts(self, text: str, speaker_idx: str = "", speaker_wav=None, style_wav=None) -> List[int]:
def tts(
self,
text: str,
speaker_name: str = "",
language_name: str = "",
speaker_wav: Union[str, List[str]] = None,
style_wav=None,
) -> List[int]:
"""🐸 TTS magic. Run all the models and generate speech.
Args:
text (str): input text.
speaker_idx (str, optional): spekaer id for multi-speaker models. Defaults to "".
speaker_wav ():
speaker_name (str, optional): spekaer id for multi-speaker models. Defaults to "".
language_name (str, optional): language id for multi-language models. Defaults to "".
speaker_wav (Union[str, List[str]], optional): path to the speaker wav. Defaults to None.
style_wav ([type], optional): style waveform for GST. Defaults to None.
Returns:
@ -196,29 +274,49 @@ class Synthesizer(object):
speaker_embedding = None
speaker_id = None
if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "speaker_ids"):
if speaker_idx and isinstance(speaker_idx, str):
if speaker_name and isinstance(speaker_name, str):
if self.tts_config.use_d_vector_file:
# get the speaker embedding from the saved d_vectors.
speaker_embedding = self.tts_model.speaker_manager.get_d_vectors_by_speaker(speaker_idx)[0]
speaker_embedding = self.tts_model.speaker_manager.get_d_vectors_by_speaker(speaker_name)[0]
speaker_embedding = np.array(speaker_embedding)[None, :] # [1 x embedding_dim]
else:
# get speaker idx from the speaker name
speaker_id = self.tts_model.speaker_manager.speaker_ids[speaker_idx]
speaker_id = self.tts_model.speaker_manager.speaker_ids[speaker_name]
elif not speaker_idx and not speaker_wav:
elif not speaker_name and not speaker_wav:
raise ValueError(
" [!] Look like you use a multi-speaker model. "
"You need to define either a `speaker_idx` or a `style_wav` to use a multi-speaker model."
"You need to define either a `speaker_name` or a `style_wav` to use a multi-speaker model."
)
else:
speaker_embedding = None
else:
if speaker_idx:
if speaker_name:
raise ValueError(
f" [!] Missing speakers.json file path for selecting speaker {speaker_idx}."
f" [!] Missing speakers.json file path for selecting speaker {speaker_name}."
"Define path for speaker.json if it is a multi-speaker model or remove defined speaker idx. "
)
# handle multi-lingaul
language_id = None
if self.tts_languages_file or (
hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None
):
if language_name and isinstance(language_name, str):
language_id = self.tts_model.language_manager.language_id_mapping[language_name]
elif not language_name:
raise ValueError(
" [!] Look like you use a multi-lingual model. "
"You need to define either a `language_name` or a `style_wav` to use a multi-lingual model."
)
else:
raise ValueError(
f" [!] Missing language_ids.json file path for selecting language {language_name}."
"Define path for language_ids.json if it is a multi-lingual model or remove defined language idx. "
)
# compute a new d_vector from the given clip.
if speaker_wav is not None:
speaker_embedding = self.tts_model.speaker_manager.compute_d_vector_from_clip(speaker_wav)
@ -234,6 +332,8 @@ class Synthesizer(object):
use_cuda=self.use_cuda,
ap=self.ap,
speaker_id=speaker_id,
language_id=language_id,
language_name=language_name,
style_wav=style_wav,
enable_eos_bos_chars=self.tts_config.enable_eos_bos_chars,
use_griffin_lim=use_gl,

144
TTS/utils/vad.py Normal file
View File

@ -0,0 +1,144 @@
# This code is adpated from: https://github.com/wiseman/py-webrtcvad/blob/master/example.py
import collections
import contextlib
import wave
import webrtcvad
def read_wave(path):
"""Reads a .wav file.
Takes the path, and returns (PCM audio data, sample rate).
"""
with contextlib.closing(wave.open(path, "rb")) as wf:
num_channels = wf.getnchannels()
assert num_channels == 1
sample_width = wf.getsampwidth()
assert sample_width == 2
sample_rate = wf.getframerate()
assert sample_rate in (8000, 16000, 32000, 48000)
pcm_data = wf.readframes(wf.getnframes())
return pcm_data, sample_rate
def write_wave(path, audio, sample_rate):
"""Writes a .wav file.
Takes path, PCM audio data, and sample rate.
"""
with contextlib.closing(wave.open(path, "wb")) as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(sample_rate)
wf.writeframes(audio)
class Frame(object):
"""Represents a "frame" of audio data."""
def __init__(self, _bytes, timestamp, duration):
self.bytes = _bytes
self.timestamp = timestamp
self.duration = duration
def frame_generator(frame_duration_ms, audio, sample_rate):
"""Generates audio frames from PCM audio data.
Takes the desired frame duration in milliseconds, the PCM data, and
the sample rate.
Yields Frames of the requested duration.
"""
n = int(sample_rate * (frame_duration_ms / 1000.0) * 2)
offset = 0
timestamp = 0.0
duration = (float(n) / sample_rate) / 2.0
while offset + n < len(audio):
yield Frame(audio[offset : offset + n], timestamp, duration)
timestamp += duration
offset += n
def vad_collector(sample_rate, frame_duration_ms, padding_duration_ms, vad, frames):
"""Filters out non-voiced audio frames.
Given a webrtcvad.Vad and a source of audio frames, yields only
the voiced audio.
Uses a padded, sliding window algorithm over the audio frames.
When more than 90% of the frames in the window are voiced (as
reported by the VAD), the collector triggers and begins yielding
audio frames. Then the collector waits until 90% of the frames in
the window are unvoiced to detrigger.
The window is padded at the front and back to provide a small
amount of silence or the beginnings/endings of speech around the
voiced frames.
Arguments:
sample_rate - The audio sample rate, in Hz.
frame_duration_ms - The frame duration in milliseconds.
padding_duration_ms - The amount to pad the window, in milliseconds.
vad - An instance of webrtcvad.Vad.
frames - a source of audio frames (sequence or generator).
Returns: A generator that yields PCM audio data.
"""
num_padding_frames = int(padding_duration_ms / frame_duration_ms)
# We use a deque for our sliding window/ring buffer.
ring_buffer = collections.deque(maxlen=num_padding_frames)
# We have two states: TRIGGERED and NOTTRIGGERED. We start in the
# NOTTRIGGERED state.
triggered = False
voiced_frames = []
for frame in frames:
is_speech = vad.is_speech(frame.bytes, sample_rate)
# sys.stdout.write('1' if is_speech else '0')
if not triggered:
ring_buffer.append((frame, is_speech))
num_voiced = len([f for f, speech in ring_buffer if speech])
# If we're NOTTRIGGERED and more than 90% of the frames in
# the ring buffer are voiced frames, then enter the
# TRIGGERED state.
if num_voiced > 0.9 * ring_buffer.maxlen:
triggered = True
# sys.stdout.write('+(%s)' % (ring_buffer[0][0].timestamp,))
# We want to yield all the audio we see from now until
# we are NOTTRIGGERED, but we have to start with the
# audio that's already in the ring buffer.
for f, _ in ring_buffer:
voiced_frames.append(f)
ring_buffer.clear()
else:
# We're in the TRIGGERED state, so collect the audio data
# and add it to the ring buffer.
voiced_frames.append(frame)
ring_buffer.append((frame, is_speech))
num_unvoiced = len([f for f, speech in ring_buffer if not speech])
# If more than 90% of the frames in the ring buffer are
# unvoiced, then enter NOTTRIGGERED and yield whatever
# audio we've collected.
if num_unvoiced > 0.9 * ring_buffer.maxlen:
# sys.stdout.write('-(%s)' % (frame.timestamp + frame.duration))
triggered = False
yield b"".join([f.bytes for f in voiced_frames])
ring_buffer.clear()
voiced_frames = []
# If we have any leftover voiced audio when we run out of input,
# yield it.
if voiced_frames:
yield b"".join([f.bytes for f in voiced_frames])
def get_vad_speech_segments(audio, sample_rate, aggressiveness=2, padding_duration_ms=300):
vad = webrtcvad.Vad(int(aggressiveness))
frames = list(frame_generator(30, audio, sample_rate))
segments = vad_collector(sample_rate, 30, padding_duration_ms, vad, frames)
return segments

View File

@ -3,10 +3,15 @@
VITS (Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech
) is an End-to-End (encoder -> vocoder together) TTS model that takes advantage of SOTA DL techniques like GANs, VAE,
Normalizing Flows. It does not require external alignment annotations and learns the text-to-audio alignment
using MAS as explained in the paper. The model architecture is a combination of GlowTTS encoder and HiFiGAN vocoder.
using MAS, as explained in the paper. The model architecture is a combination of GlowTTS encoder and HiFiGAN vocoder.
It is a feed-forward model with x67.12 real-time factor on a GPU.
🐸 YourTTS is a multi-speaker and multi-lingual TTS model that can perform voice conversion and zero-shot speaker adaptation.
It can also learn a new language or voice with a ~ 1 minute long audio clip. This is a big open gate for training
TTS models in low-resources languages. 🐸 YourTTS uses VITS as the backbone architecture coupled with a speaker encoder model.
## Important resources & papers
- 🐸 YourTTS: https://arxiv.org/abs/2112.02418
- VITS: https://arxiv.org/pdf/2106.06103.pdf
- Neural Spline Flows: https://arxiv.org/abs/1906.04032
- Variational Autoencoder: https://arxiv.org/pdf/1312.6114.pdf

View File

@ -180,7 +180,7 @@ def plot_phonemes(train_path, cmu_dict_path, save_path):
plt.figure()
plt.rcParams["figure.figsize"] = (50, 20)
barplot = sns.barplot(x, y)
barplot = sns.barplot(x=x, y=y)
if save_path:
fig = barplot.get_figure()
fig.savefig(os.path.join(save_path, "phoneme_dist"))

View File

@ -0,0 +1,130 @@
import os
from glob import glob
from TTS.config.shared_configs import BaseAudioConfig
from TTS.trainer import Trainer, TrainingArgs
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.vits import Vits, VitsArgs
from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor
output_path = os.path.dirname(os.path.abspath(__file__))
mailabs_path = "/home/julian/workspace/mailabs/**"
dataset_paths = glob(mailabs_path)
dataset_config = [
BaseDatasetConfig(name="mailabs", meta_file_train=None, path=path, language=path.split("/")[-1])
for path in dataset_paths
]
audio_config = BaseAudioConfig(
sample_rate=16000,
win_length=1024,
hop_length=256,
num_mels=80,
preemphasis=0.0,
ref_level_db=20,
log_func="np.log",
do_trim_silence=False,
trim_db=23.0,
mel_fmin=0,
mel_fmax=None,
spec_gain=1.0,
signal_norm=True,
do_amp_to_db_linear=False,
resample=False,
)
vitsArgs = VitsArgs(
use_language_embedding=True,
embedded_language_dim=4,
use_speaker_embedding=True,
use_sdp=False,
)
config = VitsConfig(
model_args=vitsArgs,
audio=audio_config,
run_name="vits_vctk",
use_speaker_embedding=True,
batch_size=32,
eval_batch_size=16,
batch_group_size=0,
num_loader_workers=4,
num_eval_loader_workers=4,
run_eval=True,
test_delay_epochs=-1,
epochs=1000,
text_cleaner="multilingual_cleaners",
use_phonemes=False,
phoneme_language="en-us",
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
compute_input_seq_cache=True,
print_step=25,
use_language_weighted_sampler=True,
print_eval=False,
mixed_precision=False,
sort_by_audio_len=True,
min_seq_len=32 * 256 * 4,
max_seq_len=160000,
output_path=output_path,
datasets=dataset_config,
characters={
"pad": "_",
"eos": "&",
"bos": "*",
"characters": "'(),-.:;¿?abcdefghijklmnopqrstuvwxyzµßàáâäåæçèéêëìíîïñòóôöùúûüąćęłńœśşźżƒабвгдежзийклмнопрстуфхцчшщъыьэюяёєіїґӧ «°±µ»$%&‘’‚“`”„",
"punctuations": "'(),-.:;¿? ",
"phonemes": None,
"unique": True,
},
test_sentences=[
[
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"mary_ann",
None,
"en_US",
],
[
"Il m'a fallu beaucoup de temps pour d\u00e9velopper une voix, et maintenant que je l'ai, je ne vais pas me taire.",
"ezwa",
None,
"fr_FR",
],
["Ich finde, dieses Startup ist wirklich unglaublich.", "eva_k", None, "de_DE"],
["Я думаю, что этот стартап действительно удивительный.", "oblomov", None, "ru_RU"],
],
)
# init audio processor
ap = AudioProcessor(**config.audio.to_dict())
# load training samples
train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
# init speaker manager for multi-speaker training
# it maps speaker-id to speaker-name in the model and data-loader
speaker_manager = SpeakerManager()
speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples)
config.model_args.num_speakers = speaker_manager.num_speakers
language_manager = LanguageManager(config=config)
config.model_args.num_languages = language_manager.num_languages
# init model
model = Vits(config, speaker_manager, language_manager)
# init the trainer and 🚀
trainer = Trainer(
TrainingArgs(),
config,
output_path,
model=model,
train_samples=train_samples,
eval_samples=eval_samples,
training_assets={"audio_processor": ap},
)
trainer.fit()

View File

@ -26,3 +26,5 @@ unidic-lite==1.0.8
gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=2.0.0
fsspec>=2021.04.0
pyworld
webrtcvad
torchaudio

View File

@ -38,3 +38,14 @@ def run_cli(command):
def get_test_data_config():
return BaseDatasetConfig(name="ljspeech", path="tests/data/ljspeech/", meta_file_train="metadata.csv")
def assertHasAttr(test_obj, obj, intendedAttr):
# from https://stackoverflow.com/questions/48078636/pythons-unittest-lacks-an-asserthasattr-method-what-should-i-use-instead
testBool = hasattr(obj, intendedAttr)
test_obj.assertTrue(testBool, msg=f"obj lacking an attribute. obj: {obj}, intendedAttr: {intendedAttr}")
def assertHasNotAttr(test_obj, obj, intendedAttr):
testBool = hasattr(obj, intendedAttr)
test_obj.assertFalse(testBool, msg=f"obj should not have an attribute. obj: {obj}, intendedAttr: {intendedAttr}")

View File

@ -0,0 +1,80 @@
import os
import unittest
import torch
from tests import get_tests_output_path, run_cli
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig
torch.manual_seed(1)
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
dataset_config_en = BaseDatasetConfig(
name="ljspeech",
meta_file_train="metadata.csv",
meta_file_val="metadata.csv",
path="tests/data/ljspeech",
language="en",
)
dataset_config_pt = BaseDatasetConfig(
name="ljspeech",
meta_file_train="metadata.csv",
meta_file_val="metadata.csv",
path="tests/data/ljspeech",
language="pt-br",
)
# pylint: disable=protected-access
class TestFindUniquePhonemes(unittest.TestCase):
@staticmethod
def test_espeak_phonemes():
# prepare the config
config = VitsConfig(
batch_size=2,
eval_batch_size=2,
num_loader_workers=0,
num_eval_loader_workers=0,
text_cleaner="english_cleaners",
use_phonemes=True,
use_espeak_phonemes=True,
phoneme_language="en-us",
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
run_eval=True,
test_delay_epochs=-1,
epochs=1,
print_step=1,
print_eval=True,
datasets=[dataset_config_en, dataset_config_pt],
)
config.save_json(config_path)
# run test
run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/find_unique_phonemes.py --config_path "{config_path}"')
@staticmethod
def test_no_espeak_phonemes():
# prepare the config
config = VitsConfig(
batch_size=2,
eval_batch_size=2,
num_loader_workers=0,
num_eval_loader_workers=0,
text_cleaner="english_cleaners",
use_phonemes=True,
use_espeak_phonemes=False,
phoneme_language="en-us",
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
run_eval=True,
test_delay_epochs=-1,
epochs=1,
print_step=1,
print_eval=True,
datasets=[dataset_config_en, dataset_config_pt],
)
config.save_json(config_path)
# run test
run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/find_unique_phonemes.py --config_path "{config_path}"')

View File

@ -0,0 +1,29 @@
import os
import unittest
import torch
from tests import get_tests_input_path, get_tests_output_path, run_cli
torch.manual_seed(1)
# pylint: disable=protected-access
class TestRemoveSilenceVAD(unittest.TestCase):
@staticmethod
def test():
# set paths
wav_path = os.path.join(get_tests_input_path(), "../data/ljspeech/wavs")
output_path = os.path.join(get_tests_output_path(), "output_wavs_removed_silence/")
output_resample_path = os.path.join(get_tests_output_path(), "output_ljspeech_16khz/")
# resample audios
run_cli(
f'CUDA_VISIBLE_DEVICES="" python TTS/bin/resample.py --input_dir "{wav_path}" --output_dir "{output_resample_path}" --output_sr 16000'
)
# run test
run_cli(
f'CUDA_VISIBLE_DEVICES="" python TTS/bin/remove_silence_using_vad.py --input_dir "{output_resample_path}" --output_dir "{output_path}"'
)
run_cli(f'rm -rf "{output_resample_path}"')
run_cli(f'rm -rf "{output_path}"')

View File

@ -13,7 +13,7 @@ file_path = get_tests_input_path()
class LSTMSpeakerEncoderTests(unittest.TestCase):
# pylint: disable=R0201
def test_in_out(self):
dummy_input = T.rand(4, 20, 80) # B x T x D
dummy_input = T.rand(4, 80, 20) # B x D x T
dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)]
model = LSTMSpeakerEncoder(input_dim=80, proj_dim=256, lstm_dim=768, num_lstm_layers=3)
# computing d vectors
@ -34,7 +34,7 @@ class LSTMSpeakerEncoderTests(unittest.TestCase):
assert output.type() == "torch.FloatTensor"
assert abs(assert_diff) < 1e-4, f" [!] output_norm has wrong values - {assert_diff}"
# compute d for a given batch
dummy_input = T.rand(1, 240, 80) # B x T x D
dummy_input = T.rand(1, 80, 240) # B x T x D
output = model.compute_embedding(dummy_input, num_frames=160, num_eval=5)
assert output.shape[0] == 1
assert output.shape[1] == 256
@ -44,7 +44,7 @@ class LSTMSpeakerEncoderTests(unittest.TestCase):
class ResNetSpeakerEncoderTests(unittest.TestCase):
# pylint: disable=R0201
def test_in_out(self):
dummy_input = T.rand(4, 20, 80) # B x T x D
dummy_input = T.rand(4, 80, 20) # B x D x T
dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)]
model = ResNetSpeakerEncoder(input_dim=80, proj_dim=256)
# computing d vectors
@ -61,7 +61,7 @@ class ResNetSpeakerEncoderTests(unittest.TestCase):
assert output.type() == "torch.FloatTensor"
assert abs(assert_diff) < 1e-4, f" [!] output_norm has wrong values - {assert_diff}"
# compute d for a given batch
dummy_input = T.rand(1, 240, 80) # B x T x D
dummy_input = T.rand(1, 80, 240) # B x D x T
output = model.compute_embedding(dummy_input, num_frames=160, num_eval=10)
assert output.shape[0] == 1
assert output.shape[1] == 256

View File

@ -6,7 +6,7 @@ import torch
from tests import get_tests_input_path
from TTS.config import load_config
from TTS.speaker_encoder.utils.generic_utils import setup_model
from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
from TTS.speaker_encoder.utils.io import save_checkpoint
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor
@ -28,7 +28,7 @@ class SpeakerManagerTest(unittest.TestCase):
config.audio.resample = True
# create a dummy speaker encoder
model = setup_model(config)
model = setup_speaker_encoder_model(config)
save_checkpoint(model, None, None, get_tests_input_path(), 0)
# load audio processor and speaker encoder
@ -38,7 +38,7 @@ class SpeakerManagerTest(unittest.TestCase):
# load a sample audio and compute embedding
waveform = ap.load_wav(sample_wav_path)
mel = ap.melspectrogram(waveform)
d_vector = manager.compute_d_vector(mel.T)
d_vector = manager.compute_d_vector(mel)
assert d_vector.shape[1] == 256
# compute d_vector directly from an input file

View File

@ -38,6 +38,11 @@ class TestTTSDataset(unittest.TestCase):
def _create_dataloader(self, batch_size, r, bgs):
items = ljspeech(c.data_path, "metadata.csv")
# add a default language because now the TTSDataset expect a language
language = ""
items = [[*item, language] for item in items]
dataset = TTSDataset(
r,
c.text_cleaner,

View File

@ -0,0 +1,58 @@
import functools
import torch
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.languages import get_language_weighted_sampler
# Fixing random state to avoid random fails
torch.manual_seed(0)
dataset_config_en = BaseDatasetConfig(
name="ljspeech",
meta_file_train="metadata.csv",
meta_file_val="metadata.csv",
path="tests/data/ljspeech",
language="en",
)
dataset_config_pt = BaseDatasetConfig(
name="ljspeech",
meta_file_train="metadata.csv",
meta_file_val="metadata.csv",
path="tests/data/ljspeech",
language="pt-br",
)
# Adding the EN samples twice to create an unbalanced dataset
train_samples, eval_samples = load_tts_samples(
[dataset_config_en, dataset_config_en, dataset_config_pt], eval_split=True
)
def is_balanced(lang_1, lang_2):
return 0.85 < lang_1 / lang_2 < 1.2
random_sampler = torch.utils.data.RandomSampler(train_samples)
ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)])
en, pt = 0, 0
for index in ids:
if train_samples[index][3] == "en":
en += 1
else:
pt += 1
assert not is_balanced(en, pt), "Random sampler is supposed to be unbalanced"
weighted_sampler = get_language_weighted_sampler(train_samples)
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
en, pt = 0, 0
for index in ids:
if train_samples[index][3] == "en":
en += 1
else:
pt += 1
assert is_balanced(en, pt), "Weighted sampler is supposed to be balanced"

View File

@ -0,0 +1,5 @@
{
"en": 0,
"fr-fr": 1,
"pt-br": 2
}

View File

@ -0,0 +1,240 @@
import os
import unittest
import torch
from tests import assertHasAttr, assertHasNotAttr, get_tests_input_path
from TTS.config import load_config
from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.models.vits import Vits, VitsArgs
from TTS.tts.utils.speakers import SpeakerManager
LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json")
torch.manual_seed(1)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# pylint: disable=no-self-use
class TestVits(unittest.TestCase):
def test_init_multispeaker(self):
num_speakers = 10
args = VitsArgs(num_speakers=num_speakers, use_speaker_embedding=True)
model = Vits(args)
assertHasAttr(self, model, "emb_g")
args = VitsArgs(num_speakers=0, use_speaker_embedding=True)
model = Vits(args)
assertHasNotAttr(self, model, "emb_g")
args = VitsArgs(num_speakers=10, use_speaker_embedding=False)
model = Vits(args)
assertHasNotAttr(self, model, "emb_g")
args = VitsArgs(d_vector_dim=101, use_d_vector_file=True)
model = Vits(args)
self.assertEqual(model.embedded_speaker_dim, 101)
def test_init_multilingual(self):
args = VitsArgs(language_ids_file=None, use_language_embedding=False)
model = Vits(args)
self.assertEqual(model.language_manager, None)
self.assertEqual(model.embedded_language_dim, 0)
self.assertEqual(model.emb_l, None)
args = VitsArgs(language_ids_file=LANG_FILE)
model = Vits(args)
self.assertNotEqual(model.language_manager, None)
self.assertEqual(model.embedded_language_dim, 0)
self.assertEqual(model.emb_l, None)
args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True)
model = Vits(args)
self.assertNotEqual(model.language_manager, None)
self.assertEqual(model.embedded_language_dim, args.embedded_language_dim)
self.assertNotEqual(model.emb_l, None)
args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, embedded_language_dim=102)
model = Vits(args)
self.assertNotEqual(model.language_manager, None)
self.assertEqual(model.embedded_language_dim, args.embedded_language_dim)
self.assertNotEqual(model.emb_l, None)
def test_get_aux_input(self):
aux_input = {"speaker_ids": None, "style_wav": None, "d_vectors": None, "language_ids": None}
args = VitsArgs()
model = Vits(args)
aux_out = model.get_aux_input(aux_input)
speaker_id = torch.randint(10, (1,))
language_id = torch.randint(10, (1,))
d_vector = torch.rand(1, 128)
aux_input = {"speaker_ids": speaker_id, "style_wav": None, "d_vectors": d_vector, "language_ids": language_id}
aux_out = model.get_aux_input(aux_input)
self.assertEqual(aux_out["speaker_ids"].shape, speaker_id.shape)
self.assertEqual(aux_out["language_ids"].shape, language_id.shape)
self.assertEqual(aux_out["d_vectors"].shape, d_vector.unsqueeze(0).transpose(2, 1).shape)
def test_voice_conversion(self):
num_speakers = 10
spec_len = 101
spec_effective_len = 50
args = VitsArgs(num_speakers=num_speakers, use_speaker_embedding=True)
model = Vits(args)
ref_inp = torch.randn(1, spec_len, 513)
ref_inp_len = torch.randint(1, spec_effective_len, (1,))
ref_spk_id = torch.randint(1, num_speakers, (1,))
tgt_spk_id = torch.randint(1, num_speakers, (1,))
o_hat, y_mask, (z, z_p, z_hat) = model.voice_conversion(ref_inp, ref_inp_len, ref_spk_id, tgt_spk_id)
self.assertEqual(o_hat.shape, (1, 1, spec_len * 256))
self.assertEqual(y_mask.shape, (1, 1, spec_len))
self.assertEqual(y_mask.sum(), ref_inp_len[0])
self.assertEqual(z.shape, (1, args.hidden_channels, spec_len))
self.assertEqual(z_p.shape, (1, args.hidden_channels, spec_len))
self.assertEqual(z_hat.shape, (1, args.hidden_channels, spec_len))
def _init_inputs(self, config):
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
input_lengths[-1] = 128
spec = torch.rand(8, config.audio["fft_size"] // 2 + 1, 30).to(device)
spec_lengths = torch.randint(20, 30, (8,)).long().to(device)
spec_lengths[-1] = spec.size(2)
waveform = torch.rand(8, 1, spec.size(2) * config.audio["hop_length"]).to(device)
return input_dummy, input_lengths, spec, spec_lengths, waveform
def _check_forward_outputs(self, config, output_dict, encoder_config=None):
self.assertEqual(
output_dict["model_outputs"].shape[2], config.model_args.spec_segment_size * config.audio["hop_length"]
)
self.assertEqual(output_dict["alignments"].shape, (8, 128, 30))
self.assertEqual(output_dict["alignments"].max(), 1)
self.assertEqual(output_dict["alignments"].min(), 0)
self.assertEqual(output_dict["z"].shape, (8, config.model_args.hidden_channels, 30))
self.assertEqual(output_dict["z_p"].shape, (8, config.model_args.hidden_channels, 30))
self.assertEqual(output_dict["m_p"].shape, (8, config.model_args.hidden_channels, 30))
self.assertEqual(output_dict["logs_p"].shape, (8, config.model_args.hidden_channels, 30))
self.assertEqual(output_dict["m_q"].shape, (8, config.model_args.hidden_channels, 30))
self.assertEqual(output_dict["logs_q"].shape, (8, config.model_args.hidden_channels, 30))
self.assertEqual(
output_dict["waveform_seg"].shape[2], config.model_args.spec_segment_size * config.audio["hop_length"]
)
if encoder_config:
self.assertEqual(output_dict["gt_spk_emb"].shape, (8, encoder_config.model_params["proj_dim"]))
self.assertEqual(output_dict["syn_spk_emb"].shape, (8, encoder_config.model_params["proj_dim"]))
else:
self.assertEqual(output_dict["gt_spk_emb"], None)
self.assertEqual(output_dict["syn_spk_emb"], None)
def test_forward(self):
num_speakers = 0
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
config.model_args.spec_segment_size = 10
input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config)
model = Vits(config).to(device)
output_dict = model.forward(input_dummy, input_lengths, spec, spec_lengths, waveform)
self._check_forward_outputs(config, output_dict)
def test_multispeaker_forward(self):
num_speakers = 10
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
config.model_args.spec_segment_size = 10
input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config)
speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device)
model = Vits(config).to(device)
output_dict = model.forward(
input_dummy, input_lengths, spec, spec_lengths, waveform, aux_input={"speaker_ids": speaker_ids}
)
self._check_forward_outputs(config, output_dict)
def test_multilingual_forward(self):
num_speakers = 10
num_langs = 3
args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10)
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args)
input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config)
speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device)
lang_ids = torch.randint(0, num_langs, (8,)).long().to(device)
model = Vits(config).to(device)
output_dict = model.forward(
input_dummy,
input_lengths,
spec,
spec_lengths,
waveform,
aux_input={"speaker_ids": speaker_ids, "language_ids": lang_ids},
)
self._check_forward_outputs(config, output_dict)
def test_secl_forward(self):
num_speakers = 10
num_langs = 3
speaker_encoder_config = load_config(SPEAKER_ENCODER_CONFIG)
speaker_encoder_config.model_params["use_torch_spec"] = True
speaker_encoder = setup_speaker_encoder_model(speaker_encoder_config).to(device)
speaker_manager = SpeakerManager()
speaker_manager.speaker_encoder = speaker_encoder
args = VitsArgs(
language_ids_file=LANG_FILE,
use_language_embedding=True,
spec_segment_size=10,
use_speaker_encoder_as_loss=True,
)
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args)
config.audio.sample_rate = 16000
input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config)
speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device)
lang_ids = torch.randint(0, num_langs, (8,)).long().to(device)
model = Vits(config, speaker_manager=speaker_manager).to(device)
output_dict = model.forward(
input_dummy,
input_lengths,
spec,
spec_lengths,
waveform,
aux_input={"speaker_ids": speaker_ids, "language_ids": lang_ids},
)
self._check_forward_outputs(config, output_dict, speaker_encoder_config)
def test_inference(self):
num_speakers = 0
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
input_dummy = torch.randint(0, 24, (1, 128)).long().to(device)
model = Vits(config).to(device)
_ = model.inference(input_dummy)
def test_multispeaker_inference(self):
num_speakers = 10
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
input_dummy = torch.randint(0, 24, (1, 128)).long().to(device)
speaker_ids = torch.randint(0, num_speakers, (1,)).long().to(device)
model = Vits(config).to(device)
_ = model.inference(input_dummy, {"speaker_ids": speaker_ids})
def test_multilingual_inference(self):
num_speakers = 10
num_langs = 3
args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10)
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args)
input_dummy = torch.randint(0, 24, (1, 128)).long().to(device)
speaker_ids = torch.randint(0, num_speakers, (1,)).long().to(device)
lang_ids = torch.randint(0, num_langs, (1,)).long().to(device)
model = Vits(config).to(device)
_ = model.inference(input_dummy, {"speaker_ids": speaker_ids, "language_ids": lang_ids})

View File

@ -0,0 +1,62 @@
import glob
import os
import shutil
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs.vits_config import VitsConfig
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")
config = VitsConfig(
batch_size=2,
eval_batch_size=2,
num_loader_workers=0,
num_eval_loader_workers=0,
text_cleaner="english_cleaners",
use_phonemes=True,
use_espeak_phonemes=True,
phoneme_language="en-us",
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
run_eval=True,
test_delay_epochs=-1,
epochs=1,
print_step=1,
print_eval=True,
test_sentences=[
["Be a voice, not an echo.", "ljspeech-0"],
],
)
# set audio config
config.audio.do_trim_silence = True
config.audio.trim_db = 60
# active multispeaker d-vec mode
config.model_args.use_d_vector_file = True
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
config.model_args.d_vector_dim = 256
config.save_json(config_path)
# train the model for one epoch
command_train = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
f"--coqpit.output_path {output_path} "
"--coqpit.datasets.0.name ljspeech "
"--coqpit.datasets.0.meta_file_train metadata.csv "
"--coqpit.datasets.0.meta_file_val metadata.csv "
"--coqpit.datasets.0.path tests/data/ljspeech "
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
"--coqpit.test_delay_epochs 0"
)
run_cli(command_train)
# Find latest folder
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
# restore the model and continue training for one more epoch
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
run_cli(command_train)
shutil.rmtree(continue_path)

View File

@ -0,0 +1,91 @@
import glob
import os
import shutil
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")
dataset_config_en = BaseDatasetConfig(
name="ljspeech",
meta_file_train="metadata.csv",
meta_file_val="metadata.csv",
path="tests/data/ljspeech",
language="en",
)
dataset_config_pt = BaseDatasetConfig(
name="ljspeech",
meta_file_train="metadata.csv",
meta_file_val="metadata.csv",
path="tests/data/ljspeech",
language="pt-br",
)
config = VitsConfig(
batch_size=2,
eval_batch_size=2,
num_loader_workers=0,
num_eval_loader_workers=0,
text_cleaner="english_cleaners",
use_phonemes=False,
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
run_eval=True,
test_delay_epochs=-1,
epochs=1,
print_step=1,
print_eval=True,
test_sentences=[
["Be a voice, not an echo.", "ljspeech-0", None, "en"],
["Be a voice, not an echo.", "ljspeech-1", None, "pt-br"],
],
datasets=[dataset_config_en, dataset_config_pt],
)
# set audio config
config.audio.do_trim_silence = True
config.audio.trim_db = 60
# active multilingual mode
config.model_args.use_language_embedding = True
config.use_language_embedding = True
# deactivate multispeaker mode
config.model_args.use_speaker_embedding = False
config.use_speaker_embedding = False
# active multispeaker d-vec mode
config.model_args.use_d_vector_file = True
config.use_d_vector_file = True
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
config.d_vector_file = "tests/data/ljspeech/speakers.json"
config.model_args.d_vector_dim = 256
config.d_vector_dim = 256
# duration predictor
config.model_args.use_sdp = True
config.use_sdp = True
# deactivate language sampler
config.use_language_weighted_sampler = False
config.save_json(config_path)
# train the model for one epoch
command_train = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
f"--coqpit.output_path {output_path} "
"--coqpit.test_delay_epochs 0"
)
run_cli(command_train)
# Find latest folder
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
# restore the model and continue training for one more epoch
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
run_cli(command_train)
shutil.rmtree(continue_path)

View File

@ -0,0 +1,88 @@
import glob
import os
import shutil
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")
dataset_config_en = BaseDatasetConfig(
name="ljspeech",
meta_file_train="metadata.csv",
meta_file_val="metadata.csv",
path="tests/data/ljspeech",
language="en",
)
dataset_config_pt = BaseDatasetConfig(
name="ljspeech",
meta_file_train="metadata.csv",
meta_file_val="metadata.csv",
path="tests/data/ljspeech",
language="pt-br",
)
config = VitsConfig(
batch_size=2,
eval_batch_size=2,
num_loader_workers=0,
num_eval_loader_workers=0,
text_cleaner="english_cleaners",
use_phonemes=True,
use_espeak_phonemes=True,
phoneme_language="en-us",
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
run_eval=True,
test_delay_epochs=-1,
epochs=1,
print_step=1,
print_eval=True,
test_sentences=[
["Be a voice, not an echo.", "ljspeech", None, "en"],
["Be a voice, not an echo.", "ljspeech", None, "pt-br"],
],
datasets=[dataset_config_en, dataset_config_pt],
)
# set audio config
config.audio.do_trim_silence = True
config.audio.trim_db = 60
# active multilingual mode
config.model_args.use_language_embedding = True
config.use_language_embedding = True
# active multispeaker mode
config.model_args.use_speaker_embedding = True
config.use_speaker_embedding = True
# deactivate multispeaker d-vec mode
config.model_args.use_d_vector_file = False
config.use_d_vector_file = False
# duration predictor
config.model_args.use_sdp = False
config.use_sdp = False
# active language sampler
config.use_language_weighted_sampler = True
config.save_json(config_path)
# train the model for one epoch
command_train = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
f"--coqpit.output_path {output_path} "
"--coqpit.test_delay_epochs 0"
)
run_cli(command_train)
# Find latest folder
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
# restore the model and continue training for one more epoch
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
run_cli(command_train)
shutil.rmtree(continue_path)

View File

@ -0,0 +1,63 @@
import glob
import os
import shutil
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs.vits_config import VitsConfig
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")
config = VitsConfig(
batch_size=2,
eval_batch_size=2,
num_loader_workers=0,
num_eval_loader_workers=0,
text_cleaner="english_cleaners",
use_phonemes=True,
use_espeak_phonemes=True,
phoneme_language="en-us",
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
run_eval=True,
test_delay_epochs=-1,
epochs=1,
print_step=1,
print_eval=True,
test_sentences=[
["Be a voice, not an echo.", "ljspeech"],
],
)
# set audio config
config.audio.do_trim_silence = True
config.audio.trim_db = 60
# active multispeaker d-vec mode
config.model_args.use_speaker_embedding = True
config.model_args.use_d_vector_file = False
config.model_args.d_vector_file = None
config.model_args.d_vector_dim = 256
config.save_json(config_path)
# train the model for one epoch
command_train = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
f"--coqpit.output_path {output_path} "
"--coqpit.datasets.0.name ljspeech "
"--coqpit.datasets.0.meta_file_train metadata.csv "
"--coqpit.datasets.0.meta_file_val metadata.csv "
"--coqpit.datasets.0.path tests/data/ljspeech "
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
"--coqpit.test_delay_epochs 0"
)
run_cli(command_train)
# Find latest folder
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
# restore the model and continue training for one more epoch
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
run_cli(command_train)
shutil.rmtree(continue_path)

View File

@ -25,7 +25,7 @@ config = VitsConfig(
print_step=1,
print_eval=True,
test_sentences=[
"Be a voice, not an echo.",
["Be a voice, not an echo."],
],
)
config.audio.do_trim_silence = True

View File

@ -4,6 +4,7 @@ import os
import shutil
from tests import get_tests_output_path, run_cli
from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.generic_utils import get_user_data_dir
from TTS.utils.manage import ModelManager
@ -17,21 +18,30 @@ def test_run_all_models():
manager = ModelManager(output_prefix=get_tests_output_path())
model_names = manager.list_models()
for model_name in model_names:
print(f"\n > Run - {model_name}")
model_path, _, _ = manager.download_model(model_name)
if "tts_models" in model_name:
local_download_dir = os.path.dirname(model_path)
# download and run the model
speaker_files = glob.glob(local_download_dir + "/speaker*")
language_files = glob.glob(local_download_dir + "/language*")
language_id = ""
if len(speaker_files) > 0:
# multi-speaker model
if "speaker_ids" in speaker_files[0]:
speaker_manager = SpeakerManager(speaker_id_file_path=speaker_files[0])
elif "speakers" in speaker_files[0]:
speaker_manager = SpeakerManager(d_vectors_file_path=speaker_files[0])
# multi-lingual model - Assuming multi-lingual models are also multi-speaker
if len(language_files) > 0 and "language_ids" in language_files[0]:
language_manager = LanguageManager(language_ids_file_path=language_files[0])
language_id = language_manager.language_names[0]
speaker_id = list(speaker_manager.speaker_ids.keys())[0]
run_cli(
f"tts --model_name {model_name} "
f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}"'
f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}" --language_idx "{language_id}" '
)
else:
# single-speaker model