mirror of https://github.com/coqui-ai/TTS.git
Add text encoder reversal speaker classifier loss
This commit is contained in:
parent
a543d71352
commit
ac3f98cefb
|
@ -115,6 +115,7 @@ class VitsConfig(BaseTTSConfig):
|
|||
mel_loss_alpha: float = 45.0
|
||||
dur_loss_alpha: float = 1.0
|
||||
consistency_loss_alpha: float = 1.0
|
||||
text_enc_spk_reversal_loss_alpha: float = 2.0
|
||||
|
||||
# data loader params
|
||||
return_wav: bool = True
|
||||
|
|
|
@ -590,6 +590,8 @@ class VitsGeneratorLoss(nn.Module):
|
|||
self.dur_loss_alpha = c.dur_loss_alpha
|
||||
self.mel_loss_alpha = c.mel_loss_alpha
|
||||
self.consistency_loss_alpha = c.consistency_loss_alpha
|
||||
self.text_enc_spk_reversal_loss_alpha = c.text_enc_spk_reversal_loss_alpha
|
||||
|
||||
self.stft = TorchSTFT(
|
||||
c.audio.fft_size,
|
||||
c.audio.hop_length,
|
||||
|
@ -662,7 +664,8 @@ class VitsGeneratorLoss(nn.Module):
|
|||
use_encoder_consistency_loss=False,
|
||||
gt_cons_emb=None,
|
||||
syn_cons_emb=None,
|
||||
loss_spk_reversal_classifier=None,
|
||||
loss_prosody_enc_spk_rev_classifier=None,
|
||||
loss_text_enc_spk_rev_classifier=None,
|
||||
):
|
||||
"""
|
||||
Shapes:
|
||||
|
@ -698,9 +701,14 @@ class VitsGeneratorLoss(nn.Module):
|
|||
loss = loss + loss_enc
|
||||
return_dict["loss_consistency_enc"] = loss_enc
|
||||
|
||||
if loss_spk_reversal_classifier is not None:
|
||||
loss += loss_spk_reversal_classifier
|
||||
return_dict["loss_spk_reversal_classifier"] = loss_spk_reversal_classifier
|
||||
if loss_prosody_enc_spk_rev_classifier is not None:
|
||||
loss += loss_prosody_enc_spk_rev_classifier
|
||||
return_dict["loss_prosody_enc_spk_rev_classifier"] = loss_prosody_enc_spk_rev_classifier
|
||||
|
||||
if loss_text_enc_spk_rev_classifier is not None:
|
||||
loss_text_enc_spk_rev_classifier = loss_text_enc_spk_rev_classifier * self.text_enc_spk_reversal_loss_alpha
|
||||
loss += loss_text_enc_spk_rev_classifier
|
||||
return_dict["loss_text_enc_spk_rev_classifier"] = loss_text_enc_spk_rev_classifier
|
||||
|
||||
# pass losses to the dict
|
||||
return_dict["loss_gen"] = loss_gen
|
||||
|
|
|
@ -540,6 +540,7 @@ class VitsArgs(Coqpit):
|
|||
external_emotions_embs_file: str = None
|
||||
emotion_embedding_dim: int = 0
|
||||
num_emotions: int = 0
|
||||
use_text_enc_spk_reversal_classifier: bool = False
|
||||
|
||||
# prosody encoder
|
||||
use_prosody_encoder: bool = False
|
||||
|
@ -689,12 +690,19 @@ class Vits(BaseTTS):
|
|||
num_style_tokens=self.args.prosody_encoder_num_tokens,
|
||||
gst_embedding_dim=self.args.prosody_embedding_dim,
|
||||
)
|
||||
self.speaker_reversal_classifier = ReversalClassifier(
|
||||
self.speaker_pros_enc_reversal_classifier = ReversalClassifier(
|
||||
in_channels=self.args.prosody_embedding_dim,
|
||||
out_channels=self.num_speakers,
|
||||
hidden_channels=256,
|
||||
)
|
||||
|
||||
if self.args.use_text_enc_spk_reversal_classifier:
|
||||
self.speaker_text_enc_reversal_classifier = ReversalClassifier(
|
||||
in_channels=self.args.hidden_channels + self.args.emotion_embedding_dim + self.args.prosody_embedding_dim,
|
||||
out_channels=self.num_speakers,
|
||||
hidden_channels=256,
|
||||
)
|
||||
|
||||
self.waveform_decoder = HifiganGenerator(
|
||||
self.args.hidden_channels,
|
||||
1,
|
||||
|
@ -1081,10 +1089,15 @@ class Vits(BaseTTS):
|
|||
l_pros_speaker = None
|
||||
if self.args.use_prosody_encoder:
|
||||
pros_emb = self.prosody_encoder(z).transpose(1, 2)
|
||||
_, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None)
|
||||
_, l_pros_speaker = self.speaker_pros_enc_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None)
|
||||
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg, pros_emb=pros_emb)
|
||||
|
||||
# reversal speaker loss to force the encoder to be speaker identity free
|
||||
l_text_speaker = None
|
||||
if self.args.use_text_enc_spk_reversal_classifier:
|
||||
_, l_text_speaker = self.speaker_text_enc_reversal_classifier(x.transpose(1, 2), sid, x_mask=None)
|
||||
|
||||
# flow layers
|
||||
z_p = self.flow(z, y_mask, g=g)
|
||||
|
||||
|
@ -1159,7 +1172,8 @@ class Vits(BaseTTS):
|
|||
"gt_cons_emb": gt_cons_emb,
|
||||
"syn_cons_emb": syn_cons_emb,
|
||||
"slice_ids": slice_ids,
|
||||
"loss_spk_reversal_classifier": l_pros_speaker,
|
||||
"loss_prosody_enc_spk_rev_classifier": l_pros_speaker,
|
||||
"loss_text_enc_spk_rev_classifier": l_text_speaker,
|
||||
}
|
||||
)
|
||||
return outputs
|
||||
|
@ -1465,7 +1479,8 @@ class Vits(BaseTTS):
|
|||
or self.args.use_emotion_encoder_as_loss,
|
||||
gt_cons_emb=self.model_outputs_cache["gt_cons_emb"],
|
||||
syn_cons_emb=self.model_outputs_cache["syn_cons_emb"],
|
||||
loss_spk_reversal_classifier=self.model_outputs_cache["loss_spk_reversal_classifier"],
|
||||
loss_prosody_enc_spk_rev_classifier=self.model_outputs_cache["loss_prosody_enc_spk_rev_classifier"],
|
||||
loss_text_enc_spk_rev_classifier=self.model_outputs_cache["loss_text_enc_spk_rev_classifier"]
|
||||
)
|
||||
|
||||
return self.model_outputs_cache, loss_dict
|
||||
|
@ -1636,7 +1651,7 @@ class Vits(BaseTTS):
|
|||
if (
|
||||
self.speaker_manager is not None
|
||||
and self.speaker_manager.ids
|
||||
and (self.args.use_speaker_embedding or self.args.use_prosody_encoder)
|
||||
and (self.args.use_speaker_embedding or self.args.use_prosody_encoder or self.args.use_text_enc_spk_reversal_classifier)
|
||||
):
|
||||
speaker_ids = [self.speaker_manager.ids[sn] for sn in batch["speaker_names"]]
|
||||
|
||||
|
|
|
@ -34,18 +34,18 @@ config.audio.do_trim_silence = True
|
|||
config.audio.trim_db = 60
|
||||
|
||||
# active multispeaker d-vec mode
|
||||
config.model_args.use_speaker_embedding = True
|
||||
config.model_args.use_d_vector_file = False
|
||||
config.model_args.use_speaker_embedding = False
|
||||
config.model_args.use_d_vector_file = True
|
||||
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
|
||||
config.model_args.speaker_embedding_channels = 128
|
||||
config.model_args.d_vector_dim = 256
|
||||
|
||||
# emotion
|
||||
config.model_args.use_external_emotions_embeddings = False
|
||||
config.model_args.use_emotion_embedding = True
|
||||
config.model_args.emotion_just_encoder = False
|
||||
config.model_args.use_external_emotions_embeddings = True
|
||||
config.model_args.use_emotion_embedding = False
|
||||
config.model_args.emotion_embedding_dim = 256
|
||||
config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"
|
||||
config.model_args.use_text_enc_spk_reversal_classifier = False
|
||||
|
||||
# consistency loss
|
||||
# config.model_args.use_emotion_encoder_as_loss = True
|
||||
|
|
Loading…
Reference in New Issue