mirror of https://github.com/coqui-ai/TTS.git
Support the use of speaker embedding as emotion embedding
This commit is contained in:
parent
360b969c23
commit
4b59f07946
|
@ -7,8 +7,7 @@ from tqdm import tqdm
|
|||
|
||||
from TTS.config import load_config
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.utils.managers import save_file
|
||||
from TTS.tts.utils.managers import EmbeddingManager
|
||||
from TTS.tts.utils.managers import EmbeddingManager, save_file
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""Compute embedding vectors for each wav file in a dataset.\n\n"""
|
||||
|
|
|
@ -179,7 +179,12 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
default=None,
|
||||
)
|
||||
parser.add_argument("--style_wav", type=str, help="Wav path file for prosody reference.", default=None)
|
||||
parser.add_argument("--style_speaker_name", type=str, help="The speaker name from the style_wav. If not provide the speaker embedding will be computed using the speaker encoder.", default=None)
|
||||
parser.add_argument(
|
||||
"--style_speaker_name",
|
||||
type=str,
|
||||
help="The speaker name from the style_wav. If not provide the speaker embedding will be computed using the speaker encoder.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--capacitron_style_text", type=str, help="Transcription of the style_wav reference.", default=None
|
||||
)
|
||||
|
|
|
@ -32,7 +32,9 @@ class ReversalClassifier(nn.Module):
|
|||
reversal (bool): If True reversal the gradients. Default: True
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, hidden_channels, gradient_clipping_bounds=0.25, scale_factor=1.0, reversal=True):
|
||||
def __init__(
|
||||
self, in_channels, out_channels, hidden_channels, gradient_clipping_bounds=0.25, scale_factor=1.0, reversal=True
|
||||
):
|
||||
super().__init__()
|
||||
self.reversal = reversal
|
||||
self._lambda = scale_factor
|
||||
|
|
|
@ -9,13 +9,8 @@ from torch.nn import functional
|
|||
from TTS.tts.utils.helpers import sequence_mask
|
||||
from TTS.tts.utils.ssim import ssim
|
||||
from TTS.utils.audio import TorchSTFT
|
||||
from TTS.vocoder.layers.losses import (
|
||||
MelganFeatureLoss,
|
||||
MSEDLoss,
|
||||
MSEGLoss,
|
||||
_apply_D_loss,
|
||||
_apply_G_adv_loss,
|
||||
)
|
||||
from TTS.vocoder.layers.losses import MelganFeatureLoss, MSEDLoss, MSEGLoss, _apply_D_loss, _apply_G_adv_loss
|
||||
|
||||
|
||||
# pylint: disable=abstract-method
|
||||
# relates https://github.com/pytorch/pytorch/issues/42305
|
||||
|
@ -730,7 +725,6 @@ class VitsGeneratorLoss(nn.Module):
|
|||
loss += loss_prosody_enc_emo_classifier
|
||||
return_dict["loss_prosody_enc_emo_classifier"] = loss_prosody_enc_emo_classifier
|
||||
|
||||
|
||||
if loss_text_enc_spk_rev_classifier is not None:
|
||||
loss_text_enc_spk_rev_classifier = loss_text_enc_spk_rev_classifier * self.speaker_classifier_alpha
|
||||
loss += loss_text_enc_spk_rev_classifier
|
||||
|
@ -779,7 +773,6 @@ class VitsDiscriminatorLoss(nn.Module):
|
|||
self.disc_latent_loss_alpha = c.disc_latent_loss_alpha
|
||||
self.disc_latent_gan_loss = MSEDLoss()
|
||||
|
||||
|
||||
@staticmethod
|
||||
def discriminator_loss(scores_real, scores_fake):
|
||||
loss = 0
|
||||
|
|
|
@ -109,7 +109,9 @@ class LatentDiscriminator(nn.Module):
|
|||
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
norm_f(nn.Conv2d(1 if hidden_channels is None else hidden_channels, 32, kernel_size=(3, 9), padding=(1, 4))),
|
||||
norm_f(
|
||||
nn.Conv2d(1 if hidden_channels is None else hidden_channels, 32, kernel_size=(3, 9), padding=(1, 4))
|
||||
),
|
||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from TTS.tts.layers.tacotron.gst_layers import GST
|
||||
from TTS.tts.layers.tacotron.capacitron_layers import CapacitronVAE
|
||||
from TTS.tts.layers.tacotron.gst_layers import GST
|
||||
|
||||
|
||||
class VitsGST(GST):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
@ -9,6 +10,7 @@ class VitsGST(GST):
|
|||
style_embed = super().forward(inputs, speaker_embedding=speaker_embedding)
|
||||
return style_embed, None
|
||||
|
||||
|
||||
class VitsVAE(CapacitronVAE):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
|
|
@ -19,12 +19,11 @@ from TTS.tts.configs.shared_configs import CharactersConfig
|
|||
from TTS.tts.datasets.dataset import TTSDataset, _parse_sample
|
||||
from TTS.tts.layers.generic.classifier import ReversalClassifier
|
||||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||
from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE
|
||||
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
||||
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
||||
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
||||
from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE
|
||||
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
||||
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.emotions import EmotionManager
|
||||
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
|
||||
|
@ -37,7 +36,6 @@ from TTS.tts.utils.visual import plot_alignment
|
|||
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||
from TTS.vocoder.utils.generic_utils import plot_results
|
||||
|
||||
|
||||
##############################
|
||||
# IO / Feature extraction
|
||||
##############################
|
||||
|
@ -684,7 +682,11 @@ class Vits(BaseTTS):
|
|||
dp_cond_embedding_dim += self.args.prosody_embedding_dim
|
||||
|
||||
dp_extra_inp_dim = 0
|
||||
if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings or self.args.use_speaker_embedding_as_emotion) and not self.args.use_noise_scale_predictor:
|
||||
if (
|
||||
self.args.use_emotion_embedding
|
||||
or self.args.use_external_emotions_embeddings
|
||||
or self.args.use_speaker_embedding_as_emotion
|
||||
) and not self.args.use_noise_scale_predictor:
|
||||
dp_extra_inp_dim += self.args.emotion_embedding_dim
|
||||
|
||||
if self.args.use_prosody_encoder and not self.args.use_noise_scale_predictor:
|
||||
|
@ -711,22 +713,22 @@ class Vits(BaseTTS):
|
|||
)
|
||||
|
||||
if self.args.use_prosody_encoder:
|
||||
if self.args.prosody_encoder_type == 'gst':
|
||||
if self.args.prosody_encoder_type == "gst":
|
||||
self.prosody_encoder = VitsGST(
|
||||
num_mel=self.args.hidden_channels,
|
||||
num_heads=self.args.prosody_encoder_num_heads,
|
||||
num_style_tokens=self.args.prosody_encoder_num_tokens,
|
||||
gst_embedding_dim=self.args.prosody_embedding_dim,
|
||||
)
|
||||
elif self.args.prosody_encoder_type == 'vae':
|
||||
elif self.args.prosody_encoder_type == "vae":
|
||||
self.prosody_encoder = VitsVAE(
|
||||
num_mel=self.args.hidden_channels,
|
||||
capacitron_VAE_embedding_dim=self.args.prosody_embedding_dim,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f" [!] The Prosody encoder type {self.args.prosody_encoder_type} is not supported !!"
|
||||
)
|
||||
f" [!] The Prosody encoder type {self.args.prosody_encoder_type} is not supported !!"
|
||||
)
|
||||
if self.args.use_prosody_enc_spk_reversal_classifier:
|
||||
self.speaker_reversal_classifier = ReversalClassifier(
|
||||
in_channels=self.args.prosody_embedding_dim,
|
||||
|
@ -738,12 +740,16 @@ class Vits(BaseTTS):
|
|||
in_channels=self.args.prosody_embedding_dim,
|
||||
out_channels=self.num_emotions,
|
||||
hidden_channels=256,
|
||||
reversal=False
|
||||
reversal=False,
|
||||
)
|
||||
|
||||
if self.args.use_noise_scale_predictor:
|
||||
noise_scale_predictor_input_dim = self.args.hidden_channels
|
||||
if self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings:
|
||||
if (
|
||||
self.args.use_emotion_embedding
|
||||
or self.args.use_external_emotions_embeddings
|
||||
or self.args.use_speaker_embedding_as_emotion
|
||||
):
|
||||
noise_scale_predictor_input_dim += self.args.emotion_embedding_dim
|
||||
|
||||
if self.args.use_prosody_encoder:
|
||||
|
@ -763,15 +769,18 @@ class Vits(BaseTTS):
|
|||
)
|
||||
|
||||
if self.args.use_emotion_embedding_squeezer:
|
||||
self.emotion_embedding_squeezer = nn.Linear(in_features=self.args.emotion_embedding_squeezer_input_dim, out_features=self.args.emotion_embedding_dim)
|
||||
self.emotion_embedding_squeezer = nn.Linear(
|
||||
in_features=self.args.emotion_embedding_squeezer_input_dim, out_features=self.args.emotion_embedding_dim
|
||||
)
|
||||
|
||||
if self.args.use_speaker_embedding_squeezer:
|
||||
self.speaker_embedding_squeezer = nn.Linear(in_features=self.args.speaker_embedding_squeezer_input_dim, out_features=self.cond_embedding_dim)
|
||||
self.speaker_embedding_squeezer = nn.Linear(
|
||||
in_features=self.args.speaker_embedding_squeezer_input_dim, out_features=self.cond_embedding_dim
|
||||
)
|
||||
|
||||
if self.args.use_text_enc_spk_reversal_classifier:
|
||||
self.speaker_text_enc_reversal_classifier = ReversalClassifier(
|
||||
in_channels=self.args.hidden_channels
|
||||
+ dp_extra_inp_dim,
|
||||
in_channels=self.args.hidden_channels + dp_extra_inp_dim,
|
||||
out_channels=self.num_speakers,
|
||||
hidden_channels=256,
|
||||
)
|
||||
|
@ -781,7 +790,7 @@ class Vits(BaseTTS):
|
|||
in_channels=self.args.hidden_channels,
|
||||
out_channels=self.num_emotions,
|
||||
hidden_channels=256,
|
||||
reversal=False
|
||||
reversal=False,
|
||||
)
|
||||
|
||||
self.waveform_decoder = HifiganGenerator(
|
||||
|
@ -1176,9 +1185,16 @@ class Vits(BaseTTS):
|
|||
if self.args.use_language_embedding and lid is not None:
|
||||
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
||||
|
||||
if self.args.use_speaker_embedding_as_emotion:
|
||||
eg = g
|
||||
|
||||
# squeezers
|
||||
if self.args.use_emotion_embedding_squeezer:
|
||||
if self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings:
|
||||
if (
|
||||
self.args.use_emotion_embedding
|
||||
or self.args.use_external_emotions_embeddings
|
||||
or self.args.use_speaker_embedding_as_emotion
|
||||
):
|
||||
eg = F.normalize(self.emotion_embedding_squeezer(eg.squeeze(-1))).unsqueeze(-1)
|
||||
|
||||
if self.args.use_speaker_embedding_squeezer:
|
||||
|
@ -1200,7 +1216,7 @@ class Vits(BaseTTS):
|
|||
prosody_encoder_input = z_p if self.args.use_prosody_encoder_z_p_input else z
|
||||
pros_emb, vae_outputs = self.prosody_encoder(
|
||||
prosody_encoder_input.detach() if self.args.detach_prosody_enc_input else prosody_encoder_input,
|
||||
y_lengths
|
||||
y_lengths,
|
||||
)
|
||||
|
||||
pros_emb = pros_emb.transpose(1, 2)
|
||||
|
@ -1215,7 +1231,7 @@ class Vits(BaseTTS):
|
|||
x_lengths,
|
||||
lang_emb=lang_emb,
|
||||
emo_emb=eg if not self.args.use_noise_scale_predictor else None,
|
||||
pros_emb=pros_emb if not self.args.use_noise_scale_predictor else None
|
||||
pros_emb=pros_emb if not self.args.use_noise_scale_predictor else None,
|
||||
)
|
||||
|
||||
# reversal speaker loss to force the encoder to be speaker identity free
|
||||
|
@ -1250,10 +1266,14 @@ class Vits(BaseTTS):
|
|||
if self.args.use_noise_scale_predictor:
|
||||
nsp_input = torch.transpose(m_p_expanded, 1, -1)
|
||||
if self.args.use_prosody_encoder and pros_emb is not None:
|
||||
nsp_input = torch.cat((nsp_input, pros_emb.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1)
|
||||
nsp_input = torch.cat(
|
||||
(nsp_input, pros_emb.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1
|
||||
)
|
||||
|
||||
if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and eg is not None:
|
||||
nsp_input = torch.cat((nsp_input, eg.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1)
|
||||
nsp_input = torch.cat(
|
||||
(nsp_input, eg.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1
|
||||
)
|
||||
|
||||
nsp_input = torch.transpose(nsp_input, 1, -1) * y_mask
|
||||
m_p_noise_scale = self.noise_scale_predictor(nsp_input, y_mask)
|
||||
|
@ -1314,7 +1334,7 @@ class Vits(BaseTTS):
|
|||
"loss_prosody_enc_spk_rev_classifier": l_pros_speaker,
|
||||
"loss_prosody_enc_emo_classifier": l_pros_emotion,
|
||||
"loss_text_enc_spk_rev_classifier": l_text_speaker,
|
||||
"loss_text_enc_emo_classifier": l_text_emotion
|
||||
"loss_text_enc_emo_classifier": l_text_emotion,
|
||||
}
|
||||
)
|
||||
return outputs
|
||||
|
@ -1373,9 +1393,16 @@ class Vits(BaseTTS):
|
|||
if self.args.use_language_embedding and lid is not None:
|
||||
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
||||
|
||||
if self.args.use_speaker_embedding_as_emotion:
|
||||
eg = g
|
||||
|
||||
# squeezers
|
||||
if self.args.use_emotion_embedding_squeezer:
|
||||
if self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings:
|
||||
if (
|
||||
self.args.use_emotion_embedding
|
||||
or self.args.use_external_emotions_embeddings
|
||||
or self.args.use_speaker_embedding_as_emotion
|
||||
):
|
||||
eg = F.normalize(self.emotion_embedding_squeezer(eg.squeeze(-1))).unsqueeze(-1)
|
||||
|
||||
if self.args.use_speaker_embedding_squeezer:
|
||||
|
@ -1399,13 +1426,12 @@ class Vits(BaseTTS):
|
|||
pros_emb, _ = self.prosody_encoder(z_p_inf, pf_lengths)
|
||||
|
||||
pros_emb = pros_emb.transpose(1, 2)
|
||||
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(
|
||||
x,
|
||||
x_lengths,
|
||||
lang_emb=lang_emb,
|
||||
emo_emb=eg if not self.args.use_noise_scale_predictor else None,
|
||||
pros_emb=pros_emb if not self.args.use_noise_scale_predictor else None
|
||||
pros_emb=pros_emb if not self.args.use_noise_scale_predictor else None,
|
||||
)
|
||||
|
||||
# duration predictor
|
||||
|
@ -1448,10 +1474,14 @@ class Vits(BaseTTS):
|
|||
if self.args.use_noise_scale_predictor:
|
||||
nsp_input = torch.transpose(m_p, 1, -1)
|
||||
if self.args.use_prosody_encoder and pros_emb is not None:
|
||||
nsp_input = torch.cat((nsp_input, pros_emb.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1)
|
||||
nsp_input = torch.cat(
|
||||
(nsp_input, pros_emb.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1
|
||||
)
|
||||
|
||||
if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and eg is not None:
|
||||
nsp_input = torch.cat((nsp_input, eg.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1)
|
||||
nsp_input = torch.cat(
|
||||
(nsp_input, eg.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1
|
||||
)
|
||||
|
||||
nsp_input = torch.transpose(nsp_input, 1, -1) * y_mask
|
||||
m_p_noise_scale = self.noise_scale_predictor(nsp_input, y_mask)
|
||||
|
@ -1521,7 +1551,6 @@ class Vits(BaseTTS):
|
|||
wav, _, _ = self.voice_conversion(y, y_lengths, speaker_cond_src, speaker_cond_tgt)
|
||||
return wav
|
||||
|
||||
|
||||
def voice_conversion(self, y, y_lengths, speaker_cond_src, speaker_cond_tgt):
|
||||
"""Forward pass for voice conversion
|
||||
|
||||
|
@ -1550,7 +1579,6 @@ class Vits(BaseTTS):
|
|||
o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt)
|
||||
return o_hat, y_mask, (z, z_p, z_hat)
|
||||
|
||||
|
||||
def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
||||
"""Perform a single training step. Run the model forward pass and compute losses.
|
||||
|
||||
|
@ -1599,18 +1627,16 @@ class Vits(BaseTTS):
|
|||
self.model_outputs_cache = outputs # pylint: disable=attribute-defined-outside-init
|
||||
|
||||
# compute scores and features
|
||||
scores_disc_fake, _, scores_disc_real, _, scores_disc_mp, _, scores_disc_zp, _= self.disc(
|
||||
outputs["model_outputs"].detach(), outputs["waveform_seg"], outputs["m_p"].detach(), outputs["z_p"].detach()
|
||||
scores_disc_fake, _, scores_disc_real, _, scores_disc_mp, _, scores_disc_zp, _ = self.disc(
|
||||
outputs["model_outputs"].detach(),
|
||||
outputs["waveform_seg"],
|
||||
outputs["m_p"].detach(),
|
||||
outputs["z_p"].detach(),
|
||||
)
|
||||
|
||||
# compute loss
|
||||
with autocast(enabled=False): # use float32 for the criterion
|
||||
loss_dict = criterion[optimizer_idx](
|
||||
scores_disc_real,
|
||||
scores_disc_fake,
|
||||
scores_disc_zp,
|
||||
scores_disc_mp
|
||||
)
|
||||
loss_dict = criterion[optimizer_idx](scores_disc_real, scores_disc_fake, scores_disc_zp, scores_disc_mp)
|
||||
return outputs, loss_dict
|
||||
|
||||
if optimizer_idx == 1:
|
||||
|
@ -1640,8 +1666,20 @@ class Vits(BaseTTS):
|
|||
)
|
||||
|
||||
# compute discriminator scores and features
|
||||
scores_disc_fake, feats_disc_fake, _, feats_disc_real, scores_disc_mp, feats_disc_mp, _, feats_disc_zp = self.disc(
|
||||
self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"], self.model_outputs_cache["m_p"], self.model_outputs_cache["z_p"].detach()
|
||||
(
|
||||
scores_disc_fake,
|
||||
feats_disc_fake,
|
||||
_,
|
||||
feats_disc_real,
|
||||
scores_disc_mp,
|
||||
feats_disc_mp,
|
||||
_,
|
||||
feats_disc_zp,
|
||||
) = self.disc(
|
||||
self.model_outputs_cache["model_outputs"],
|
||||
self.model_outputs_cache["waveform_seg"],
|
||||
self.model_outputs_cache["m_p"],
|
||||
self.model_outputs_cache["z_p"].detach(),
|
||||
)
|
||||
|
||||
# compute losses
|
||||
|
@ -1669,7 +1707,7 @@ class Vits(BaseTTS):
|
|||
loss_text_enc_emo_classifier=self.model_outputs_cache["loss_text_enc_emo_classifier"],
|
||||
scores_disc_mp=scores_disc_mp,
|
||||
feats_disc_mp=feats_disc_mp,
|
||||
feats_disc_zp=feats_disc_zp
|
||||
feats_disc_zp=feats_disc_zp,
|
||||
)
|
||||
|
||||
return self.model_outputs_cache, loss_dict
|
||||
|
@ -1729,7 +1767,14 @@ class Vits(BaseTTS):
|
|||
config = self.config
|
||||
|
||||
# extract speaker and language info
|
||||
text, speaker_name, style_wav, language_name, emotion_name, style_speaker_name = None, None, None, None, None, None
|
||||
text, speaker_name, style_wav, language_name, emotion_name, style_speaker_name = (
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
if isinstance(sentence_info, list):
|
||||
if len(sentence_info) == 1:
|
||||
|
@ -1748,12 +1793,18 @@ class Vits(BaseTTS):
|
|||
text = sentence_info
|
||||
|
||||
if style_wav and style_speaker_name is None:
|
||||
raise RuntimeError(
|
||||
" [!] You must to provide the style_speaker_name for the style_wav !!"
|
||||
)
|
||||
raise RuntimeError(" [!] You must to provide the style_speaker_name for the style_wav !!")
|
||||
|
||||
# get speaker id/d_vector
|
||||
speaker_id, d_vector, language_id, emotion_id, emotion_embedding, style_speaker_id, style_speaker_d_vector = None, None, None, None, None, None, None
|
||||
speaker_id, d_vector, language_id, emotion_id, emotion_embedding, style_speaker_id, style_speaker_d_vector = (
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
if hasattr(self, "speaker_manager"):
|
||||
if config.use_d_vector_file:
|
||||
if speaker_name is None:
|
||||
|
@ -1762,7 +1813,9 @@ class Vits(BaseTTS):
|
|||
d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False)
|
||||
|
||||
if style_wav is not None:
|
||||
style_speaker_d_vector = self.speaker_manager.get_mean_embedding(style_speaker_name, num_samples=None, randomize=False)
|
||||
style_speaker_d_vector = self.speaker_manager.get_mean_embedding(
|
||||
style_speaker_name, num_samples=None, randomize=False
|
||||
)
|
||||
|
||||
elif config.use_speaker_embedding:
|
||||
if speaker_name is None:
|
||||
|
@ -1893,7 +1946,15 @@ class Vits(BaseTTS):
|
|||
emotion_embeddings = [emotion_mapping[w]["embedding"] for w in batch["audio_files"]]
|
||||
emotion_embeddings = torch.FloatTensor(emotion_embeddings)
|
||||
|
||||
if self.emotion_manager is not None and self.emotion_manager.embeddings and (self.args.use_emotion_embedding or self.args.use_prosody_enc_emo_classifier or self.args.use_text_enc_emo_classifier):
|
||||
if (
|
||||
self.emotion_manager is not None
|
||||
and self.emotion_manager.embeddings
|
||||
and (
|
||||
self.args.use_emotion_embedding
|
||||
or self.args.use_prosody_enc_emo_classifier
|
||||
or self.args.use_text_enc_emo_classifier
|
||||
)
|
||||
):
|
||||
emotion_mapping = self.emotion_manager.embeddings
|
||||
emotion_names = [emotion_mapping[w]["name"] for w in batch["audio_files"]]
|
||||
emotion_ids = [self.emotion_manager.ids[en] for en in emotion_names]
|
||||
|
|
|
@ -94,7 +94,11 @@ class EmotionManager(EmbeddingManager):
|
|||
EmotionEncoder: Emotion encoder object.
|
||||
"""
|
||||
emotion_manager = None
|
||||
if get_from_config_or_model_args_with_default(config, "use_emotion_embedding", False) or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False) or get_from_config_or_model_args_with_default(config, "use_text_enc_emo_classifier", False):
|
||||
if (
|
||||
get_from_config_or_model_args_with_default(config, "use_emotion_embedding", False)
|
||||
or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False)
|
||||
or get_from_config_or_model_args_with_default(config, "use_text_enc_emo_classifier", False)
|
||||
):
|
||||
if get_from_config_or_model_args_with_default(config, "emotions_ids_file", None):
|
||||
emotion_manager = EmotionManager(
|
||||
emotion_id_file_path=get_from_config_or_model_args_with_default(config, "emotions_ids_file", None)
|
||||
|
@ -106,7 +110,11 @@ class EmotionManager(EmbeddingManager):
|
|||
)
|
||||
)
|
||||
|
||||
if get_from_config_or_model_args_with_default(config, "use_external_emotions_embeddings", False) or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False) or get_from_config_or_model_args_with_default(config, "use_text_enc_emo_classifier", False):
|
||||
if (
|
||||
get_from_config_or_model_args_with_default(config, "use_external_emotions_embeddings", False)
|
||||
or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False)
|
||||
or get_from_config_or_model_args_with_default(config, "use_text_enc_emo_classifier", False)
|
||||
):
|
||||
if get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None):
|
||||
emotion_manager = EmotionManager(
|
||||
embeddings_file_path=get_from_config_or_model_args_with_default(
|
||||
|
|
|
@ -217,7 +217,6 @@ def synthesis(
|
|||
if style_speaker_d_vector is not None:
|
||||
style_speaker_d_vector = embedding_to_torch(style_speaker_d_vector, cuda=use_cuda)
|
||||
|
||||
|
||||
if not isinstance(style_feature, dict):
|
||||
# GST or Capacitron style mel
|
||||
style_feature = numpy_to_torch(style_feature, torch.float, cuda=use_cuda)
|
||||
|
|
|
@ -306,9 +306,17 @@ class Synthesizer(object):
|
|||
|
||||
# handle emotion
|
||||
emotion_embedding, emotion_id = None, None
|
||||
if not reference_wav and not getattr(self.tts_model, "prosody_encoder", False) and (self.tts_emotions_file or (
|
||||
getattr(self.tts_model, "emotion_manager", None) and getattr(self.tts_model.emotion_manager, "ids", None)
|
||||
)):
|
||||
if (
|
||||
not reference_wav
|
||||
and not getattr(self.tts_model, "prosody_encoder", False)
|
||||
and (
|
||||
self.tts_emotions_file
|
||||
or (
|
||||
getattr(self.tts_model, "emotion_manager", None)
|
||||
and getattr(self.tts_model.emotion_manager, "ids", None)
|
||||
)
|
||||
)
|
||||
):
|
||||
if emotion_name and isinstance(emotion_name, str):
|
||||
if getattr(self.tts_config, "use_external_emotions_embeddings", False) or (
|
||||
getattr(self.tts_config, "model_args", None)
|
||||
|
@ -426,7 +434,7 @@ class Synthesizer(object):
|
|||
d_vector=speaker_embedding,
|
||||
use_griffin_lim=use_gl,
|
||||
reference_speaker_id=reference_speaker_id,
|
||||
reference_d_vector=reference_speaker_embedding
|
||||
reference_d_vector=reference_speaker_embedding,
|
||||
)
|
||||
waveform = outputs
|
||||
if not use_gl:
|
||||
|
|
|
@ -58,6 +58,9 @@ continue_restore_path, _ = get_last_checkpoint(continue_path)
|
|||
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||
speaker_id = "ljspeech-1"
|
||||
continue_speakers_path = os.path.join(continue_path, "speakers.json")
|
||||
if not os.path.isfile(continue_speakers_path):
|
||||
continue_speakers_path = continue_speakers_path.replace(".json", ".pth")
|
||||
|
||||
|
||||
# Check integrity of the config
|
||||
with open(continue_config_path, "r", encoding="utf-8") as f:
|
||||
|
|
|
@ -37,23 +37,16 @@ config.audio.trim_db = 60
|
|||
config.model_args.use_speaker_embedding = False
|
||||
config.model_args.use_d_vector_file = True
|
||||
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
|
||||
config.model_args.speaker_embedding_channels = 128
|
||||
config.model_args.d_vector_dim = 100
|
||||
config.model_args.speaker_embedding_channels = 256
|
||||
config.model_args.d_vector_dim = 256
|
||||
|
||||
# emotion
|
||||
config.model_args.use_external_emotions_embeddings = True
|
||||
config.model_args.use_emotion_embedding = False
|
||||
config.model_args.emotion_embedding_dim = 64
|
||||
config.model_args.emotion_embedding_dim = 256
|
||||
config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"
|
||||
config.model_args.use_text_enc_spk_reversal_classifier = False
|
||||
|
||||
|
||||
config.model_args.use_emotion_embedding_squeezer = True
|
||||
config.model_args.emotion_embedding_squeezer_input_dim = 256
|
||||
|
||||
config.model_args.use_speaker_embedding_squeezer = True
|
||||
config.model_args.speaker_embedding_squeezer_input_dim = 256
|
||||
|
||||
# consistency loss
|
||||
# config.model_args.use_emotion_encoder_as_loss = True
|
||||
# config.model_args.encoder_model_path = "/raid/edresson/dev/Checkpoints/Coqui-Realesead/tts_models--multilingual--multi-dataset--your_tts/model_se.pth.tar"
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
import glob
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from trainer import get_last_checkpoint
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
|
||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||
|
||||
|
||||
config = VitsConfig(
|
||||
batch_size=2,
|
||||
eval_batch_size=2,
|
||||
num_loader_workers=0,
|
||||
num_eval_loader_workers=0,
|
||||
text_cleaner="english_cleaners",
|
||||
use_phonemes=True,
|
||||
phoneme_language="en-us",
|
||||
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
|
||||
run_eval=True,
|
||||
test_delay_epochs=-1,
|
||||
epochs=1,
|
||||
print_step=1,
|
||||
print_eval=True,
|
||||
test_sentences=[
|
||||
["Be a voice, not an echo.", "ljspeech-1", None, None, "ljspeech-1"],
|
||||
],
|
||||
)
|
||||
# set audio config
|
||||
config.audio.do_trim_silence = True
|
||||
config.audio.trim_db = 60
|
||||
|
||||
# active multispeaker d-vec mode
|
||||
config.model_args.use_speaker_embedding = False
|
||||
config.model_args.use_d_vector_file = True
|
||||
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
|
||||
config.model_args.speaker_embedding_channels = 128
|
||||
config.model_args.d_vector_dim = 256
|
||||
|
||||
# emotion
|
||||
config.model_args.emotion_embedding_dim = 256
|
||||
|
||||
config.model_args.use_emotion_embedding_squeezer = False
|
||||
config.model_args.emotion_embedding_squeezer_input_dim = 256
|
||||
config.model_args.use_speaker_embedding_as_emotion = True
|
||||
|
||||
config.model_args.use_speaker_embedding_squeezer = False
|
||||
config.model_args.speaker_embedding_squeezer_input_dim = 256
|
||||
|
||||
# consistency loss
|
||||
# config.model_args.use_emotion_encoder_as_loss = True
|
||||
# config.model_args.encoder_model_path = "/raid/edresson/dev/Checkpoints/Coqui-Realesead/tts_models--multilingual--multi-dataset--your_tts/model_se.pth.tar"
|
||||
# config.model_args.encoder_config_path = "/raid/edresson/dev/Checkpoints/Coqui-Realesead/tts_models--multilingual--multi-dataset--your_tts/config_se.json"
|
||||
|
||||
config.save_json(config_path)
|
||||
|
||||
# train the model for one epoch
|
||||
command_train = (
|
||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||
f"--coqpit.output_path {output_path} "
|
||||
"--coqpit.datasets.0.name ljspeech_test "
|
||||
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
|
||||
"--coqpit.test_delay_epochs 0"
|
||||
)
|
||||
run_cli(command_train)
|
||||
|
||||
# Find latest folder
|
||||
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||
|
||||
# Inference using TTS API
|
||||
continue_config_path = os.path.join(continue_path, "config.json")
|
||||
continue_restore_path, _ = get_last_checkpoint(continue_path)
|
||||
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||
speaker_id = "ljspeech-1"
|
||||
continue_speakers_path = config.model_args.d_vector_file
|
||||
|
||||
|
||||
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
|
||||
run_cli(inference_command)
|
||||
|
||||
# restore the model and continue training for one more epoch
|
||||
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||
run_cli(command_train)
|
||||
shutil.rmtree(continue_path)
|
Loading…
Reference in New Issue