mirror of https://github.com/coqui-ai/TTS.git
Support the use of speaker embedding as emotion embedding
This commit is contained in:
parent
360b969c23
commit
4b59f07946
|
@ -7,8 +7,7 @@ from tqdm import tqdm
|
||||||
|
|
||||||
from TTS.config import load_config
|
from TTS.config import load_config
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.utils.managers import save_file
|
from TTS.tts.utils.managers import EmbeddingManager, save_file
|
||||||
from TTS.tts.utils.managers import EmbeddingManager
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="""Compute embedding vectors for each wav file in a dataset.\n\n"""
|
description="""Compute embedding vectors for each wav file in a dataset.\n\n"""
|
||||||
|
|
|
@ -179,7 +179,12 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
parser.add_argument("--style_wav", type=str, help="Wav path file for prosody reference.", default=None)
|
parser.add_argument("--style_wav", type=str, help="Wav path file for prosody reference.", default=None)
|
||||||
parser.add_argument("--style_speaker_name", type=str, help="The speaker name from the style_wav. If not provide the speaker embedding will be computed using the speaker encoder.", default=None)
|
parser.add_argument(
|
||||||
|
"--style_speaker_name",
|
||||||
|
type=str,
|
||||||
|
help="The speaker name from the style_wav. If not provide the speaker embedding will be computed using the speaker encoder.",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--capacitron_style_text", type=str, help="Transcription of the style_wav reference.", default=None
|
"--capacitron_style_text", type=str, help="Transcription of the style_wav reference.", default=None
|
||||||
)
|
)
|
||||||
|
|
|
@ -32,7 +32,9 @@ class ReversalClassifier(nn.Module):
|
||||||
reversal (bool): If True reversal the gradients. Default: True
|
reversal (bool): If True reversal the gradients. Default: True
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, hidden_channels, gradient_clipping_bounds=0.25, scale_factor=1.0, reversal=True):
|
def __init__(
|
||||||
|
self, in_channels, out_channels, hidden_channels, gradient_clipping_bounds=0.25, scale_factor=1.0, reversal=True
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.reversal = reversal
|
self.reversal = reversal
|
||||||
self._lambda = scale_factor
|
self._lambda = scale_factor
|
||||||
|
|
|
@ -9,13 +9,8 @@ from torch.nn import functional
|
||||||
from TTS.tts.utils.helpers import sequence_mask
|
from TTS.tts.utils.helpers import sequence_mask
|
||||||
from TTS.tts.utils.ssim import ssim
|
from TTS.tts.utils.ssim import ssim
|
||||||
from TTS.utils.audio import TorchSTFT
|
from TTS.utils.audio import TorchSTFT
|
||||||
from TTS.vocoder.layers.losses import (
|
from TTS.vocoder.layers.losses import MelganFeatureLoss, MSEDLoss, MSEGLoss, _apply_D_loss, _apply_G_adv_loss
|
||||||
MelganFeatureLoss,
|
|
||||||
MSEDLoss,
|
|
||||||
MSEGLoss,
|
|
||||||
_apply_D_loss,
|
|
||||||
_apply_G_adv_loss,
|
|
||||||
)
|
|
||||||
|
|
||||||
# pylint: disable=abstract-method
|
# pylint: disable=abstract-method
|
||||||
# relates https://github.com/pytorch/pytorch/issues/42305
|
# relates https://github.com/pytorch/pytorch/issues/42305
|
||||||
|
@ -730,7 +725,6 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
loss += loss_prosody_enc_emo_classifier
|
loss += loss_prosody_enc_emo_classifier
|
||||||
return_dict["loss_prosody_enc_emo_classifier"] = loss_prosody_enc_emo_classifier
|
return_dict["loss_prosody_enc_emo_classifier"] = loss_prosody_enc_emo_classifier
|
||||||
|
|
||||||
|
|
||||||
if loss_text_enc_spk_rev_classifier is not None:
|
if loss_text_enc_spk_rev_classifier is not None:
|
||||||
loss_text_enc_spk_rev_classifier = loss_text_enc_spk_rev_classifier * self.speaker_classifier_alpha
|
loss_text_enc_spk_rev_classifier = loss_text_enc_spk_rev_classifier * self.speaker_classifier_alpha
|
||||||
loss += loss_text_enc_spk_rev_classifier
|
loss += loss_text_enc_spk_rev_classifier
|
||||||
|
@ -779,7 +773,6 @@ class VitsDiscriminatorLoss(nn.Module):
|
||||||
self.disc_latent_loss_alpha = c.disc_latent_loss_alpha
|
self.disc_latent_loss_alpha = c.disc_latent_loss_alpha
|
||||||
self.disc_latent_gan_loss = MSEDLoss()
|
self.disc_latent_gan_loss = MSEDLoss()
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def discriminator_loss(scores_real, scores_fake):
|
def discriminator_loss(scores_real, scores_fake):
|
||||||
loss = 0
|
loss = 0
|
||||||
|
|
|
@ -109,7 +109,9 @@ class LatentDiscriminator(nn.Module):
|
||||||
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
|
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
|
||||||
self.discriminators = nn.ModuleList(
|
self.discriminators = nn.ModuleList(
|
||||||
[
|
[
|
||||||
norm_f(nn.Conv2d(1 if hidden_channels is None else hidden_channels, 32, kernel_size=(3, 9), padding=(1, 4))),
|
norm_f(
|
||||||
|
nn.Conv2d(1 if hidden_channels is None else hidden_channels, 32, kernel_size=(3, 9), padding=(1, 4))
|
||||||
|
),
|
||||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
||||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
||||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from TTS.tts.layers.tacotron.gst_layers import GST
|
|
||||||
from TTS.tts.layers.tacotron.capacitron_layers import CapacitronVAE
|
from TTS.tts.layers.tacotron.capacitron_layers import CapacitronVAE
|
||||||
|
from TTS.tts.layers.tacotron.gst_layers import GST
|
||||||
|
|
||||||
|
|
||||||
class VitsGST(GST):
|
class VitsGST(GST):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
@ -9,6 +10,7 @@ class VitsGST(GST):
|
||||||
style_embed = super().forward(inputs, speaker_embedding=speaker_embedding)
|
style_embed = super().forward(inputs, speaker_embedding=speaker_embedding)
|
||||||
return style_embed, None
|
return style_embed, None
|
||||||
|
|
||||||
|
|
||||||
class VitsVAE(CapacitronVAE):
|
class VitsVAE(CapacitronVAE):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
|
@ -19,12 +19,11 @@ from TTS.tts.configs.shared_configs import CharactersConfig
|
||||||
from TTS.tts.datasets.dataset import TTSDataset, _parse_sample
|
from TTS.tts.datasets.dataset import TTSDataset, _parse_sample
|
||||||
from TTS.tts.layers.generic.classifier import ReversalClassifier
|
from TTS.tts.layers.generic.classifier import ReversalClassifier
|
||||||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||||
from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE
|
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
||||||
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||||
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
||||||
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE
|
||||||
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
||||||
|
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.tts.utils.emotions import EmotionManager
|
from TTS.tts.utils.emotions import EmotionManager
|
||||||
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
|
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
|
||||||
|
@ -37,7 +36,6 @@ from TTS.tts.utils.visual import plot_alignment
|
||||||
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||||
from TTS.vocoder.utils.generic_utils import plot_results
|
from TTS.vocoder.utils.generic_utils import plot_results
|
||||||
|
|
||||||
|
|
||||||
##############################
|
##############################
|
||||||
# IO / Feature extraction
|
# IO / Feature extraction
|
||||||
##############################
|
##############################
|
||||||
|
@ -684,7 +682,11 @@ class Vits(BaseTTS):
|
||||||
dp_cond_embedding_dim += self.args.prosody_embedding_dim
|
dp_cond_embedding_dim += self.args.prosody_embedding_dim
|
||||||
|
|
||||||
dp_extra_inp_dim = 0
|
dp_extra_inp_dim = 0
|
||||||
if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings or self.args.use_speaker_embedding_as_emotion) and not self.args.use_noise_scale_predictor:
|
if (
|
||||||
|
self.args.use_emotion_embedding
|
||||||
|
or self.args.use_external_emotions_embeddings
|
||||||
|
or self.args.use_speaker_embedding_as_emotion
|
||||||
|
) and not self.args.use_noise_scale_predictor:
|
||||||
dp_extra_inp_dim += self.args.emotion_embedding_dim
|
dp_extra_inp_dim += self.args.emotion_embedding_dim
|
||||||
|
|
||||||
if self.args.use_prosody_encoder and not self.args.use_noise_scale_predictor:
|
if self.args.use_prosody_encoder and not self.args.use_noise_scale_predictor:
|
||||||
|
@ -711,22 +713,22 @@ class Vits(BaseTTS):
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.use_prosody_encoder:
|
if self.args.use_prosody_encoder:
|
||||||
if self.args.prosody_encoder_type == 'gst':
|
if self.args.prosody_encoder_type == "gst":
|
||||||
self.prosody_encoder = VitsGST(
|
self.prosody_encoder = VitsGST(
|
||||||
num_mel=self.args.hidden_channels,
|
num_mel=self.args.hidden_channels,
|
||||||
num_heads=self.args.prosody_encoder_num_heads,
|
num_heads=self.args.prosody_encoder_num_heads,
|
||||||
num_style_tokens=self.args.prosody_encoder_num_tokens,
|
num_style_tokens=self.args.prosody_encoder_num_tokens,
|
||||||
gst_embedding_dim=self.args.prosody_embedding_dim,
|
gst_embedding_dim=self.args.prosody_embedding_dim,
|
||||||
)
|
)
|
||||||
elif self.args.prosody_encoder_type == 'vae':
|
elif self.args.prosody_encoder_type == "vae":
|
||||||
self.prosody_encoder = VitsVAE(
|
self.prosody_encoder = VitsVAE(
|
||||||
num_mel=self.args.hidden_channels,
|
num_mel=self.args.hidden_channels,
|
||||||
capacitron_VAE_embedding_dim=self.args.prosody_embedding_dim,
|
capacitron_VAE_embedding_dim=self.args.prosody_embedding_dim,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f" [!] The Prosody encoder type {self.args.prosody_encoder_type} is not supported !!"
|
f" [!] The Prosody encoder type {self.args.prosody_encoder_type} is not supported !!"
|
||||||
)
|
)
|
||||||
if self.args.use_prosody_enc_spk_reversal_classifier:
|
if self.args.use_prosody_enc_spk_reversal_classifier:
|
||||||
self.speaker_reversal_classifier = ReversalClassifier(
|
self.speaker_reversal_classifier = ReversalClassifier(
|
||||||
in_channels=self.args.prosody_embedding_dim,
|
in_channels=self.args.prosody_embedding_dim,
|
||||||
|
@ -738,12 +740,16 @@ class Vits(BaseTTS):
|
||||||
in_channels=self.args.prosody_embedding_dim,
|
in_channels=self.args.prosody_embedding_dim,
|
||||||
out_channels=self.num_emotions,
|
out_channels=self.num_emotions,
|
||||||
hidden_channels=256,
|
hidden_channels=256,
|
||||||
reversal=False
|
reversal=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.use_noise_scale_predictor:
|
if self.args.use_noise_scale_predictor:
|
||||||
noise_scale_predictor_input_dim = self.args.hidden_channels
|
noise_scale_predictor_input_dim = self.args.hidden_channels
|
||||||
if self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings:
|
if (
|
||||||
|
self.args.use_emotion_embedding
|
||||||
|
or self.args.use_external_emotions_embeddings
|
||||||
|
or self.args.use_speaker_embedding_as_emotion
|
||||||
|
):
|
||||||
noise_scale_predictor_input_dim += self.args.emotion_embedding_dim
|
noise_scale_predictor_input_dim += self.args.emotion_embedding_dim
|
||||||
|
|
||||||
if self.args.use_prosody_encoder:
|
if self.args.use_prosody_encoder:
|
||||||
|
@ -763,15 +769,18 @@ class Vits(BaseTTS):
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.use_emotion_embedding_squeezer:
|
if self.args.use_emotion_embedding_squeezer:
|
||||||
self.emotion_embedding_squeezer = nn.Linear(in_features=self.args.emotion_embedding_squeezer_input_dim, out_features=self.args.emotion_embedding_dim)
|
self.emotion_embedding_squeezer = nn.Linear(
|
||||||
|
in_features=self.args.emotion_embedding_squeezer_input_dim, out_features=self.args.emotion_embedding_dim
|
||||||
|
)
|
||||||
|
|
||||||
if self.args.use_speaker_embedding_squeezer:
|
if self.args.use_speaker_embedding_squeezer:
|
||||||
self.speaker_embedding_squeezer = nn.Linear(in_features=self.args.speaker_embedding_squeezer_input_dim, out_features=self.cond_embedding_dim)
|
self.speaker_embedding_squeezer = nn.Linear(
|
||||||
|
in_features=self.args.speaker_embedding_squeezer_input_dim, out_features=self.cond_embedding_dim
|
||||||
|
)
|
||||||
|
|
||||||
if self.args.use_text_enc_spk_reversal_classifier:
|
if self.args.use_text_enc_spk_reversal_classifier:
|
||||||
self.speaker_text_enc_reversal_classifier = ReversalClassifier(
|
self.speaker_text_enc_reversal_classifier = ReversalClassifier(
|
||||||
in_channels=self.args.hidden_channels
|
in_channels=self.args.hidden_channels + dp_extra_inp_dim,
|
||||||
+ dp_extra_inp_dim,
|
|
||||||
out_channels=self.num_speakers,
|
out_channels=self.num_speakers,
|
||||||
hidden_channels=256,
|
hidden_channels=256,
|
||||||
)
|
)
|
||||||
|
@ -781,7 +790,7 @@ class Vits(BaseTTS):
|
||||||
in_channels=self.args.hidden_channels,
|
in_channels=self.args.hidden_channels,
|
||||||
out_channels=self.num_emotions,
|
out_channels=self.num_emotions,
|
||||||
hidden_channels=256,
|
hidden_channels=256,
|
||||||
reversal=False
|
reversal=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.waveform_decoder = HifiganGenerator(
|
self.waveform_decoder = HifiganGenerator(
|
||||||
|
@ -1176,9 +1185,16 @@ class Vits(BaseTTS):
|
||||||
if self.args.use_language_embedding and lid is not None:
|
if self.args.use_language_embedding and lid is not None:
|
||||||
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
||||||
|
|
||||||
|
if self.args.use_speaker_embedding_as_emotion:
|
||||||
|
eg = g
|
||||||
|
|
||||||
# squeezers
|
# squeezers
|
||||||
if self.args.use_emotion_embedding_squeezer:
|
if self.args.use_emotion_embedding_squeezer:
|
||||||
if self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings:
|
if (
|
||||||
|
self.args.use_emotion_embedding
|
||||||
|
or self.args.use_external_emotions_embeddings
|
||||||
|
or self.args.use_speaker_embedding_as_emotion
|
||||||
|
):
|
||||||
eg = F.normalize(self.emotion_embedding_squeezer(eg.squeeze(-1))).unsqueeze(-1)
|
eg = F.normalize(self.emotion_embedding_squeezer(eg.squeeze(-1))).unsqueeze(-1)
|
||||||
|
|
||||||
if self.args.use_speaker_embedding_squeezer:
|
if self.args.use_speaker_embedding_squeezer:
|
||||||
|
@ -1200,7 +1216,7 @@ class Vits(BaseTTS):
|
||||||
prosody_encoder_input = z_p if self.args.use_prosody_encoder_z_p_input else z
|
prosody_encoder_input = z_p if self.args.use_prosody_encoder_z_p_input else z
|
||||||
pros_emb, vae_outputs = self.prosody_encoder(
|
pros_emb, vae_outputs = self.prosody_encoder(
|
||||||
prosody_encoder_input.detach() if self.args.detach_prosody_enc_input else prosody_encoder_input,
|
prosody_encoder_input.detach() if self.args.detach_prosody_enc_input else prosody_encoder_input,
|
||||||
y_lengths
|
y_lengths,
|
||||||
)
|
)
|
||||||
|
|
||||||
pros_emb = pros_emb.transpose(1, 2)
|
pros_emb = pros_emb.transpose(1, 2)
|
||||||
|
@ -1215,7 +1231,7 @@ class Vits(BaseTTS):
|
||||||
x_lengths,
|
x_lengths,
|
||||||
lang_emb=lang_emb,
|
lang_emb=lang_emb,
|
||||||
emo_emb=eg if not self.args.use_noise_scale_predictor else None,
|
emo_emb=eg if not self.args.use_noise_scale_predictor else None,
|
||||||
pros_emb=pros_emb if not self.args.use_noise_scale_predictor else None
|
pros_emb=pros_emb if not self.args.use_noise_scale_predictor else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# reversal speaker loss to force the encoder to be speaker identity free
|
# reversal speaker loss to force the encoder to be speaker identity free
|
||||||
|
@ -1250,10 +1266,14 @@ class Vits(BaseTTS):
|
||||||
if self.args.use_noise_scale_predictor:
|
if self.args.use_noise_scale_predictor:
|
||||||
nsp_input = torch.transpose(m_p_expanded, 1, -1)
|
nsp_input = torch.transpose(m_p_expanded, 1, -1)
|
||||||
if self.args.use_prosody_encoder and pros_emb is not None:
|
if self.args.use_prosody_encoder and pros_emb is not None:
|
||||||
nsp_input = torch.cat((nsp_input, pros_emb.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1)
|
nsp_input = torch.cat(
|
||||||
|
(nsp_input, pros_emb.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and eg is not None:
|
if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and eg is not None:
|
||||||
nsp_input = torch.cat((nsp_input, eg.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1)
|
nsp_input = torch.cat(
|
||||||
|
(nsp_input, eg.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
nsp_input = torch.transpose(nsp_input, 1, -1) * y_mask
|
nsp_input = torch.transpose(nsp_input, 1, -1) * y_mask
|
||||||
m_p_noise_scale = self.noise_scale_predictor(nsp_input, y_mask)
|
m_p_noise_scale = self.noise_scale_predictor(nsp_input, y_mask)
|
||||||
|
@ -1314,7 +1334,7 @@ class Vits(BaseTTS):
|
||||||
"loss_prosody_enc_spk_rev_classifier": l_pros_speaker,
|
"loss_prosody_enc_spk_rev_classifier": l_pros_speaker,
|
||||||
"loss_prosody_enc_emo_classifier": l_pros_emotion,
|
"loss_prosody_enc_emo_classifier": l_pros_emotion,
|
||||||
"loss_text_enc_spk_rev_classifier": l_text_speaker,
|
"loss_text_enc_spk_rev_classifier": l_text_speaker,
|
||||||
"loss_text_enc_emo_classifier": l_text_emotion
|
"loss_text_enc_emo_classifier": l_text_emotion,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return outputs
|
return outputs
|
||||||
|
@ -1373,9 +1393,16 @@ class Vits(BaseTTS):
|
||||||
if self.args.use_language_embedding and lid is not None:
|
if self.args.use_language_embedding and lid is not None:
|
||||||
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
||||||
|
|
||||||
|
if self.args.use_speaker_embedding_as_emotion:
|
||||||
|
eg = g
|
||||||
|
|
||||||
# squeezers
|
# squeezers
|
||||||
if self.args.use_emotion_embedding_squeezer:
|
if self.args.use_emotion_embedding_squeezer:
|
||||||
if self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings:
|
if (
|
||||||
|
self.args.use_emotion_embedding
|
||||||
|
or self.args.use_external_emotions_embeddings
|
||||||
|
or self.args.use_speaker_embedding_as_emotion
|
||||||
|
):
|
||||||
eg = F.normalize(self.emotion_embedding_squeezer(eg.squeeze(-1))).unsqueeze(-1)
|
eg = F.normalize(self.emotion_embedding_squeezer(eg.squeeze(-1))).unsqueeze(-1)
|
||||||
|
|
||||||
if self.args.use_speaker_embedding_squeezer:
|
if self.args.use_speaker_embedding_squeezer:
|
||||||
|
@ -1399,13 +1426,12 @@ class Vits(BaseTTS):
|
||||||
pros_emb, _ = self.prosody_encoder(z_p_inf, pf_lengths)
|
pros_emb, _ = self.prosody_encoder(z_p_inf, pf_lengths)
|
||||||
|
|
||||||
pros_emb = pros_emb.transpose(1, 2)
|
pros_emb = pros_emb.transpose(1, 2)
|
||||||
|
|
||||||
x, m_p, logs_p, x_mask = self.text_encoder(
|
x, m_p, logs_p, x_mask = self.text_encoder(
|
||||||
x,
|
x,
|
||||||
x_lengths,
|
x_lengths,
|
||||||
lang_emb=lang_emb,
|
lang_emb=lang_emb,
|
||||||
emo_emb=eg if not self.args.use_noise_scale_predictor else None,
|
emo_emb=eg if not self.args.use_noise_scale_predictor else None,
|
||||||
pros_emb=pros_emb if not self.args.use_noise_scale_predictor else None
|
pros_emb=pros_emb if not self.args.use_noise_scale_predictor else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# duration predictor
|
# duration predictor
|
||||||
|
@ -1448,10 +1474,14 @@ class Vits(BaseTTS):
|
||||||
if self.args.use_noise_scale_predictor:
|
if self.args.use_noise_scale_predictor:
|
||||||
nsp_input = torch.transpose(m_p, 1, -1)
|
nsp_input = torch.transpose(m_p, 1, -1)
|
||||||
if self.args.use_prosody_encoder and pros_emb is not None:
|
if self.args.use_prosody_encoder and pros_emb is not None:
|
||||||
nsp_input = torch.cat((nsp_input, pros_emb.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1)
|
nsp_input = torch.cat(
|
||||||
|
(nsp_input, pros_emb.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and eg is not None:
|
if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and eg is not None:
|
||||||
nsp_input = torch.cat((nsp_input, eg.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1)
|
nsp_input = torch.cat(
|
||||||
|
(nsp_input, eg.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
nsp_input = torch.transpose(nsp_input, 1, -1) * y_mask
|
nsp_input = torch.transpose(nsp_input, 1, -1) * y_mask
|
||||||
m_p_noise_scale = self.noise_scale_predictor(nsp_input, y_mask)
|
m_p_noise_scale = self.noise_scale_predictor(nsp_input, y_mask)
|
||||||
|
@ -1521,7 +1551,6 @@ class Vits(BaseTTS):
|
||||||
wav, _, _ = self.voice_conversion(y, y_lengths, speaker_cond_src, speaker_cond_tgt)
|
wav, _, _ = self.voice_conversion(y, y_lengths, speaker_cond_src, speaker_cond_tgt)
|
||||||
return wav
|
return wav
|
||||||
|
|
||||||
|
|
||||||
def voice_conversion(self, y, y_lengths, speaker_cond_src, speaker_cond_tgt):
|
def voice_conversion(self, y, y_lengths, speaker_cond_src, speaker_cond_tgt):
|
||||||
"""Forward pass for voice conversion
|
"""Forward pass for voice conversion
|
||||||
|
|
||||||
|
@ -1550,7 +1579,6 @@ class Vits(BaseTTS):
|
||||||
o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt)
|
o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt)
|
||||||
return o_hat, y_mask, (z, z_p, z_hat)
|
return o_hat, y_mask, (z, z_p, z_hat)
|
||||||
|
|
||||||
|
|
||||||
def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
||||||
"""Perform a single training step. Run the model forward pass and compute losses.
|
"""Perform a single training step. Run the model forward pass and compute losses.
|
||||||
|
|
||||||
|
@ -1599,18 +1627,16 @@ class Vits(BaseTTS):
|
||||||
self.model_outputs_cache = outputs # pylint: disable=attribute-defined-outside-init
|
self.model_outputs_cache = outputs # pylint: disable=attribute-defined-outside-init
|
||||||
|
|
||||||
# compute scores and features
|
# compute scores and features
|
||||||
scores_disc_fake, _, scores_disc_real, _, scores_disc_mp, _, scores_disc_zp, _= self.disc(
|
scores_disc_fake, _, scores_disc_real, _, scores_disc_mp, _, scores_disc_zp, _ = self.disc(
|
||||||
outputs["model_outputs"].detach(), outputs["waveform_seg"], outputs["m_p"].detach(), outputs["z_p"].detach()
|
outputs["model_outputs"].detach(),
|
||||||
|
outputs["waveform_seg"],
|
||||||
|
outputs["m_p"].detach(),
|
||||||
|
outputs["z_p"].detach(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# compute loss
|
# compute loss
|
||||||
with autocast(enabled=False): # use float32 for the criterion
|
with autocast(enabled=False): # use float32 for the criterion
|
||||||
loss_dict = criterion[optimizer_idx](
|
loss_dict = criterion[optimizer_idx](scores_disc_real, scores_disc_fake, scores_disc_zp, scores_disc_mp)
|
||||||
scores_disc_real,
|
|
||||||
scores_disc_fake,
|
|
||||||
scores_disc_zp,
|
|
||||||
scores_disc_mp
|
|
||||||
)
|
|
||||||
return outputs, loss_dict
|
return outputs, loss_dict
|
||||||
|
|
||||||
if optimizer_idx == 1:
|
if optimizer_idx == 1:
|
||||||
|
@ -1640,8 +1666,20 @@ class Vits(BaseTTS):
|
||||||
)
|
)
|
||||||
|
|
||||||
# compute discriminator scores and features
|
# compute discriminator scores and features
|
||||||
scores_disc_fake, feats_disc_fake, _, feats_disc_real, scores_disc_mp, feats_disc_mp, _, feats_disc_zp = self.disc(
|
(
|
||||||
self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"], self.model_outputs_cache["m_p"], self.model_outputs_cache["z_p"].detach()
|
scores_disc_fake,
|
||||||
|
feats_disc_fake,
|
||||||
|
_,
|
||||||
|
feats_disc_real,
|
||||||
|
scores_disc_mp,
|
||||||
|
feats_disc_mp,
|
||||||
|
_,
|
||||||
|
feats_disc_zp,
|
||||||
|
) = self.disc(
|
||||||
|
self.model_outputs_cache["model_outputs"],
|
||||||
|
self.model_outputs_cache["waveform_seg"],
|
||||||
|
self.model_outputs_cache["m_p"],
|
||||||
|
self.model_outputs_cache["z_p"].detach(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# compute losses
|
# compute losses
|
||||||
|
@ -1669,7 +1707,7 @@ class Vits(BaseTTS):
|
||||||
loss_text_enc_emo_classifier=self.model_outputs_cache["loss_text_enc_emo_classifier"],
|
loss_text_enc_emo_classifier=self.model_outputs_cache["loss_text_enc_emo_classifier"],
|
||||||
scores_disc_mp=scores_disc_mp,
|
scores_disc_mp=scores_disc_mp,
|
||||||
feats_disc_mp=feats_disc_mp,
|
feats_disc_mp=feats_disc_mp,
|
||||||
feats_disc_zp=feats_disc_zp
|
feats_disc_zp=feats_disc_zp,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.model_outputs_cache, loss_dict
|
return self.model_outputs_cache, loss_dict
|
||||||
|
@ -1729,7 +1767,14 @@ class Vits(BaseTTS):
|
||||||
config = self.config
|
config = self.config
|
||||||
|
|
||||||
# extract speaker and language info
|
# extract speaker and language info
|
||||||
text, speaker_name, style_wav, language_name, emotion_name, style_speaker_name = None, None, None, None, None, None
|
text, speaker_name, style_wav, language_name, emotion_name, style_speaker_name = (
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(sentence_info, list):
|
if isinstance(sentence_info, list):
|
||||||
if len(sentence_info) == 1:
|
if len(sentence_info) == 1:
|
||||||
|
@ -1748,12 +1793,18 @@ class Vits(BaseTTS):
|
||||||
text = sentence_info
|
text = sentence_info
|
||||||
|
|
||||||
if style_wav and style_speaker_name is None:
|
if style_wav and style_speaker_name is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(" [!] You must to provide the style_speaker_name for the style_wav !!")
|
||||||
" [!] You must to provide the style_speaker_name for the style_wav !!"
|
|
||||||
)
|
|
||||||
|
|
||||||
# get speaker id/d_vector
|
# get speaker id/d_vector
|
||||||
speaker_id, d_vector, language_id, emotion_id, emotion_embedding, style_speaker_id, style_speaker_d_vector = None, None, None, None, None, None, None
|
speaker_id, d_vector, language_id, emotion_id, emotion_embedding, style_speaker_id, style_speaker_d_vector = (
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
if hasattr(self, "speaker_manager"):
|
if hasattr(self, "speaker_manager"):
|
||||||
if config.use_d_vector_file:
|
if config.use_d_vector_file:
|
||||||
if speaker_name is None:
|
if speaker_name is None:
|
||||||
|
@ -1762,7 +1813,9 @@ class Vits(BaseTTS):
|
||||||
d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False)
|
d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False)
|
||||||
|
|
||||||
if style_wav is not None:
|
if style_wav is not None:
|
||||||
style_speaker_d_vector = self.speaker_manager.get_mean_embedding(style_speaker_name, num_samples=None, randomize=False)
|
style_speaker_d_vector = self.speaker_manager.get_mean_embedding(
|
||||||
|
style_speaker_name, num_samples=None, randomize=False
|
||||||
|
)
|
||||||
|
|
||||||
elif config.use_speaker_embedding:
|
elif config.use_speaker_embedding:
|
||||||
if speaker_name is None:
|
if speaker_name is None:
|
||||||
|
@ -1893,7 +1946,15 @@ class Vits(BaseTTS):
|
||||||
emotion_embeddings = [emotion_mapping[w]["embedding"] for w in batch["audio_files"]]
|
emotion_embeddings = [emotion_mapping[w]["embedding"] for w in batch["audio_files"]]
|
||||||
emotion_embeddings = torch.FloatTensor(emotion_embeddings)
|
emotion_embeddings = torch.FloatTensor(emotion_embeddings)
|
||||||
|
|
||||||
if self.emotion_manager is not None and self.emotion_manager.embeddings and (self.args.use_emotion_embedding or self.args.use_prosody_enc_emo_classifier or self.args.use_text_enc_emo_classifier):
|
if (
|
||||||
|
self.emotion_manager is not None
|
||||||
|
and self.emotion_manager.embeddings
|
||||||
|
and (
|
||||||
|
self.args.use_emotion_embedding
|
||||||
|
or self.args.use_prosody_enc_emo_classifier
|
||||||
|
or self.args.use_text_enc_emo_classifier
|
||||||
|
)
|
||||||
|
):
|
||||||
emotion_mapping = self.emotion_manager.embeddings
|
emotion_mapping = self.emotion_manager.embeddings
|
||||||
emotion_names = [emotion_mapping[w]["name"] for w in batch["audio_files"]]
|
emotion_names = [emotion_mapping[w]["name"] for w in batch["audio_files"]]
|
||||||
emotion_ids = [self.emotion_manager.ids[en] for en in emotion_names]
|
emotion_ids = [self.emotion_manager.ids[en] for en in emotion_names]
|
||||||
|
|
|
@ -94,7 +94,11 @@ class EmotionManager(EmbeddingManager):
|
||||||
EmotionEncoder: Emotion encoder object.
|
EmotionEncoder: Emotion encoder object.
|
||||||
"""
|
"""
|
||||||
emotion_manager = None
|
emotion_manager = None
|
||||||
if get_from_config_or_model_args_with_default(config, "use_emotion_embedding", False) or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False) or get_from_config_or_model_args_with_default(config, "use_text_enc_emo_classifier", False):
|
if (
|
||||||
|
get_from_config_or_model_args_with_default(config, "use_emotion_embedding", False)
|
||||||
|
or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False)
|
||||||
|
or get_from_config_or_model_args_with_default(config, "use_text_enc_emo_classifier", False)
|
||||||
|
):
|
||||||
if get_from_config_or_model_args_with_default(config, "emotions_ids_file", None):
|
if get_from_config_or_model_args_with_default(config, "emotions_ids_file", None):
|
||||||
emotion_manager = EmotionManager(
|
emotion_manager = EmotionManager(
|
||||||
emotion_id_file_path=get_from_config_or_model_args_with_default(config, "emotions_ids_file", None)
|
emotion_id_file_path=get_from_config_or_model_args_with_default(config, "emotions_ids_file", None)
|
||||||
|
@ -106,7 +110,11 @@ class EmotionManager(EmbeddingManager):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if get_from_config_or_model_args_with_default(config, "use_external_emotions_embeddings", False) or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False) or get_from_config_or_model_args_with_default(config, "use_text_enc_emo_classifier", False):
|
if (
|
||||||
|
get_from_config_or_model_args_with_default(config, "use_external_emotions_embeddings", False)
|
||||||
|
or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False)
|
||||||
|
or get_from_config_or_model_args_with_default(config, "use_text_enc_emo_classifier", False)
|
||||||
|
):
|
||||||
if get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None):
|
if get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None):
|
||||||
emotion_manager = EmotionManager(
|
emotion_manager = EmotionManager(
|
||||||
embeddings_file_path=get_from_config_or_model_args_with_default(
|
embeddings_file_path=get_from_config_or_model_args_with_default(
|
||||||
|
|
|
@ -217,7 +217,6 @@ def synthesis(
|
||||||
if style_speaker_d_vector is not None:
|
if style_speaker_d_vector is not None:
|
||||||
style_speaker_d_vector = embedding_to_torch(style_speaker_d_vector, cuda=use_cuda)
|
style_speaker_d_vector = embedding_to_torch(style_speaker_d_vector, cuda=use_cuda)
|
||||||
|
|
||||||
|
|
||||||
if not isinstance(style_feature, dict):
|
if not isinstance(style_feature, dict):
|
||||||
# GST or Capacitron style mel
|
# GST or Capacitron style mel
|
||||||
style_feature = numpy_to_torch(style_feature, torch.float, cuda=use_cuda)
|
style_feature = numpy_to_torch(style_feature, torch.float, cuda=use_cuda)
|
||||||
|
|
|
@ -306,9 +306,17 @@ class Synthesizer(object):
|
||||||
|
|
||||||
# handle emotion
|
# handle emotion
|
||||||
emotion_embedding, emotion_id = None, None
|
emotion_embedding, emotion_id = None, None
|
||||||
if not reference_wav and not getattr(self.tts_model, "prosody_encoder", False) and (self.tts_emotions_file or (
|
if (
|
||||||
getattr(self.tts_model, "emotion_manager", None) and getattr(self.tts_model.emotion_manager, "ids", None)
|
not reference_wav
|
||||||
)):
|
and not getattr(self.tts_model, "prosody_encoder", False)
|
||||||
|
and (
|
||||||
|
self.tts_emotions_file
|
||||||
|
or (
|
||||||
|
getattr(self.tts_model, "emotion_manager", None)
|
||||||
|
and getattr(self.tts_model.emotion_manager, "ids", None)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
):
|
||||||
if emotion_name and isinstance(emotion_name, str):
|
if emotion_name and isinstance(emotion_name, str):
|
||||||
if getattr(self.tts_config, "use_external_emotions_embeddings", False) or (
|
if getattr(self.tts_config, "use_external_emotions_embeddings", False) or (
|
||||||
getattr(self.tts_config, "model_args", None)
|
getattr(self.tts_config, "model_args", None)
|
||||||
|
@ -426,7 +434,7 @@ class Synthesizer(object):
|
||||||
d_vector=speaker_embedding,
|
d_vector=speaker_embedding,
|
||||||
use_griffin_lim=use_gl,
|
use_griffin_lim=use_gl,
|
||||||
reference_speaker_id=reference_speaker_id,
|
reference_speaker_id=reference_speaker_id,
|
||||||
reference_d_vector=reference_speaker_embedding
|
reference_d_vector=reference_speaker_embedding,
|
||||||
)
|
)
|
||||||
waveform = outputs
|
waveform = outputs
|
||||||
if not use_gl:
|
if not use_gl:
|
||||||
|
|
|
@ -58,6 +58,9 @@ continue_restore_path, _ = get_last_checkpoint(continue_path)
|
||||||
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
|
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||||
speaker_id = "ljspeech-1"
|
speaker_id = "ljspeech-1"
|
||||||
continue_speakers_path = os.path.join(continue_path, "speakers.json")
|
continue_speakers_path = os.path.join(continue_path, "speakers.json")
|
||||||
|
if not os.path.isfile(continue_speakers_path):
|
||||||
|
continue_speakers_path = continue_speakers_path.replace(".json", ".pth")
|
||||||
|
|
||||||
|
|
||||||
# Check integrity of the config
|
# Check integrity of the config
|
||||||
with open(continue_config_path, "r", encoding="utf-8") as f:
|
with open(continue_config_path, "r", encoding="utf-8") as f:
|
||||||
|
|
|
@ -37,23 +37,16 @@ config.audio.trim_db = 60
|
||||||
config.model_args.use_speaker_embedding = False
|
config.model_args.use_speaker_embedding = False
|
||||||
config.model_args.use_d_vector_file = True
|
config.model_args.use_d_vector_file = True
|
||||||
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
|
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
|
||||||
config.model_args.speaker_embedding_channels = 128
|
config.model_args.speaker_embedding_channels = 256
|
||||||
config.model_args.d_vector_dim = 100
|
config.model_args.d_vector_dim = 256
|
||||||
|
|
||||||
# emotion
|
# emotion
|
||||||
config.model_args.use_external_emotions_embeddings = True
|
config.model_args.use_external_emotions_embeddings = True
|
||||||
config.model_args.use_emotion_embedding = False
|
config.model_args.use_emotion_embedding = False
|
||||||
config.model_args.emotion_embedding_dim = 64
|
config.model_args.emotion_embedding_dim = 256
|
||||||
config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"
|
config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"
|
||||||
config.model_args.use_text_enc_spk_reversal_classifier = False
|
config.model_args.use_text_enc_spk_reversal_classifier = False
|
||||||
|
|
||||||
|
|
||||||
config.model_args.use_emotion_embedding_squeezer = True
|
|
||||||
config.model_args.emotion_embedding_squeezer_input_dim = 256
|
|
||||||
|
|
||||||
config.model_args.use_speaker_embedding_squeezer = True
|
|
||||||
config.model_args.speaker_embedding_squeezer_input_dim = 256
|
|
||||||
|
|
||||||
# consistency loss
|
# consistency loss
|
||||||
# config.model_args.use_emotion_encoder_as_loss = True
|
# config.model_args.use_emotion_encoder_as_loss = True
|
||||||
# config.model_args.encoder_model_path = "/raid/edresson/dev/Checkpoints/Coqui-Realesead/tts_models--multilingual--multi-dataset--your_tts/model_se.pth.tar"
|
# config.model_args.encoder_model_path = "/raid/edresson/dev/Checkpoints/Coqui-Realesead/tts_models--multilingual--multi-dataset--your_tts/model_se.pth.tar"
|
||||||
|
|
|
@ -0,0 +1,90 @@
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
from trainer import get_last_checkpoint
|
||||||
|
|
||||||
|
from tests import get_device_id, get_tests_output_path, run_cli
|
||||||
|
from TTS.tts.configs.vits_config import VitsConfig
|
||||||
|
|
||||||
|
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||||
|
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||||
|
|
||||||
|
|
||||||
|
config = VitsConfig(
|
||||||
|
batch_size=2,
|
||||||
|
eval_batch_size=2,
|
||||||
|
num_loader_workers=0,
|
||||||
|
num_eval_loader_workers=0,
|
||||||
|
text_cleaner="english_cleaners",
|
||||||
|
use_phonemes=True,
|
||||||
|
phoneme_language="en-us",
|
||||||
|
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
|
||||||
|
run_eval=True,
|
||||||
|
test_delay_epochs=-1,
|
||||||
|
epochs=1,
|
||||||
|
print_step=1,
|
||||||
|
print_eval=True,
|
||||||
|
test_sentences=[
|
||||||
|
["Be a voice, not an echo.", "ljspeech-1", None, None, "ljspeech-1"],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# set audio config
|
||||||
|
config.audio.do_trim_silence = True
|
||||||
|
config.audio.trim_db = 60
|
||||||
|
|
||||||
|
# active multispeaker d-vec mode
|
||||||
|
config.model_args.use_speaker_embedding = False
|
||||||
|
config.model_args.use_d_vector_file = True
|
||||||
|
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
|
||||||
|
config.model_args.speaker_embedding_channels = 128
|
||||||
|
config.model_args.d_vector_dim = 256
|
||||||
|
|
||||||
|
# emotion
|
||||||
|
config.model_args.emotion_embedding_dim = 256
|
||||||
|
|
||||||
|
config.model_args.use_emotion_embedding_squeezer = False
|
||||||
|
config.model_args.emotion_embedding_squeezer_input_dim = 256
|
||||||
|
config.model_args.use_speaker_embedding_as_emotion = True
|
||||||
|
|
||||||
|
config.model_args.use_speaker_embedding_squeezer = False
|
||||||
|
config.model_args.speaker_embedding_squeezer_input_dim = 256
|
||||||
|
|
||||||
|
# consistency loss
|
||||||
|
# config.model_args.use_emotion_encoder_as_loss = True
|
||||||
|
# config.model_args.encoder_model_path = "/raid/edresson/dev/Checkpoints/Coqui-Realesead/tts_models--multilingual--multi-dataset--your_tts/model_se.pth.tar"
|
||||||
|
# config.model_args.encoder_config_path = "/raid/edresson/dev/Checkpoints/Coqui-Realesead/tts_models--multilingual--multi-dataset--your_tts/config_se.json"
|
||||||
|
|
||||||
|
config.save_json(config_path)
|
||||||
|
|
||||||
|
# train the model for one epoch
|
||||||
|
command_train = (
|
||||||
|
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||||
|
f"--coqpit.output_path {output_path} "
|
||||||
|
"--coqpit.datasets.0.name ljspeech_test "
|
||||||
|
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||||
|
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||||
|
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||||
|
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
|
||||||
|
"--coqpit.test_delay_epochs 0"
|
||||||
|
)
|
||||||
|
run_cli(command_train)
|
||||||
|
|
||||||
|
# Find latest folder
|
||||||
|
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||||
|
|
||||||
|
# Inference using TTS API
|
||||||
|
continue_config_path = os.path.join(continue_path, "config.json")
|
||||||
|
continue_restore_path, _ = get_last_checkpoint(continue_path)
|
||||||
|
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||||
|
speaker_id = "ljspeech-1"
|
||||||
|
continue_speakers_path = config.model_args.d_vector_file
|
||||||
|
|
||||||
|
|
||||||
|
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
|
||||||
|
run_cli(inference_command)
|
||||||
|
|
||||||
|
# restore the model and continue training for one more epoch
|
||||||
|
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||||
|
run_cli(command_train)
|
||||||
|
shutil.rmtree(continue_path)
|
Loading…
Reference in New Issue