Support the use of speaker embedding as emotion embedding

This commit is contained in:
Edresson Casanova 2022-06-07 13:29:22 -03:00
parent 360b969c23
commit 4b59f07946
13 changed files with 244 additions and 79 deletions

View File

@ -7,8 +7,7 @@ from tqdm import tqdm
from TTS.config import load_config
from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.managers import save_file
from TTS.tts.utils.managers import EmbeddingManager
from TTS.tts.utils.managers import EmbeddingManager, save_file
parser = argparse.ArgumentParser(
description="""Compute embedding vectors for each wav file in a dataset.\n\n"""

View File

@ -179,7 +179,12 @@ If you don't specify any models, then it uses LJSpeech based English model.
default=None,
)
parser.add_argument("--style_wav", type=str, help="Wav path file for prosody reference.", default=None)
parser.add_argument("--style_speaker_name", type=str, help="The speaker name from the style_wav. If not provide the speaker embedding will be computed using the speaker encoder.", default=None)
parser.add_argument(
"--style_speaker_name",
type=str,
help="The speaker name from the style_wav. If not provide the speaker embedding will be computed using the speaker encoder.",
default=None,
)
parser.add_argument(
"--capacitron_style_text", type=str, help="Transcription of the style_wav reference.", default=None
)

View File

@ -32,7 +32,9 @@ class ReversalClassifier(nn.Module):
reversal (bool): If True reversal the gradients. Default: True
"""
def __init__(self, in_channels, out_channels, hidden_channels, gradient_clipping_bounds=0.25, scale_factor=1.0, reversal=True):
def __init__(
self, in_channels, out_channels, hidden_channels, gradient_clipping_bounds=0.25, scale_factor=1.0, reversal=True
):
super().__init__()
self.reversal = reversal
self._lambda = scale_factor

View File

@ -9,13 +9,8 @@ from torch.nn import functional
from TTS.tts.utils.helpers import sequence_mask
from TTS.tts.utils.ssim import ssim
from TTS.utils.audio import TorchSTFT
from TTS.vocoder.layers.losses import (
MelganFeatureLoss,
MSEDLoss,
MSEGLoss,
_apply_D_loss,
_apply_G_adv_loss,
)
from TTS.vocoder.layers.losses import MelganFeatureLoss, MSEDLoss, MSEGLoss, _apply_D_loss, _apply_G_adv_loss
# pylint: disable=abstract-method
# relates https://github.com/pytorch/pytorch/issues/42305
@ -730,7 +725,6 @@ class VitsGeneratorLoss(nn.Module):
loss += loss_prosody_enc_emo_classifier
return_dict["loss_prosody_enc_emo_classifier"] = loss_prosody_enc_emo_classifier
if loss_text_enc_spk_rev_classifier is not None:
loss_text_enc_spk_rev_classifier = loss_text_enc_spk_rev_classifier * self.speaker_classifier_alpha
loss += loss_text_enc_spk_rev_classifier
@ -779,7 +773,6 @@ class VitsDiscriminatorLoss(nn.Module):
self.disc_latent_loss_alpha = c.disc_latent_loss_alpha
self.disc_latent_gan_loss = MSEDLoss()
@staticmethod
def discriminator_loss(scores_real, scores_fake):
loss = 0

View File

@ -109,7 +109,9 @@ class LatentDiscriminator(nn.Module):
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
self.discriminators = nn.ModuleList(
[
norm_f(nn.Conv2d(1 if hidden_channels is None else hidden_channels, 32, kernel_size=(3, 9), padding=(1, 4))),
norm_f(
nn.Conv2d(1 if hidden_channels is None else hidden_channels, 32, kernel_size=(3, 9), padding=(1, 4))
),
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),

View File

@ -1,5 +1,6 @@
from TTS.tts.layers.tacotron.gst_layers import GST
from TTS.tts.layers.tacotron.capacitron_layers import CapacitronVAE
from TTS.tts.layers.tacotron.gst_layers import GST
class VitsGST(GST):
def __init__(self, *args, **kwargs):
@ -9,6 +10,7 @@ class VitsGST(GST):
style_embed = super().forward(inputs, speaker_embedding=speaker_embedding)
return style_embed, None
class VitsVAE(CapacitronVAE):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

View File

@ -19,12 +19,11 @@ from TTS.tts.configs.shared_configs import CharactersConfig
from TTS.tts.datasets.dataset import TTSDataset, _parse_sample
from TTS.tts.layers.generic.classifier import ReversalClassifier
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.emotions import EmotionManager
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
@ -37,7 +36,6 @@ from TTS.tts.utils.visual import plot_alignment
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
from TTS.vocoder.utils.generic_utils import plot_results
##############################
# IO / Feature extraction
##############################
@ -684,7 +682,11 @@ class Vits(BaseTTS):
dp_cond_embedding_dim += self.args.prosody_embedding_dim
dp_extra_inp_dim = 0
if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings or self.args.use_speaker_embedding_as_emotion) and not self.args.use_noise_scale_predictor:
if (
self.args.use_emotion_embedding
or self.args.use_external_emotions_embeddings
or self.args.use_speaker_embedding_as_emotion
) and not self.args.use_noise_scale_predictor:
dp_extra_inp_dim += self.args.emotion_embedding_dim
if self.args.use_prosody_encoder and not self.args.use_noise_scale_predictor:
@ -711,22 +713,22 @@ class Vits(BaseTTS):
)
if self.args.use_prosody_encoder:
if self.args.prosody_encoder_type == 'gst':
if self.args.prosody_encoder_type == "gst":
self.prosody_encoder = VitsGST(
num_mel=self.args.hidden_channels,
num_heads=self.args.prosody_encoder_num_heads,
num_style_tokens=self.args.prosody_encoder_num_tokens,
gst_embedding_dim=self.args.prosody_embedding_dim,
)
elif self.args.prosody_encoder_type == 'vae':
elif self.args.prosody_encoder_type == "vae":
self.prosody_encoder = VitsVAE(
num_mel=self.args.hidden_channels,
capacitron_VAE_embedding_dim=self.args.prosody_embedding_dim,
)
else:
raise RuntimeError(
f" [!] The Prosody encoder type {self.args.prosody_encoder_type} is not supported !!"
)
f" [!] The Prosody encoder type {self.args.prosody_encoder_type} is not supported !!"
)
if self.args.use_prosody_enc_spk_reversal_classifier:
self.speaker_reversal_classifier = ReversalClassifier(
in_channels=self.args.prosody_embedding_dim,
@ -738,12 +740,16 @@ class Vits(BaseTTS):
in_channels=self.args.prosody_embedding_dim,
out_channels=self.num_emotions,
hidden_channels=256,
reversal=False
reversal=False,
)
if self.args.use_noise_scale_predictor:
noise_scale_predictor_input_dim = self.args.hidden_channels
if self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings:
if (
self.args.use_emotion_embedding
or self.args.use_external_emotions_embeddings
or self.args.use_speaker_embedding_as_emotion
):
noise_scale_predictor_input_dim += self.args.emotion_embedding_dim
if self.args.use_prosody_encoder:
@ -763,15 +769,18 @@ class Vits(BaseTTS):
)
if self.args.use_emotion_embedding_squeezer:
self.emotion_embedding_squeezer = nn.Linear(in_features=self.args.emotion_embedding_squeezer_input_dim, out_features=self.args.emotion_embedding_dim)
self.emotion_embedding_squeezer = nn.Linear(
in_features=self.args.emotion_embedding_squeezer_input_dim, out_features=self.args.emotion_embedding_dim
)
if self.args.use_speaker_embedding_squeezer:
self.speaker_embedding_squeezer = nn.Linear(in_features=self.args.speaker_embedding_squeezer_input_dim, out_features=self.cond_embedding_dim)
self.speaker_embedding_squeezer = nn.Linear(
in_features=self.args.speaker_embedding_squeezer_input_dim, out_features=self.cond_embedding_dim
)
if self.args.use_text_enc_spk_reversal_classifier:
self.speaker_text_enc_reversal_classifier = ReversalClassifier(
in_channels=self.args.hidden_channels
+ dp_extra_inp_dim,
in_channels=self.args.hidden_channels + dp_extra_inp_dim,
out_channels=self.num_speakers,
hidden_channels=256,
)
@ -781,7 +790,7 @@ class Vits(BaseTTS):
in_channels=self.args.hidden_channels,
out_channels=self.num_emotions,
hidden_channels=256,
reversal=False
reversal=False,
)
self.waveform_decoder = HifiganGenerator(
@ -1176,9 +1185,16 @@ class Vits(BaseTTS):
if self.args.use_language_embedding and lid is not None:
lang_emb = self.emb_l(lid).unsqueeze(-1)
if self.args.use_speaker_embedding_as_emotion:
eg = g
# squeezers
if self.args.use_emotion_embedding_squeezer:
if self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings:
if (
self.args.use_emotion_embedding
or self.args.use_external_emotions_embeddings
or self.args.use_speaker_embedding_as_emotion
):
eg = F.normalize(self.emotion_embedding_squeezer(eg.squeeze(-1))).unsqueeze(-1)
if self.args.use_speaker_embedding_squeezer:
@ -1200,7 +1216,7 @@ class Vits(BaseTTS):
prosody_encoder_input = z_p if self.args.use_prosody_encoder_z_p_input else z
pros_emb, vae_outputs = self.prosody_encoder(
prosody_encoder_input.detach() if self.args.detach_prosody_enc_input else prosody_encoder_input,
y_lengths
y_lengths,
)
pros_emb = pros_emb.transpose(1, 2)
@ -1215,7 +1231,7 @@ class Vits(BaseTTS):
x_lengths,
lang_emb=lang_emb,
emo_emb=eg if not self.args.use_noise_scale_predictor else None,
pros_emb=pros_emb if not self.args.use_noise_scale_predictor else None
pros_emb=pros_emb if not self.args.use_noise_scale_predictor else None,
)
# reversal speaker loss to force the encoder to be speaker identity free
@ -1250,10 +1266,14 @@ class Vits(BaseTTS):
if self.args.use_noise_scale_predictor:
nsp_input = torch.transpose(m_p_expanded, 1, -1)
if self.args.use_prosody_encoder and pros_emb is not None:
nsp_input = torch.cat((nsp_input, pros_emb.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1)
nsp_input = torch.cat(
(nsp_input, pros_emb.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1
)
if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and eg is not None:
nsp_input = torch.cat((nsp_input, eg.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1)
nsp_input = torch.cat(
(nsp_input, eg.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1
)
nsp_input = torch.transpose(nsp_input, 1, -1) * y_mask
m_p_noise_scale = self.noise_scale_predictor(nsp_input, y_mask)
@ -1314,7 +1334,7 @@ class Vits(BaseTTS):
"loss_prosody_enc_spk_rev_classifier": l_pros_speaker,
"loss_prosody_enc_emo_classifier": l_pros_emotion,
"loss_text_enc_spk_rev_classifier": l_text_speaker,
"loss_text_enc_emo_classifier": l_text_emotion
"loss_text_enc_emo_classifier": l_text_emotion,
}
)
return outputs
@ -1373,9 +1393,16 @@ class Vits(BaseTTS):
if self.args.use_language_embedding and lid is not None:
lang_emb = self.emb_l(lid).unsqueeze(-1)
if self.args.use_speaker_embedding_as_emotion:
eg = g
# squeezers
if self.args.use_emotion_embedding_squeezer:
if self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings:
if (
self.args.use_emotion_embedding
or self.args.use_external_emotions_embeddings
or self.args.use_speaker_embedding_as_emotion
):
eg = F.normalize(self.emotion_embedding_squeezer(eg.squeeze(-1))).unsqueeze(-1)
if self.args.use_speaker_embedding_squeezer:
@ -1399,13 +1426,12 @@ class Vits(BaseTTS):
pros_emb, _ = self.prosody_encoder(z_p_inf, pf_lengths)
pros_emb = pros_emb.transpose(1, 2)
x, m_p, logs_p, x_mask = self.text_encoder(
x,
x_lengths,
lang_emb=lang_emb,
emo_emb=eg if not self.args.use_noise_scale_predictor else None,
pros_emb=pros_emb if not self.args.use_noise_scale_predictor else None
pros_emb=pros_emb if not self.args.use_noise_scale_predictor else None,
)
# duration predictor
@ -1448,10 +1474,14 @@ class Vits(BaseTTS):
if self.args.use_noise_scale_predictor:
nsp_input = torch.transpose(m_p, 1, -1)
if self.args.use_prosody_encoder and pros_emb is not None:
nsp_input = torch.cat((nsp_input, pros_emb.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1)
nsp_input = torch.cat(
(nsp_input, pros_emb.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1
)
if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and eg is not None:
nsp_input = torch.cat((nsp_input, eg.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1)
nsp_input = torch.cat(
(nsp_input, eg.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1
)
nsp_input = torch.transpose(nsp_input, 1, -1) * y_mask
m_p_noise_scale = self.noise_scale_predictor(nsp_input, y_mask)
@ -1521,7 +1551,6 @@ class Vits(BaseTTS):
wav, _, _ = self.voice_conversion(y, y_lengths, speaker_cond_src, speaker_cond_tgt)
return wav
def voice_conversion(self, y, y_lengths, speaker_cond_src, speaker_cond_tgt):
"""Forward pass for voice conversion
@ -1550,7 +1579,6 @@ class Vits(BaseTTS):
o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt)
return o_hat, y_mask, (z, z_p, z_hat)
def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
"""Perform a single training step. Run the model forward pass and compute losses.
@ -1599,18 +1627,16 @@ class Vits(BaseTTS):
self.model_outputs_cache = outputs # pylint: disable=attribute-defined-outside-init
# compute scores and features
scores_disc_fake, _, scores_disc_real, _, scores_disc_mp, _, scores_disc_zp, _= self.disc(
outputs["model_outputs"].detach(), outputs["waveform_seg"], outputs["m_p"].detach(), outputs["z_p"].detach()
scores_disc_fake, _, scores_disc_real, _, scores_disc_mp, _, scores_disc_zp, _ = self.disc(
outputs["model_outputs"].detach(),
outputs["waveform_seg"],
outputs["m_p"].detach(),
outputs["z_p"].detach(),
)
# compute loss
with autocast(enabled=False): # use float32 for the criterion
loss_dict = criterion[optimizer_idx](
scores_disc_real,
scores_disc_fake,
scores_disc_zp,
scores_disc_mp
)
loss_dict = criterion[optimizer_idx](scores_disc_real, scores_disc_fake, scores_disc_zp, scores_disc_mp)
return outputs, loss_dict
if optimizer_idx == 1:
@ -1640,8 +1666,20 @@ class Vits(BaseTTS):
)
# compute discriminator scores and features
scores_disc_fake, feats_disc_fake, _, feats_disc_real, scores_disc_mp, feats_disc_mp, _, feats_disc_zp = self.disc(
self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"], self.model_outputs_cache["m_p"], self.model_outputs_cache["z_p"].detach()
(
scores_disc_fake,
feats_disc_fake,
_,
feats_disc_real,
scores_disc_mp,
feats_disc_mp,
_,
feats_disc_zp,
) = self.disc(
self.model_outputs_cache["model_outputs"],
self.model_outputs_cache["waveform_seg"],
self.model_outputs_cache["m_p"],
self.model_outputs_cache["z_p"].detach(),
)
# compute losses
@ -1669,7 +1707,7 @@ class Vits(BaseTTS):
loss_text_enc_emo_classifier=self.model_outputs_cache["loss_text_enc_emo_classifier"],
scores_disc_mp=scores_disc_mp,
feats_disc_mp=feats_disc_mp,
feats_disc_zp=feats_disc_zp
feats_disc_zp=feats_disc_zp,
)
return self.model_outputs_cache, loss_dict
@ -1729,7 +1767,14 @@ class Vits(BaseTTS):
config = self.config
# extract speaker and language info
text, speaker_name, style_wav, language_name, emotion_name, style_speaker_name = None, None, None, None, None, None
text, speaker_name, style_wav, language_name, emotion_name, style_speaker_name = (
None,
None,
None,
None,
None,
None,
)
if isinstance(sentence_info, list):
if len(sentence_info) == 1:
@ -1748,12 +1793,18 @@ class Vits(BaseTTS):
text = sentence_info
if style_wav and style_speaker_name is None:
raise RuntimeError(
" [!] You must to provide the style_speaker_name for the style_wav !!"
)
raise RuntimeError(" [!] You must to provide the style_speaker_name for the style_wav !!")
# get speaker id/d_vector
speaker_id, d_vector, language_id, emotion_id, emotion_embedding, style_speaker_id, style_speaker_d_vector = None, None, None, None, None, None, None
speaker_id, d_vector, language_id, emotion_id, emotion_embedding, style_speaker_id, style_speaker_d_vector = (
None,
None,
None,
None,
None,
None,
None,
)
if hasattr(self, "speaker_manager"):
if config.use_d_vector_file:
if speaker_name is None:
@ -1762,7 +1813,9 @@ class Vits(BaseTTS):
d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False)
if style_wav is not None:
style_speaker_d_vector = self.speaker_manager.get_mean_embedding(style_speaker_name, num_samples=None, randomize=False)
style_speaker_d_vector = self.speaker_manager.get_mean_embedding(
style_speaker_name, num_samples=None, randomize=False
)
elif config.use_speaker_embedding:
if speaker_name is None:
@ -1893,7 +1946,15 @@ class Vits(BaseTTS):
emotion_embeddings = [emotion_mapping[w]["embedding"] for w in batch["audio_files"]]
emotion_embeddings = torch.FloatTensor(emotion_embeddings)
if self.emotion_manager is not None and self.emotion_manager.embeddings and (self.args.use_emotion_embedding or self.args.use_prosody_enc_emo_classifier or self.args.use_text_enc_emo_classifier):
if (
self.emotion_manager is not None
and self.emotion_manager.embeddings
and (
self.args.use_emotion_embedding
or self.args.use_prosody_enc_emo_classifier
or self.args.use_text_enc_emo_classifier
)
):
emotion_mapping = self.emotion_manager.embeddings
emotion_names = [emotion_mapping[w]["name"] for w in batch["audio_files"]]
emotion_ids = [self.emotion_manager.ids[en] for en in emotion_names]

View File

@ -94,7 +94,11 @@ class EmotionManager(EmbeddingManager):
EmotionEncoder: Emotion encoder object.
"""
emotion_manager = None
if get_from_config_or_model_args_with_default(config, "use_emotion_embedding", False) or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False) or get_from_config_or_model_args_with_default(config, "use_text_enc_emo_classifier", False):
if (
get_from_config_or_model_args_with_default(config, "use_emotion_embedding", False)
or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False)
or get_from_config_or_model_args_with_default(config, "use_text_enc_emo_classifier", False)
):
if get_from_config_or_model_args_with_default(config, "emotions_ids_file", None):
emotion_manager = EmotionManager(
emotion_id_file_path=get_from_config_or_model_args_with_default(config, "emotions_ids_file", None)
@ -106,7 +110,11 @@ class EmotionManager(EmbeddingManager):
)
)
if get_from_config_or_model_args_with_default(config, "use_external_emotions_embeddings", False) or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False) or get_from_config_or_model_args_with_default(config, "use_text_enc_emo_classifier", False):
if (
get_from_config_or_model_args_with_default(config, "use_external_emotions_embeddings", False)
or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False)
or get_from_config_or_model_args_with_default(config, "use_text_enc_emo_classifier", False)
):
if get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None):
emotion_manager = EmotionManager(
embeddings_file_path=get_from_config_or_model_args_with_default(

View File

@ -217,7 +217,6 @@ def synthesis(
if style_speaker_d_vector is not None:
style_speaker_d_vector = embedding_to_torch(style_speaker_d_vector, cuda=use_cuda)
if not isinstance(style_feature, dict):
# GST or Capacitron style mel
style_feature = numpy_to_torch(style_feature, torch.float, cuda=use_cuda)

View File

@ -306,9 +306,17 @@ class Synthesizer(object):
# handle emotion
emotion_embedding, emotion_id = None, None
if not reference_wav and not getattr(self.tts_model, "prosody_encoder", False) and (self.tts_emotions_file or (
getattr(self.tts_model, "emotion_manager", None) and getattr(self.tts_model.emotion_manager, "ids", None)
)):
if (
not reference_wav
and not getattr(self.tts_model, "prosody_encoder", False)
and (
self.tts_emotions_file
or (
getattr(self.tts_model, "emotion_manager", None)
and getattr(self.tts_model.emotion_manager, "ids", None)
)
)
):
if emotion_name and isinstance(emotion_name, str):
if getattr(self.tts_config, "use_external_emotions_embeddings", False) or (
getattr(self.tts_config, "model_args", None)
@ -426,7 +434,7 @@ class Synthesizer(object):
d_vector=speaker_embedding,
use_griffin_lim=use_gl,
reference_speaker_id=reference_speaker_id,
reference_d_vector=reference_speaker_embedding
reference_d_vector=reference_speaker_embedding,
)
waveform = outputs
if not use_gl:

View File

@ -58,6 +58,9 @@ continue_restore_path, _ = get_last_checkpoint(continue_path)
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
speaker_id = "ljspeech-1"
continue_speakers_path = os.path.join(continue_path, "speakers.json")
if not os.path.isfile(continue_speakers_path):
continue_speakers_path = continue_speakers_path.replace(".json", ".pth")
# Check integrity of the config
with open(continue_config_path, "r", encoding="utf-8") as f:

View File

@ -37,23 +37,16 @@ config.audio.trim_db = 60
config.model_args.use_speaker_embedding = False
config.model_args.use_d_vector_file = True
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
config.model_args.speaker_embedding_channels = 128
config.model_args.d_vector_dim = 100
config.model_args.speaker_embedding_channels = 256
config.model_args.d_vector_dim = 256
# emotion
config.model_args.use_external_emotions_embeddings = True
config.model_args.use_emotion_embedding = False
config.model_args.emotion_embedding_dim = 64
config.model_args.emotion_embedding_dim = 256
config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"
config.model_args.use_text_enc_spk_reversal_classifier = False
config.model_args.use_emotion_embedding_squeezer = True
config.model_args.emotion_embedding_squeezer_input_dim = 256
config.model_args.use_speaker_embedding_squeezer = True
config.model_args.speaker_embedding_squeezer_input_dim = 256
# consistency loss
# config.model_args.use_emotion_encoder_as_loss = True
# config.model_args.encoder_model_path = "/raid/edresson/dev/Checkpoints/Coqui-Realesead/tts_models--multilingual--multi-dataset--your_tts/model_se.pth.tar"

View File

@ -0,0 +1,90 @@
import glob
import os
import shutil
from trainer import get_last_checkpoint
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs.vits_config import VitsConfig
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")
config = VitsConfig(
batch_size=2,
eval_batch_size=2,
num_loader_workers=0,
num_eval_loader_workers=0,
text_cleaner="english_cleaners",
use_phonemes=True,
phoneme_language="en-us",
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
run_eval=True,
test_delay_epochs=-1,
epochs=1,
print_step=1,
print_eval=True,
test_sentences=[
["Be a voice, not an echo.", "ljspeech-1", None, None, "ljspeech-1"],
],
)
# set audio config
config.audio.do_trim_silence = True
config.audio.trim_db = 60
# active multispeaker d-vec mode
config.model_args.use_speaker_embedding = False
config.model_args.use_d_vector_file = True
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
config.model_args.speaker_embedding_channels = 128
config.model_args.d_vector_dim = 256
# emotion
config.model_args.emotion_embedding_dim = 256
config.model_args.use_emotion_embedding_squeezer = False
config.model_args.emotion_embedding_squeezer_input_dim = 256
config.model_args.use_speaker_embedding_as_emotion = True
config.model_args.use_speaker_embedding_squeezer = False
config.model_args.speaker_embedding_squeezer_input_dim = 256
# consistency loss
# config.model_args.use_emotion_encoder_as_loss = True
# config.model_args.encoder_model_path = "/raid/edresson/dev/Checkpoints/Coqui-Realesead/tts_models--multilingual--multi-dataset--your_tts/model_se.pth.tar"
# config.model_args.encoder_config_path = "/raid/edresson/dev/Checkpoints/Coqui-Realesead/tts_models--multilingual--multi-dataset--your_tts/config_se.json"
config.save_json(config_path)
# train the model for one epoch
command_train = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
f"--coqpit.output_path {output_path} "
"--coqpit.datasets.0.name ljspeech_test "
"--coqpit.datasets.0.meta_file_train metadata.csv "
"--coqpit.datasets.0.meta_file_val metadata.csv "
"--coqpit.datasets.0.path tests/data/ljspeech "
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
"--coqpit.test_delay_epochs 0"
)
run_cli(command_train)
# Find latest folder
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
# Inference using TTS API
continue_config_path = os.path.join(continue_path, "config.json")
continue_restore_path, _ = get_last_checkpoint(continue_path)
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
speaker_id = "ljspeech-1"
continue_speakers_path = config.model_args.d_vector_file
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
run_cli(inference_command)
# restore the model and continue training for one more epoch
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
run_cli(command_train)
shutil.rmtree(continue_path)