mirror of https://github.com/coqui-ai/TTS.git
Add emotion classifier loss
This commit is contained in:
parent
f50819a5f6
commit
02194367d7
|
@ -115,7 +115,8 @@ class VitsConfig(BaseTTSConfig):
|
|||
mel_loss_alpha: float = 45.0
|
||||
dur_loss_alpha: float = 1.0
|
||||
consistency_loss_alpha: float = 1.0
|
||||
text_enc_spk_reversal_loss_alpha: float = 2.0
|
||||
speaker_classifier_loss_alpha: float = 2.0
|
||||
emotion_classifier_loss_alpha: float = 4.0
|
||||
|
||||
# data loader params
|
||||
return_wav: bool = True
|
||||
|
|
|
@ -20,7 +20,7 @@ class GradientReversalFunction(torch.autograd.Function):
|
|||
|
||||
|
||||
class ReversalClassifier(nn.Module):
|
||||
"""Adversarial classifier with a gradient reversal layer.
|
||||
"""Adversarial classifier with an optional gradient reversal layer.
|
||||
Adapted from: https://github.com/Tomiinek/Multilingual_Text_to_Speech/
|
||||
|
||||
Args:
|
||||
|
@ -29,20 +29,22 @@ class ReversalClassifier(nn.Module):
|
|||
hidden_channels (int): Number of hidden channels.
|
||||
gradient_clipping_bound (float): Maximal value of the gradient which flows from this module. Default: 0.25
|
||||
scale_factor (float): Scale multiplier of the reversed gradientts. Default: 1.0
|
||||
reversal (bool): If True reversal the gradients. Default: True
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, hidden_channels, gradient_clipping_bounds=0.25, scale_factor=1.0):
|
||||
def __init__(self, in_channels, out_channels, hidden_channels, gradient_clipping_bounds=0.25, scale_factor=1.0, reversal=True):
|
||||
super().__init__()
|
||||
self.reversal = reversal
|
||||
self._lambda = scale_factor
|
||||
self._clipping = gradient_clipping_bounds
|
||||
self._out_channels = out_channels
|
||||
self._classifier = nn.Sequential(
|
||||
nn.Linear(in_channels, hidden_channels), nn.ReLU(), nn.Linear(hidden_channels, out_channels)
|
||||
)
|
||||
self.test = nn.Linear(in_channels, hidden_channels)
|
||||
|
||||
def forward(self, x, labels, x_mask=None):
|
||||
x = GradientReversalFunction.apply(x, self._lambda, self._clipping)
|
||||
if self.reversal:
|
||||
x = GradientReversalFunction.apply(x, self._lambda, self._clipping)
|
||||
x = self._classifier(x)
|
||||
loss = self.loss(labels, x, x_mask)
|
||||
return x, loss
|
||||
|
@ -60,3 +62,4 @@ class ReversalClassifier(nn.Module):
|
|||
target[~input_mask] = ignore_index
|
||||
|
||||
return nn.functional.cross_entropy(predictions.transpose(1, 2), target, ignore_index=ignore_index)
|
||||
|
||||
|
|
|
@ -590,7 +590,8 @@ class VitsGeneratorLoss(nn.Module):
|
|||
self.dur_loss_alpha = c.dur_loss_alpha
|
||||
self.mel_loss_alpha = c.mel_loss_alpha
|
||||
self.consistency_loss_alpha = c.consistency_loss_alpha
|
||||
self.text_enc_spk_reversal_loss_alpha = c.text_enc_spk_reversal_loss_alpha
|
||||
self.emotion_classifier_alpha = c.emotion_classifier_loss_alpha
|
||||
self.speaker_classifier_alpha = c.speaker_classifier_loss_alpha
|
||||
|
||||
self.stft = TorchSTFT(
|
||||
c.audio.fft_size,
|
||||
|
@ -665,7 +666,9 @@ class VitsGeneratorLoss(nn.Module):
|
|||
gt_cons_emb=None,
|
||||
syn_cons_emb=None,
|
||||
loss_prosody_enc_spk_rev_classifier=None,
|
||||
loss_prosody_enc_emo_classifier=None,
|
||||
loss_text_enc_spk_rev_classifier=None,
|
||||
loss_text_enc_emo_classifier=None,
|
||||
):
|
||||
"""
|
||||
Shapes:
|
||||
|
@ -702,14 +705,27 @@ class VitsGeneratorLoss(nn.Module):
|
|||
return_dict["loss_consistency_enc"] = loss_enc
|
||||
|
||||
if loss_prosody_enc_spk_rev_classifier is not None:
|
||||
loss_prosody_enc_spk_rev_classifier = loss_prosody_enc_spk_rev_classifier * self.speaker_classifier_alpha
|
||||
loss += loss_prosody_enc_spk_rev_classifier
|
||||
return_dict["loss_prosody_enc_spk_rev_classifier"] = loss_prosody_enc_spk_rev_classifier
|
||||
|
||||
if loss_prosody_enc_emo_classifier is not None:
|
||||
loss_prosody_enc_emo_classifier = loss_prosody_enc_emo_classifier * self.emotion_classifier_alpha
|
||||
loss += loss_prosody_enc_emo_classifier
|
||||
return_dict["loss_prosody_enc_emo_classifier"] = loss_prosody_enc_emo_classifier
|
||||
|
||||
|
||||
if loss_text_enc_spk_rev_classifier is not None:
|
||||
loss_text_enc_spk_rev_classifier = loss_text_enc_spk_rev_classifier * self.text_enc_spk_reversal_loss_alpha
|
||||
loss_text_enc_spk_rev_classifier = loss_text_enc_spk_rev_classifier * self.speaker_classifier_alpha
|
||||
loss += loss_text_enc_spk_rev_classifier
|
||||
return_dict["loss_text_enc_spk_rev_classifier"] = loss_text_enc_spk_rev_classifier
|
||||
|
||||
if loss_text_enc_emo_classifier is not None:
|
||||
loss_text_enc_emo_classifier = loss_text_enc_emo_classifier * self.emotion_classifier_alpha
|
||||
loss += loss_text_enc_emo_classifier
|
||||
return_dict["loss_text_enc_emo_classifier"] = loss_text_enc_emo_classifier
|
||||
|
||||
|
||||
# pass losses to the dict
|
||||
return_dict["loss_gen"] = loss_gen
|
||||
return_dict["loss_kl"] = loss_kl
|
||||
|
|
|
@ -541,6 +541,7 @@ class VitsArgs(Coqpit):
|
|||
emotion_embedding_dim: int = 0
|
||||
num_emotions: int = 0
|
||||
use_text_enc_spk_reversal_classifier: bool = False
|
||||
use_text_enc_emo_classifier: bool = False
|
||||
|
||||
# prosody encoder
|
||||
use_prosody_encoder: bool = False
|
||||
|
@ -548,6 +549,7 @@ class VitsArgs(Coqpit):
|
|||
prosody_encoder_num_heads: int = 1
|
||||
prosody_encoder_num_tokens: int = 5
|
||||
use_prosody_enc_spk_reversal_classifier: bool = False
|
||||
use_prosody_enc_emo_classifier: bool = False
|
||||
|
||||
use_prosody_conditional_flow_module: bool = False
|
||||
|
||||
|
@ -706,6 +708,13 @@ class Vits(BaseTTS):
|
|||
out_channels=self.num_speakers,
|
||||
hidden_channels=256,
|
||||
)
|
||||
if self.args.use_prosody_enc_emo_classifier:
|
||||
self.pros_enc_emotion_classifier = ReversalClassifier(
|
||||
in_channels=self.args.prosody_embedding_dim,
|
||||
out_channels=self.num_emotions,
|
||||
hidden_channels=256,
|
||||
reversal=False
|
||||
)
|
||||
|
||||
if self.args.use_prosody_conditional_flow_module:
|
||||
cond_embedding_dim = 0
|
||||
|
@ -732,6 +741,14 @@ class Vits(BaseTTS):
|
|||
hidden_channels=256,
|
||||
)
|
||||
|
||||
if self.args.use_text_enc_emo_classifier:
|
||||
self.emo_text_enc_classifier = ReversalClassifier(
|
||||
in_channels=self.args.hidden_channels,
|
||||
out_channels=self.num_emotions,
|
||||
hidden_channels=256,
|
||||
reversal=False
|
||||
)
|
||||
|
||||
self.waveform_decoder = HifiganGenerator(
|
||||
self.args.hidden_channels,
|
||||
1,
|
||||
|
@ -1117,11 +1134,14 @@ class Vits(BaseTTS):
|
|||
# prosody embedding
|
||||
pros_emb = None
|
||||
l_pros_speaker = None
|
||||
l_pros_emotion = None
|
||||
if self.args.use_prosody_encoder:
|
||||
pros_emb = self.prosody_encoder(z).transpose(1, 2)
|
||||
if self.args.use_prosody_enc_spk_reversal_classifier:
|
||||
_, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None)
|
||||
|
||||
if self.args.use_prosody_enc_emo_classifier:
|
||||
_, l_pros_emotion = self.pros_enc_emotion_classifier(pros_emb.transpose(1, 2), eid, x_mask=None)
|
||||
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(
|
||||
x,
|
||||
x_lengths,
|
||||
|
@ -1134,6 +1154,11 @@ class Vits(BaseTTS):
|
|||
if self.args.use_prosody_conditional_flow_module:
|
||||
m_p = self.prosody_conditional_module(m_p, x_mask, g=eg if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) else pros_emb)
|
||||
|
||||
# reversal speaker loss to force the encoder to be speaker identity free
|
||||
l_text_emotion = None
|
||||
if self.args.use_text_enc_emo_classifier:
|
||||
_, l_text_emotion = self.emo_text_enc_classifier(m_p.transpose(1, 2), eid, x_mask=None)
|
||||
|
||||
# reversal speaker loss to force the encoder to be speaker identity free
|
||||
l_text_speaker = None
|
||||
if self.args.use_text_enc_spk_reversal_classifier:
|
||||
|
@ -1214,7 +1239,9 @@ class Vits(BaseTTS):
|
|||
"syn_cons_emb": syn_cons_emb,
|
||||
"slice_ids": slice_ids,
|
||||
"loss_prosody_enc_spk_rev_classifier": l_pros_speaker,
|
||||
"loss_prosody_enc_emo_classifier": l_pros_emotion,
|
||||
"loss_text_enc_spk_rev_classifier": l_text_speaker,
|
||||
"loss_text_enc_emo_classifier": l_text_emotion,
|
||||
}
|
||||
)
|
||||
return outputs
|
||||
|
@ -1531,7 +1558,9 @@ class Vits(BaseTTS):
|
|||
gt_cons_emb=self.model_outputs_cache["gt_cons_emb"],
|
||||
syn_cons_emb=self.model_outputs_cache["syn_cons_emb"],
|
||||
loss_prosody_enc_spk_rev_classifier=self.model_outputs_cache["loss_prosody_enc_spk_rev_classifier"],
|
||||
loss_prosody_enc_emo_classifier=self.model_outputs_cache["loss_prosody_enc_emo_classifier"],
|
||||
loss_text_enc_spk_rev_classifier=self.model_outputs_cache["loss_text_enc_spk_rev_classifier"],
|
||||
loss_text_enc_emo_classifier=self.model_outputs_cache["loss_text_enc_emo_classifier"],
|
||||
)
|
||||
|
||||
return self.model_outputs_cache, loss_dict
|
||||
|
@ -1737,7 +1766,7 @@ class Vits(BaseTTS):
|
|||
emotion_embeddings = [emotion_mapping[w]["embedding"] for w in batch["audio_files"]]
|
||||
emotion_embeddings = torch.FloatTensor(emotion_embeddings)
|
||||
|
||||
if self.emotion_manager is not None and self.emotion_manager.embeddings and self.args.use_emotion_embedding:
|
||||
if self.emotion_manager is not None and self.emotion_manager.embeddings and (self.args.use_emotion_embedding or self.args.use_prosody_enc_emo_classifier or self.args.use_text_enc_emo_classifier):
|
||||
emotion_mapping = self.emotion_manager.embeddings
|
||||
emotion_names = [emotion_mapping[w]["name"] for w in batch["audio_files"]]
|
||||
emotion_ids = [self.emotion_manager.ids[en] for en in emotion_names]
|
||||
|
|
|
@ -94,7 +94,7 @@ class EmotionManager(EmbeddingManager):
|
|||
EmotionEncoder: Emotion encoder object.
|
||||
"""
|
||||
emotion_manager = None
|
||||
if get_from_config_or_model_args_with_default(config, "use_emotion_embedding", False):
|
||||
if get_from_config_or_model_args_with_default(config, "use_emotion_embedding", False) or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False):
|
||||
if get_from_config_or_model_args_with_default(config, "emotions_ids_file", None):
|
||||
emotion_manager = EmotionManager(
|
||||
emotion_id_file_path=get_from_config_or_model_args_with_default(config, "emotions_ids_file", None)
|
||||
|
@ -106,7 +106,7 @@ class EmotionManager(EmbeddingManager):
|
|||
)
|
||||
)
|
||||
|
||||
if get_from_config_or_model_args_with_default(config, "use_external_emotions_embeddings", False):
|
||||
if get_from_config_or_model_args_with_default(config, "use_external_emotions_embeddings", False) or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False):
|
||||
if get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None):
|
||||
emotion_manager = EmotionManager(
|
||||
embeddings_file_path=get_from_config_or_model_args_with_default(
|
||||
|
|
|
@ -293,9 +293,9 @@ class Synthesizer(object):
|
|||
|
||||
# handle emotion
|
||||
emotion_embedding, emotion_id = None, None
|
||||
if self.tts_emotions_file or (
|
||||
if not getattr(self.tts_model, "prosody_encoder", False) and (self.tts_emotions_file or (
|
||||
getattr(self.tts_model, "emotion_manager", None) and getattr(self.tts_model.emotion_manager, "ids", None)
|
||||
):
|
||||
)):
|
||||
if emotion_name and isinstance(emotion_name, str):
|
||||
if getattr(self.tts_config, "use_external_emotions_embeddings", False) or (
|
||||
getattr(self.tts_config, "model_args", None)
|
||||
|
|
|
@ -43,6 +43,10 @@ config.model_args.d_vector_dim = 128
|
|||
# prosody embedding
|
||||
config.model_args.use_prosody_encoder = True
|
||||
config.model_args.prosody_embedding_dim = 64
|
||||
# active classifier
|
||||
config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"
|
||||
config.model_args.use_prosody_enc_emo_classifier = True
|
||||
config.model_args.use_text_enc_emo_classifier = True
|
||||
|
||||
config.save_json(config_path)
|
||||
|
||||
|
|
Loading…
Reference in New Issue