Add text encoder reversal speaker classifier loss

This commit is contained in:
Edresson Casanova 2022-05-17 13:14:15 +00:00
parent a543d71352
commit ac3f98cefb
4 changed files with 38 additions and 14 deletions

View File

@ -115,6 +115,7 @@ class VitsConfig(BaseTTSConfig):
mel_loss_alpha: float = 45.0
dur_loss_alpha: float = 1.0
consistency_loss_alpha: float = 1.0
text_enc_spk_reversal_loss_alpha: float = 2.0
# data loader params
return_wav: bool = True

View File

@ -590,6 +590,8 @@ class VitsGeneratorLoss(nn.Module):
self.dur_loss_alpha = c.dur_loss_alpha
self.mel_loss_alpha = c.mel_loss_alpha
self.consistency_loss_alpha = c.consistency_loss_alpha
self.text_enc_spk_reversal_loss_alpha = c.text_enc_spk_reversal_loss_alpha
self.stft = TorchSTFT(
c.audio.fft_size,
c.audio.hop_length,
@ -662,7 +664,8 @@ class VitsGeneratorLoss(nn.Module):
use_encoder_consistency_loss=False,
gt_cons_emb=None,
syn_cons_emb=None,
loss_spk_reversal_classifier=None,
loss_prosody_enc_spk_rev_classifier=None,
loss_text_enc_spk_rev_classifier=None,
):
"""
Shapes:
@ -698,9 +701,14 @@ class VitsGeneratorLoss(nn.Module):
loss = loss + loss_enc
return_dict["loss_consistency_enc"] = loss_enc
if loss_spk_reversal_classifier is not None:
loss += loss_spk_reversal_classifier
return_dict["loss_spk_reversal_classifier"] = loss_spk_reversal_classifier
if loss_prosody_enc_spk_rev_classifier is not None:
loss += loss_prosody_enc_spk_rev_classifier
return_dict["loss_prosody_enc_spk_rev_classifier"] = loss_prosody_enc_spk_rev_classifier
if loss_text_enc_spk_rev_classifier is not None:
loss_text_enc_spk_rev_classifier = loss_text_enc_spk_rev_classifier * self.text_enc_spk_reversal_loss_alpha
loss += loss_text_enc_spk_rev_classifier
return_dict["loss_text_enc_spk_rev_classifier"] = loss_text_enc_spk_rev_classifier
# pass losses to the dict
return_dict["loss_gen"] = loss_gen

View File

@ -540,6 +540,7 @@ class VitsArgs(Coqpit):
external_emotions_embs_file: str = None
emotion_embedding_dim: int = 0
num_emotions: int = 0
use_text_enc_spk_reversal_classifier: bool = False
# prosody encoder
use_prosody_encoder: bool = False
@ -689,12 +690,19 @@ class Vits(BaseTTS):
num_style_tokens=self.args.prosody_encoder_num_tokens,
gst_embedding_dim=self.args.prosody_embedding_dim,
)
self.speaker_reversal_classifier = ReversalClassifier(
self.speaker_pros_enc_reversal_classifier = ReversalClassifier(
in_channels=self.args.prosody_embedding_dim,
out_channels=self.num_speakers,
hidden_channels=256,
)
if self.args.use_text_enc_spk_reversal_classifier:
self.speaker_text_enc_reversal_classifier = ReversalClassifier(
in_channels=self.args.hidden_channels + self.args.emotion_embedding_dim + self.args.prosody_embedding_dim,
out_channels=self.num_speakers,
hidden_channels=256,
)
self.waveform_decoder = HifiganGenerator(
self.args.hidden_channels,
1,
@ -1081,10 +1089,15 @@ class Vits(BaseTTS):
l_pros_speaker = None
if self.args.use_prosody_encoder:
pros_emb = self.prosody_encoder(z).transpose(1, 2)
_, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None)
_, l_pros_speaker = self.speaker_pros_enc_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None)
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg, pros_emb=pros_emb)
# reversal speaker loss to force the encoder to be speaker identity free
l_text_speaker = None
if self.args.use_text_enc_spk_reversal_classifier:
_, l_text_speaker = self.speaker_text_enc_reversal_classifier(x.transpose(1, 2), sid, x_mask=None)
# flow layers
z_p = self.flow(z, y_mask, g=g)
@ -1159,7 +1172,8 @@ class Vits(BaseTTS):
"gt_cons_emb": gt_cons_emb,
"syn_cons_emb": syn_cons_emb,
"slice_ids": slice_ids,
"loss_spk_reversal_classifier": l_pros_speaker,
"loss_prosody_enc_spk_rev_classifier": l_pros_speaker,
"loss_text_enc_spk_rev_classifier": l_text_speaker,
}
)
return outputs
@ -1465,7 +1479,8 @@ class Vits(BaseTTS):
or self.args.use_emotion_encoder_as_loss,
gt_cons_emb=self.model_outputs_cache["gt_cons_emb"],
syn_cons_emb=self.model_outputs_cache["syn_cons_emb"],
loss_spk_reversal_classifier=self.model_outputs_cache["loss_spk_reversal_classifier"],
loss_prosody_enc_spk_rev_classifier=self.model_outputs_cache["loss_prosody_enc_spk_rev_classifier"],
loss_text_enc_spk_rev_classifier=self.model_outputs_cache["loss_text_enc_spk_rev_classifier"]
)
return self.model_outputs_cache, loss_dict
@ -1636,7 +1651,7 @@ class Vits(BaseTTS):
if (
self.speaker_manager is not None
and self.speaker_manager.ids
and (self.args.use_speaker_embedding or self.args.use_prosody_encoder)
and (self.args.use_speaker_embedding or self.args.use_prosody_encoder or self.args.use_text_enc_spk_reversal_classifier)
):
speaker_ids = [self.speaker_manager.ids[sn] for sn in batch["speaker_names"]]

View File

@ -34,18 +34,18 @@ config.audio.do_trim_silence = True
config.audio.trim_db = 60
# active multispeaker d-vec mode
config.model_args.use_speaker_embedding = True
config.model_args.use_d_vector_file = False
config.model_args.use_speaker_embedding = False
config.model_args.use_d_vector_file = True
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
config.model_args.speaker_embedding_channels = 128
config.model_args.d_vector_dim = 256
# emotion
config.model_args.use_external_emotions_embeddings = False
config.model_args.use_emotion_embedding = True
config.model_args.emotion_just_encoder = False
config.model_args.use_external_emotions_embeddings = True
config.model_args.use_emotion_embedding = False
config.model_args.emotion_embedding_dim = 256
config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"
config.model_args.use_text_enc_spk_reversal_classifier = False
# consistency loss
# config.model_args.use_emotion_encoder_as_loss = True