Add emotion classifier loss

This commit is contained in:
Edresson Casanova 2022-05-25 10:05:52 -03:00
parent f50819a5f6
commit 02194367d7
7 changed files with 66 additions and 13 deletions

View File

@ -115,7 +115,8 @@ class VitsConfig(BaseTTSConfig):
mel_loss_alpha: float = 45.0
dur_loss_alpha: float = 1.0
consistency_loss_alpha: float = 1.0
text_enc_spk_reversal_loss_alpha: float = 2.0
speaker_classifier_loss_alpha: float = 2.0
emotion_classifier_loss_alpha: float = 4.0
# data loader params
return_wav: bool = True

View File

@ -20,7 +20,7 @@ class GradientReversalFunction(torch.autograd.Function):
class ReversalClassifier(nn.Module):
"""Adversarial classifier with a gradient reversal layer.
"""Adversarial classifier with an optional gradient reversal layer.
Adapted from: https://github.com/Tomiinek/Multilingual_Text_to_Speech/
Args:
@ -29,19 +29,21 @@ class ReversalClassifier(nn.Module):
hidden_channels (int): Number of hidden channels.
gradient_clipping_bound (float): Maximal value of the gradient which flows from this module. Default: 0.25
scale_factor (float): Scale multiplier of the reversed gradientts. Default: 1.0
reversal (bool): If True reversal the gradients. Default: True
"""
def __init__(self, in_channels, out_channels, hidden_channels, gradient_clipping_bounds=0.25, scale_factor=1.0):
def __init__(self, in_channels, out_channels, hidden_channels, gradient_clipping_bounds=0.25, scale_factor=1.0, reversal=True):
super().__init__()
self.reversal = reversal
self._lambda = scale_factor
self._clipping = gradient_clipping_bounds
self._out_channels = out_channels
self._classifier = nn.Sequential(
nn.Linear(in_channels, hidden_channels), nn.ReLU(), nn.Linear(hidden_channels, out_channels)
)
self.test = nn.Linear(in_channels, hidden_channels)
def forward(self, x, labels, x_mask=None):
if self.reversal:
x = GradientReversalFunction.apply(x, self._lambda, self._clipping)
x = self._classifier(x)
loss = self.loss(labels, x, x_mask)
@ -60,3 +62,4 @@ class ReversalClassifier(nn.Module):
target[~input_mask] = ignore_index
return nn.functional.cross_entropy(predictions.transpose(1, 2), target, ignore_index=ignore_index)

View File

@ -590,7 +590,8 @@ class VitsGeneratorLoss(nn.Module):
self.dur_loss_alpha = c.dur_loss_alpha
self.mel_loss_alpha = c.mel_loss_alpha
self.consistency_loss_alpha = c.consistency_loss_alpha
self.text_enc_spk_reversal_loss_alpha = c.text_enc_spk_reversal_loss_alpha
self.emotion_classifier_alpha = c.emotion_classifier_loss_alpha
self.speaker_classifier_alpha = c.speaker_classifier_loss_alpha
self.stft = TorchSTFT(
c.audio.fft_size,
@ -665,7 +666,9 @@ class VitsGeneratorLoss(nn.Module):
gt_cons_emb=None,
syn_cons_emb=None,
loss_prosody_enc_spk_rev_classifier=None,
loss_prosody_enc_emo_classifier=None,
loss_text_enc_spk_rev_classifier=None,
loss_text_enc_emo_classifier=None,
):
"""
Shapes:
@ -702,14 +705,27 @@ class VitsGeneratorLoss(nn.Module):
return_dict["loss_consistency_enc"] = loss_enc
if loss_prosody_enc_spk_rev_classifier is not None:
loss_prosody_enc_spk_rev_classifier = loss_prosody_enc_spk_rev_classifier * self.speaker_classifier_alpha
loss += loss_prosody_enc_spk_rev_classifier
return_dict["loss_prosody_enc_spk_rev_classifier"] = loss_prosody_enc_spk_rev_classifier
if loss_prosody_enc_emo_classifier is not None:
loss_prosody_enc_emo_classifier = loss_prosody_enc_emo_classifier * self.emotion_classifier_alpha
loss += loss_prosody_enc_emo_classifier
return_dict["loss_prosody_enc_emo_classifier"] = loss_prosody_enc_emo_classifier
if loss_text_enc_spk_rev_classifier is not None:
loss_text_enc_spk_rev_classifier = loss_text_enc_spk_rev_classifier * self.text_enc_spk_reversal_loss_alpha
loss_text_enc_spk_rev_classifier = loss_text_enc_spk_rev_classifier * self.speaker_classifier_alpha
loss += loss_text_enc_spk_rev_classifier
return_dict["loss_text_enc_spk_rev_classifier"] = loss_text_enc_spk_rev_classifier
if loss_text_enc_emo_classifier is not None:
loss_text_enc_emo_classifier = loss_text_enc_emo_classifier * self.emotion_classifier_alpha
loss += loss_text_enc_emo_classifier
return_dict["loss_text_enc_emo_classifier"] = loss_text_enc_emo_classifier
# pass losses to the dict
return_dict["loss_gen"] = loss_gen
return_dict["loss_kl"] = loss_kl

View File

@ -541,6 +541,7 @@ class VitsArgs(Coqpit):
emotion_embedding_dim: int = 0
num_emotions: int = 0
use_text_enc_spk_reversal_classifier: bool = False
use_text_enc_emo_classifier: bool = False
# prosody encoder
use_prosody_encoder: bool = False
@ -548,6 +549,7 @@ class VitsArgs(Coqpit):
prosody_encoder_num_heads: int = 1
prosody_encoder_num_tokens: int = 5
use_prosody_enc_spk_reversal_classifier: bool = False
use_prosody_enc_emo_classifier: bool = False
use_prosody_conditional_flow_module: bool = False
@ -706,6 +708,13 @@ class Vits(BaseTTS):
out_channels=self.num_speakers,
hidden_channels=256,
)
if self.args.use_prosody_enc_emo_classifier:
self.pros_enc_emotion_classifier = ReversalClassifier(
in_channels=self.args.prosody_embedding_dim,
out_channels=self.num_emotions,
hidden_channels=256,
reversal=False
)
if self.args.use_prosody_conditional_flow_module:
cond_embedding_dim = 0
@ -732,6 +741,14 @@ class Vits(BaseTTS):
hidden_channels=256,
)
if self.args.use_text_enc_emo_classifier:
self.emo_text_enc_classifier = ReversalClassifier(
in_channels=self.args.hidden_channels,
out_channels=self.num_emotions,
hidden_channels=256,
reversal=False
)
self.waveform_decoder = HifiganGenerator(
self.args.hidden_channels,
1,
@ -1117,10 +1134,13 @@ class Vits(BaseTTS):
# prosody embedding
pros_emb = None
l_pros_speaker = None
l_pros_emotion = None
if self.args.use_prosody_encoder:
pros_emb = self.prosody_encoder(z).transpose(1, 2)
if self.args.use_prosody_enc_spk_reversal_classifier:
_, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None)
if self.args.use_prosody_enc_emo_classifier:
_, l_pros_emotion = self.pros_enc_emotion_classifier(pros_emb.transpose(1, 2), eid, x_mask=None)
x, m_p, logs_p, x_mask = self.text_encoder(
x,
@ -1134,6 +1154,11 @@ class Vits(BaseTTS):
if self.args.use_prosody_conditional_flow_module:
m_p = self.prosody_conditional_module(m_p, x_mask, g=eg if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) else pros_emb)
# reversal speaker loss to force the encoder to be speaker identity free
l_text_emotion = None
if self.args.use_text_enc_emo_classifier:
_, l_text_emotion = self.emo_text_enc_classifier(m_p.transpose(1, 2), eid, x_mask=None)
# reversal speaker loss to force the encoder to be speaker identity free
l_text_speaker = None
if self.args.use_text_enc_spk_reversal_classifier:
@ -1214,7 +1239,9 @@ class Vits(BaseTTS):
"syn_cons_emb": syn_cons_emb,
"slice_ids": slice_ids,
"loss_prosody_enc_spk_rev_classifier": l_pros_speaker,
"loss_prosody_enc_emo_classifier": l_pros_emotion,
"loss_text_enc_spk_rev_classifier": l_text_speaker,
"loss_text_enc_emo_classifier": l_text_emotion,
}
)
return outputs
@ -1531,7 +1558,9 @@ class Vits(BaseTTS):
gt_cons_emb=self.model_outputs_cache["gt_cons_emb"],
syn_cons_emb=self.model_outputs_cache["syn_cons_emb"],
loss_prosody_enc_spk_rev_classifier=self.model_outputs_cache["loss_prosody_enc_spk_rev_classifier"],
loss_prosody_enc_emo_classifier=self.model_outputs_cache["loss_prosody_enc_emo_classifier"],
loss_text_enc_spk_rev_classifier=self.model_outputs_cache["loss_text_enc_spk_rev_classifier"],
loss_text_enc_emo_classifier=self.model_outputs_cache["loss_text_enc_emo_classifier"],
)
return self.model_outputs_cache, loss_dict
@ -1737,7 +1766,7 @@ class Vits(BaseTTS):
emotion_embeddings = [emotion_mapping[w]["embedding"] for w in batch["audio_files"]]
emotion_embeddings = torch.FloatTensor(emotion_embeddings)
if self.emotion_manager is not None and self.emotion_manager.embeddings and self.args.use_emotion_embedding:
if self.emotion_manager is not None and self.emotion_manager.embeddings and (self.args.use_emotion_embedding or self.args.use_prosody_enc_emo_classifier or self.args.use_text_enc_emo_classifier):
emotion_mapping = self.emotion_manager.embeddings
emotion_names = [emotion_mapping[w]["name"] for w in batch["audio_files"]]
emotion_ids = [self.emotion_manager.ids[en] for en in emotion_names]

View File

@ -94,7 +94,7 @@ class EmotionManager(EmbeddingManager):
EmotionEncoder: Emotion encoder object.
"""
emotion_manager = None
if get_from_config_or_model_args_with_default(config, "use_emotion_embedding", False):
if get_from_config_or_model_args_with_default(config, "use_emotion_embedding", False) or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False):
if get_from_config_or_model_args_with_default(config, "emotions_ids_file", None):
emotion_manager = EmotionManager(
emotion_id_file_path=get_from_config_or_model_args_with_default(config, "emotions_ids_file", None)
@ -106,7 +106,7 @@ class EmotionManager(EmbeddingManager):
)
)
if get_from_config_or_model_args_with_default(config, "use_external_emotions_embeddings", False):
if get_from_config_or_model_args_with_default(config, "use_external_emotions_embeddings", False) or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False):
if get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None):
emotion_manager = EmotionManager(
embeddings_file_path=get_from_config_or_model_args_with_default(

View File

@ -293,9 +293,9 @@ class Synthesizer(object):
# handle emotion
emotion_embedding, emotion_id = None, None
if self.tts_emotions_file or (
if not getattr(self.tts_model, "prosody_encoder", False) and (self.tts_emotions_file or (
getattr(self.tts_model, "emotion_manager", None) and getattr(self.tts_model.emotion_manager, "ids", None)
):
)):
if emotion_name and isinstance(emotion_name, str):
if getattr(self.tts_config, "use_external_emotions_embeddings", False) or (
getattr(self.tts_config, "model_args", None)

View File

@ -43,6 +43,10 @@ config.model_args.d_vector_dim = 128
# prosody embedding
config.model_args.use_prosody_encoder = True
config.model_args.prosody_embedding_dim = 64
# active classifier
config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"
config.model_args.use_prosody_enc_emo_classifier = True
config.model_args.use_text_enc_emo_classifier = True
config.save_json(config_path)