mirror of https://github.com/coqui-ai/TTS.git
Add text encoder reversal speaker classifier loss
This commit is contained in:
parent
024e567849
commit
8505cd09e8
|
@ -115,6 +115,7 @@ class VitsConfig(BaseTTSConfig):
|
||||||
mel_loss_alpha: float = 45.0
|
mel_loss_alpha: float = 45.0
|
||||||
dur_loss_alpha: float = 1.0
|
dur_loss_alpha: float = 1.0
|
||||||
consistency_loss_alpha: float = 1.0
|
consistency_loss_alpha: float = 1.0
|
||||||
|
text_enc_spk_reversal_loss_alpha: float = 2.0
|
||||||
|
|
||||||
# data loader params
|
# data loader params
|
||||||
return_wav: bool = True
|
return_wav: bool = True
|
||||||
|
|
|
@ -590,6 +590,8 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
self.dur_loss_alpha = c.dur_loss_alpha
|
self.dur_loss_alpha = c.dur_loss_alpha
|
||||||
self.mel_loss_alpha = c.mel_loss_alpha
|
self.mel_loss_alpha = c.mel_loss_alpha
|
||||||
self.consistency_loss_alpha = c.consistency_loss_alpha
|
self.consistency_loss_alpha = c.consistency_loss_alpha
|
||||||
|
self.text_enc_spk_reversal_loss_alpha = c.text_enc_spk_reversal_loss_alpha
|
||||||
|
|
||||||
self.stft = TorchSTFT(
|
self.stft = TorchSTFT(
|
||||||
c.audio.fft_size,
|
c.audio.fft_size,
|
||||||
c.audio.hop_length,
|
c.audio.hop_length,
|
||||||
|
@ -662,7 +664,8 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
use_encoder_consistency_loss=False,
|
use_encoder_consistency_loss=False,
|
||||||
gt_cons_emb=None,
|
gt_cons_emb=None,
|
||||||
syn_cons_emb=None,
|
syn_cons_emb=None,
|
||||||
loss_spk_reversal_classifier=None,
|
loss_prosody_enc_spk_rev_classifier=None,
|
||||||
|
loss_text_enc_spk_rev_classifier=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
|
@ -698,9 +701,14 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
loss = loss + loss_enc
|
loss = loss + loss_enc
|
||||||
return_dict["loss_consistency_enc"] = loss_enc
|
return_dict["loss_consistency_enc"] = loss_enc
|
||||||
|
|
||||||
if loss_spk_reversal_classifier is not None:
|
if loss_prosody_enc_spk_rev_classifier is not None:
|
||||||
loss += loss_spk_reversal_classifier
|
loss += loss_prosody_enc_spk_rev_classifier
|
||||||
return_dict["loss_spk_reversal_classifier"] = loss_spk_reversal_classifier
|
return_dict["loss_prosody_enc_spk_rev_classifier"] = loss_prosody_enc_spk_rev_classifier
|
||||||
|
|
||||||
|
if loss_text_enc_spk_rev_classifier is not None:
|
||||||
|
loss_text_enc_spk_rev_classifier = loss_text_enc_spk_rev_classifier * self.text_enc_spk_reversal_loss_alpha
|
||||||
|
loss += loss_text_enc_spk_rev_classifier
|
||||||
|
return_dict["loss_text_enc_spk_rev_classifier"] = loss_text_enc_spk_rev_classifier
|
||||||
|
|
||||||
# pass losses to the dict
|
# pass losses to the dict
|
||||||
return_dict["loss_gen"] = loss_gen
|
return_dict["loss_gen"] = loss_gen
|
||||||
|
|
|
@ -540,6 +540,7 @@ class VitsArgs(Coqpit):
|
||||||
external_emotions_embs_file: str = None
|
external_emotions_embs_file: str = None
|
||||||
emotion_embedding_dim: int = 0
|
emotion_embedding_dim: int = 0
|
||||||
num_emotions: int = 0
|
num_emotions: int = 0
|
||||||
|
use_text_enc_spk_reversal_classifier: bool = False
|
||||||
|
|
||||||
# prosody encoder
|
# prosody encoder
|
||||||
use_prosody_encoder: bool = False
|
use_prosody_encoder: bool = False
|
||||||
|
@ -689,12 +690,19 @@ class Vits(BaseTTS):
|
||||||
num_style_tokens=self.args.prosody_encoder_num_tokens,
|
num_style_tokens=self.args.prosody_encoder_num_tokens,
|
||||||
gst_embedding_dim=self.args.prosody_embedding_dim,
|
gst_embedding_dim=self.args.prosody_embedding_dim,
|
||||||
)
|
)
|
||||||
self.speaker_reversal_classifier = ReversalClassifier(
|
self.speaker_pros_enc_reversal_classifier = ReversalClassifier(
|
||||||
in_channels=self.args.prosody_embedding_dim,
|
in_channels=self.args.prosody_embedding_dim,
|
||||||
out_channels=self.num_speakers,
|
out_channels=self.num_speakers,
|
||||||
hidden_channels=256,
|
hidden_channels=256,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.args.use_text_enc_spk_reversal_classifier:
|
||||||
|
self.speaker_text_enc_reversal_classifier = ReversalClassifier(
|
||||||
|
in_channels=self.args.hidden_channels + self.args.emotion_embedding_dim + self.args.prosody_embedding_dim,
|
||||||
|
out_channels=self.num_speakers,
|
||||||
|
hidden_channels=256,
|
||||||
|
)
|
||||||
|
|
||||||
self.waveform_decoder = HifiganGenerator(
|
self.waveform_decoder = HifiganGenerator(
|
||||||
self.args.hidden_channels,
|
self.args.hidden_channels,
|
||||||
1,
|
1,
|
||||||
|
@ -1081,10 +1089,15 @@ class Vits(BaseTTS):
|
||||||
l_pros_speaker = None
|
l_pros_speaker = None
|
||||||
if self.args.use_prosody_encoder:
|
if self.args.use_prosody_encoder:
|
||||||
pros_emb = self.prosody_encoder(z).transpose(1, 2)
|
pros_emb = self.prosody_encoder(z).transpose(1, 2)
|
||||||
_, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None)
|
_, l_pros_speaker = self.speaker_pros_enc_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None)
|
||||||
|
|
||||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg, pros_emb=pros_emb)
|
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg, pros_emb=pros_emb)
|
||||||
|
|
||||||
|
# reversal speaker loss to force the encoder to be speaker identity free
|
||||||
|
l_text_speaker = None
|
||||||
|
if self.args.use_text_enc_spk_reversal_classifier:
|
||||||
|
_, l_text_speaker = self.speaker_text_enc_reversal_classifier(x.transpose(1, 2), sid, x_mask=None)
|
||||||
|
|
||||||
# flow layers
|
# flow layers
|
||||||
z_p = self.flow(z, y_mask, g=g)
|
z_p = self.flow(z, y_mask, g=g)
|
||||||
|
|
||||||
|
@ -1159,7 +1172,8 @@ class Vits(BaseTTS):
|
||||||
"gt_cons_emb": gt_cons_emb,
|
"gt_cons_emb": gt_cons_emb,
|
||||||
"syn_cons_emb": syn_cons_emb,
|
"syn_cons_emb": syn_cons_emb,
|
||||||
"slice_ids": slice_ids,
|
"slice_ids": slice_ids,
|
||||||
"loss_spk_reversal_classifier": l_pros_speaker,
|
"loss_prosody_enc_spk_rev_classifier": l_pros_speaker,
|
||||||
|
"loss_text_enc_spk_rev_classifier": l_text_speaker,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return outputs
|
return outputs
|
||||||
|
@ -1465,7 +1479,8 @@ class Vits(BaseTTS):
|
||||||
or self.args.use_emotion_encoder_as_loss,
|
or self.args.use_emotion_encoder_as_loss,
|
||||||
gt_cons_emb=self.model_outputs_cache["gt_cons_emb"],
|
gt_cons_emb=self.model_outputs_cache["gt_cons_emb"],
|
||||||
syn_cons_emb=self.model_outputs_cache["syn_cons_emb"],
|
syn_cons_emb=self.model_outputs_cache["syn_cons_emb"],
|
||||||
loss_spk_reversal_classifier=self.model_outputs_cache["loss_spk_reversal_classifier"],
|
loss_prosody_enc_spk_rev_classifier=self.model_outputs_cache["loss_prosody_enc_spk_rev_classifier"],
|
||||||
|
loss_text_enc_spk_rev_classifier=self.model_outputs_cache["loss_text_enc_spk_rev_classifier"]
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.model_outputs_cache, loss_dict
|
return self.model_outputs_cache, loss_dict
|
||||||
|
@ -1636,7 +1651,7 @@ class Vits(BaseTTS):
|
||||||
if (
|
if (
|
||||||
self.speaker_manager is not None
|
self.speaker_manager is not None
|
||||||
and self.speaker_manager.ids
|
and self.speaker_manager.ids
|
||||||
and (self.args.use_speaker_embedding or self.args.use_prosody_encoder)
|
and (self.args.use_speaker_embedding or self.args.use_prosody_encoder or self.args.use_text_enc_spk_reversal_classifier)
|
||||||
):
|
):
|
||||||
speaker_ids = [self.speaker_manager.ids[sn] for sn in batch["speaker_names"]]
|
speaker_ids = [self.speaker_manager.ids[sn] for sn in batch["speaker_names"]]
|
||||||
|
|
||||||
|
|
|
@ -34,18 +34,18 @@ config.audio.do_trim_silence = True
|
||||||
config.audio.trim_db = 60
|
config.audio.trim_db = 60
|
||||||
|
|
||||||
# active multispeaker d-vec mode
|
# active multispeaker d-vec mode
|
||||||
config.model_args.use_speaker_embedding = True
|
config.model_args.use_speaker_embedding = False
|
||||||
config.model_args.use_d_vector_file = False
|
config.model_args.use_d_vector_file = True
|
||||||
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
|
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
|
||||||
config.model_args.speaker_embedding_channels = 128
|
config.model_args.speaker_embedding_channels = 128
|
||||||
config.model_args.d_vector_dim = 256
|
config.model_args.d_vector_dim = 256
|
||||||
|
|
||||||
# emotion
|
# emotion
|
||||||
config.model_args.use_external_emotions_embeddings = False
|
config.model_args.use_external_emotions_embeddings = True
|
||||||
config.model_args.use_emotion_embedding = True
|
config.model_args.use_emotion_embedding = False
|
||||||
config.model_args.emotion_just_encoder = False
|
|
||||||
config.model_args.emotion_embedding_dim = 256
|
config.model_args.emotion_embedding_dim = 256
|
||||||
config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"
|
config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"
|
||||||
|
config.model_args.use_text_enc_spk_reversal_classifier = False
|
||||||
|
|
||||||
# consistency loss
|
# consistency loss
|
||||||
# config.model_args.use_emotion_encoder_as_loss = True
|
# config.model_args.use_emotion_encoder_as_loss = True
|
||||||
|
|
Loading…
Reference in New Issue