mirror of https://github.com/coqui-ai/TTS.git
Add emotion classifier loss
This commit is contained in:
parent
f50819a5f6
commit
02194367d7
|
@ -115,7 +115,8 @@ class VitsConfig(BaseTTSConfig):
|
||||||
mel_loss_alpha: float = 45.0
|
mel_loss_alpha: float = 45.0
|
||||||
dur_loss_alpha: float = 1.0
|
dur_loss_alpha: float = 1.0
|
||||||
consistency_loss_alpha: float = 1.0
|
consistency_loss_alpha: float = 1.0
|
||||||
text_enc_spk_reversal_loss_alpha: float = 2.0
|
speaker_classifier_loss_alpha: float = 2.0
|
||||||
|
emotion_classifier_loss_alpha: float = 4.0
|
||||||
|
|
||||||
# data loader params
|
# data loader params
|
||||||
return_wav: bool = True
|
return_wav: bool = True
|
||||||
|
|
|
@ -20,7 +20,7 @@ class GradientReversalFunction(torch.autograd.Function):
|
||||||
|
|
||||||
|
|
||||||
class ReversalClassifier(nn.Module):
|
class ReversalClassifier(nn.Module):
|
||||||
"""Adversarial classifier with a gradient reversal layer.
|
"""Adversarial classifier with an optional gradient reversal layer.
|
||||||
Adapted from: https://github.com/Tomiinek/Multilingual_Text_to_Speech/
|
Adapted from: https://github.com/Tomiinek/Multilingual_Text_to_Speech/
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -29,19 +29,21 @@ class ReversalClassifier(nn.Module):
|
||||||
hidden_channels (int): Number of hidden channels.
|
hidden_channels (int): Number of hidden channels.
|
||||||
gradient_clipping_bound (float): Maximal value of the gradient which flows from this module. Default: 0.25
|
gradient_clipping_bound (float): Maximal value of the gradient which flows from this module. Default: 0.25
|
||||||
scale_factor (float): Scale multiplier of the reversed gradientts. Default: 1.0
|
scale_factor (float): Scale multiplier of the reversed gradientts. Default: 1.0
|
||||||
|
reversal (bool): If True reversal the gradients. Default: True
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, hidden_channels, gradient_clipping_bounds=0.25, scale_factor=1.0):
|
def __init__(self, in_channels, out_channels, hidden_channels, gradient_clipping_bounds=0.25, scale_factor=1.0, reversal=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.reversal = reversal
|
||||||
self._lambda = scale_factor
|
self._lambda = scale_factor
|
||||||
self._clipping = gradient_clipping_bounds
|
self._clipping = gradient_clipping_bounds
|
||||||
self._out_channels = out_channels
|
self._out_channels = out_channels
|
||||||
self._classifier = nn.Sequential(
|
self._classifier = nn.Sequential(
|
||||||
nn.Linear(in_channels, hidden_channels), nn.ReLU(), nn.Linear(hidden_channels, out_channels)
|
nn.Linear(in_channels, hidden_channels), nn.ReLU(), nn.Linear(hidden_channels, out_channels)
|
||||||
)
|
)
|
||||||
self.test = nn.Linear(in_channels, hidden_channels)
|
|
||||||
|
|
||||||
def forward(self, x, labels, x_mask=None):
|
def forward(self, x, labels, x_mask=None):
|
||||||
|
if self.reversal:
|
||||||
x = GradientReversalFunction.apply(x, self._lambda, self._clipping)
|
x = GradientReversalFunction.apply(x, self._lambda, self._clipping)
|
||||||
x = self._classifier(x)
|
x = self._classifier(x)
|
||||||
loss = self.loss(labels, x, x_mask)
|
loss = self.loss(labels, x, x_mask)
|
||||||
|
@ -60,3 +62,4 @@ class ReversalClassifier(nn.Module):
|
||||||
target[~input_mask] = ignore_index
|
target[~input_mask] = ignore_index
|
||||||
|
|
||||||
return nn.functional.cross_entropy(predictions.transpose(1, 2), target, ignore_index=ignore_index)
|
return nn.functional.cross_entropy(predictions.transpose(1, 2), target, ignore_index=ignore_index)
|
||||||
|
|
||||||
|
|
|
@ -590,7 +590,8 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
self.dur_loss_alpha = c.dur_loss_alpha
|
self.dur_loss_alpha = c.dur_loss_alpha
|
||||||
self.mel_loss_alpha = c.mel_loss_alpha
|
self.mel_loss_alpha = c.mel_loss_alpha
|
||||||
self.consistency_loss_alpha = c.consistency_loss_alpha
|
self.consistency_loss_alpha = c.consistency_loss_alpha
|
||||||
self.text_enc_spk_reversal_loss_alpha = c.text_enc_spk_reversal_loss_alpha
|
self.emotion_classifier_alpha = c.emotion_classifier_loss_alpha
|
||||||
|
self.speaker_classifier_alpha = c.speaker_classifier_loss_alpha
|
||||||
|
|
||||||
self.stft = TorchSTFT(
|
self.stft = TorchSTFT(
|
||||||
c.audio.fft_size,
|
c.audio.fft_size,
|
||||||
|
@ -665,7 +666,9 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
gt_cons_emb=None,
|
gt_cons_emb=None,
|
||||||
syn_cons_emb=None,
|
syn_cons_emb=None,
|
||||||
loss_prosody_enc_spk_rev_classifier=None,
|
loss_prosody_enc_spk_rev_classifier=None,
|
||||||
|
loss_prosody_enc_emo_classifier=None,
|
||||||
loss_text_enc_spk_rev_classifier=None,
|
loss_text_enc_spk_rev_classifier=None,
|
||||||
|
loss_text_enc_emo_classifier=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
|
@ -702,14 +705,27 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
return_dict["loss_consistency_enc"] = loss_enc
|
return_dict["loss_consistency_enc"] = loss_enc
|
||||||
|
|
||||||
if loss_prosody_enc_spk_rev_classifier is not None:
|
if loss_prosody_enc_spk_rev_classifier is not None:
|
||||||
|
loss_prosody_enc_spk_rev_classifier = loss_prosody_enc_spk_rev_classifier * self.speaker_classifier_alpha
|
||||||
loss += loss_prosody_enc_spk_rev_classifier
|
loss += loss_prosody_enc_spk_rev_classifier
|
||||||
return_dict["loss_prosody_enc_spk_rev_classifier"] = loss_prosody_enc_spk_rev_classifier
|
return_dict["loss_prosody_enc_spk_rev_classifier"] = loss_prosody_enc_spk_rev_classifier
|
||||||
|
|
||||||
|
if loss_prosody_enc_emo_classifier is not None:
|
||||||
|
loss_prosody_enc_emo_classifier = loss_prosody_enc_emo_classifier * self.emotion_classifier_alpha
|
||||||
|
loss += loss_prosody_enc_emo_classifier
|
||||||
|
return_dict["loss_prosody_enc_emo_classifier"] = loss_prosody_enc_emo_classifier
|
||||||
|
|
||||||
|
|
||||||
if loss_text_enc_spk_rev_classifier is not None:
|
if loss_text_enc_spk_rev_classifier is not None:
|
||||||
loss_text_enc_spk_rev_classifier = loss_text_enc_spk_rev_classifier * self.text_enc_spk_reversal_loss_alpha
|
loss_text_enc_spk_rev_classifier = loss_text_enc_spk_rev_classifier * self.speaker_classifier_alpha
|
||||||
loss += loss_text_enc_spk_rev_classifier
|
loss += loss_text_enc_spk_rev_classifier
|
||||||
return_dict["loss_text_enc_spk_rev_classifier"] = loss_text_enc_spk_rev_classifier
|
return_dict["loss_text_enc_spk_rev_classifier"] = loss_text_enc_spk_rev_classifier
|
||||||
|
|
||||||
|
if loss_text_enc_emo_classifier is not None:
|
||||||
|
loss_text_enc_emo_classifier = loss_text_enc_emo_classifier * self.emotion_classifier_alpha
|
||||||
|
loss += loss_text_enc_emo_classifier
|
||||||
|
return_dict["loss_text_enc_emo_classifier"] = loss_text_enc_emo_classifier
|
||||||
|
|
||||||
|
|
||||||
# pass losses to the dict
|
# pass losses to the dict
|
||||||
return_dict["loss_gen"] = loss_gen
|
return_dict["loss_gen"] = loss_gen
|
||||||
return_dict["loss_kl"] = loss_kl
|
return_dict["loss_kl"] = loss_kl
|
||||||
|
|
|
@ -541,6 +541,7 @@ class VitsArgs(Coqpit):
|
||||||
emotion_embedding_dim: int = 0
|
emotion_embedding_dim: int = 0
|
||||||
num_emotions: int = 0
|
num_emotions: int = 0
|
||||||
use_text_enc_spk_reversal_classifier: bool = False
|
use_text_enc_spk_reversal_classifier: bool = False
|
||||||
|
use_text_enc_emo_classifier: bool = False
|
||||||
|
|
||||||
# prosody encoder
|
# prosody encoder
|
||||||
use_prosody_encoder: bool = False
|
use_prosody_encoder: bool = False
|
||||||
|
@ -548,6 +549,7 @@ class VitsArgs(Coqpit):
|
||||||
prosody_encoder_num_heads: int = 1
|
prosody_encoder_num_heads: int = 1
|
||||||
prosody_encoder_num_tokens: int = 5
|
prosody_encoder_num_tokens: int = 5
|
||||||
use_prosody_enc_spk_reversal_classifier: bool = False
|
use_prosody_enc_spk_reversal_classifier: bool = False
|
||||||
|
use_prosody_enc_emo_classifier: bool = False
|
||||||
|
|
||||||
use_prosody_conditional_flow_module: bool = False
|
use_prosody_conditional_flow_module: bool = False
|
||||||
|
|
||||||
|
@ -706,6 +708,13 @@ class Vits(BaseTTS):
|
||||||
out_channels=self.num_speakers,
|
out_channels=self.num_speakers,
|
||||||
hidden_channels=256,
|
hidden_channels=256,
|
||||||
)
|
)
|
||||||
|
if self.args.use_prosody_enc_emo_classifier:
|
||||||
|
self.pros_enc_emotion_classifier = ReversalClassifier(
|
||||||
|
in_channels=self.args.prosody_embedding_dim,
|
||||||
|
out_channels=self.num_emotions,
|
||||||
|
hidden_channels=256,
|
||||||
|
reversal=False
|
||||||
|
)
|
||||||
|
|
||||||
if self.args.use_prosody_conditional_flow_module:
|
if self.args.use_prosody_conditional_flow_module:
|
||||||
cond_embedding_dim = 0
|
cond_embedding_dim = 0
|
||||||
|
@ -732,6 +741,14 @@ class Vits(BaseTTS):
|
||||||
hidden_channels=256,
|
hidden_channels=256,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.args.use_text_enc_emo_classifier:
|
||||||
|
self.emo_text_enc_classifier = ReversalClassifier(
|
||||||
|
in_channels=self.args.hidden_channels,
|
||||||
|
out_channels=self.num_emotions,
|
||||||
|
hidden_channels=256,
|
||||||
|
reversal=False
|
||||||
|
)
|
||||||
|
|
||||||
self.waveform_decoder = HifiganGenerator(
|
self.waveform_decoder = HifiganGenerator(
|
||||||
self.args.hidden_channels,
|
self.args.hidden_channels,
|
||||||
1,
|
1,
|
||||||
|
@ -1117,10 +1134,13 @@ class Vits(BaseTTS):
|
||||||
# prosody embedding
|
# prosody embedding
|
||||||
pros_emb = None
|
pros_emb = None
|
||||||
l_pros_speaker = None
|
l_pros_speaker = None
|
||||||
|
l_pros_emotion = None
|
||||||
if self.args.use_prosody_encoder:
|
if self.args.use_prosody_encoder:
|
||||||
pros_emb = self.prosody_encoder(z).transpose(1, 2)
|
pros_emb = self.prosody_encoder(z).transpose(1, 2)
|
||||||
if self.args.use_prosody_enc_spk_reversal_classifier:
|
if self.args.use_prosody_enc_spk_reversal_classifier:
|
||||||
_, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None)
|
_, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None)
|
||||||
|
if self.args.use_prosody_enc_emo_classifier:
|
||||||
|
_, l_pros_emotion = self.pros_enc_emotion_classifier(pros_emb.transpose(1, 2), eid, x_mask=None)
|
||||||
|
|
||||||
x, m_p, logs_p, x_mask = self.text_encoder(
|
x, m_p, logs_p, x_mask = self.text_encoder(
|
||||||
x,
|
x,
|
||||||
|
@ -1134,6 +1154,11 @@ class Vits(BaseTTS):
|
||||||
if self.args.use_prosody_conditional_flow_module:
|
if self.args.use_prosody_conditional_flow_module:
|
||||||
m_p = self.prosody_conditional_module(m_p, x_mask, g=eg if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) else pros_emb)
|
m_p = self.prosody_conditional_module(m_p, x_mask, g=eg if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) else pros_emb)
|
||||||
|
|
||||||
|
# reversal speaker loss to force the encoder to be speaker identity free
|
||||||
|
l_text_emotion = None
|
||||||
|
if self.args.use_text_enc_emo_classifier:
|
||||||
|
_, l_text_emotion = self.emo_text_enc_classifier(m_p.transpose(1, 2), eid, x_mask=None)
|
||||||
|
|
||||||
# reversal speaker loss to force the encoder to be speaker identity free
|
# reversal speaker loss to force the encoder to be speaker identity free
|
||||||
l_text_speaker = None
|
l_text_speaker = None
|
||||||
if self.args.use_text_enc_spk_reversal_classifier:
|
if self.args.use_text_enc_spk_reversal_classifier:
|
||||||
|
@ -1214,7 +1239,9 @@ class Vits(BaseTTS):
|
||||||
"syn_cons_emb": syn_cons_emb,
|
"syn_cons_emb": syn_cons_emb,
|
||||||
"slice_ids": slice_ids,
|
"slice_ids": slice_ids,
|
||||||
"loss_prosody_enc_spk_rev_classifier": l_pros_speaker,
|
"loss_prosody_enc_spk_rev_classifier": l_pros_speaker,
|
||||||
|
"loss_prosody_enc_emo_classifier": l_pros_emotion,
|
||||||
"loss_text_enc_spk_rev_classifier": l_text_speaker,
|
"loss_text_enc_spk_rev_classifier": l_text_speaker,
|
||||||
|
"loss_text_enc_emo_classifier": l_text_emotion,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return outputs
|
return outputs
|
||||||
|
@ -1531,7 +1558,9 @@ class Vits(BaseTTS):
|
||||||
gt_cons_emb=self.model_outputs_cache["gt_cons_emb"],
|
gt_cons_emb=self.model_outputs_cache["gt_cons_emb"],
|
||||||
syn_cons_emb=self.model_outputs_cache["syn_cons_emb"],
|
syn_cons_emb=self.model_outputs_cache["syn_cons_emb"],
|
||||||
loss_prosody_enc_spk_rev_classifier=self.model_outputs_cache["loss_prosody_enc_spk_rev_classifier"],
|
loss_prosody_enc_spk_rev_classifier=self.model_outputs_cache["loss_prosody_enc_spk_rev_classifier"],
|
||||||
|
loss_prosody_enc_emo_classifier=self.model_outputs_cache["loss_prosody_enc_emo_classifier"],
|
||||||
loss_text_enc_spk_rev_classifier=self.model_outputs_cache["loss_text_enc_spk_rev_classifier"],
|
loss_text_enc_spk_rev_classifier=self.model_outputs_cache["loss_text_enc_spk_rev_classifier"],
|
||||||
|
loss_text_enc_emo_classifier=self.model_outputs_cache["loss_text_enc_emo_classifier"],
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.model_outputs_cache, loss_dict
|
return self.model_outputs_cache, loss_dict
|
||||||
|
@ -1737,7 +1766,7 @@ class Vits(BaseTTS):
|
||||||
emotion_embeddings = [emotion_mapping[w]["embedding"] for w in batch["audio_files"]]
|
emotion_embeddings = [emotion_mapping[w]["embedding"] for w in batch["audio_files"]]
|
||||||
emotion_embeddings = torch.FloatTensor(emotion_embeddings)
|
emotion_embeddings = torch.FloatTensor(emotion_embeddings)
|
||||||
|
|
||||||
if self.emotion_manager is not None and self.emotion_manager.embeddings and self.args.use_emotion_embedding:
|
if self.emotion_manager is not None and self.emotion_manager.embeddings and (self.args.use_emotion_embedding or self.args.use_prosody_enc_emo_classifier or self.args.use_text_enc_emo_classifier):
|
||||||
emotion_mapping = self.emotion_manager.embeddings
|
emotion_mapping = self.emotion_manager.embeddings
|
||||||
emotion_names = [emotion_mapping[w]["name"] for w in batch["audio_files"]]
|
emotion_names = [emotion_mapping[w]["name"] for w in batch["audio_files"]]
|
||||||
emotion_ids = [self.emotion_manager.ids[en] for en in emotion_names]
|
emotion_ids = [self.emotion_manager.ids[en] for en in emotion_names]
|
||||||
|
|
|
@ -94,7 +94,7 @@ class EmotionManager(EmbeddingManager):
|
||||||
EmotionEncoder: Emotion encoder object.
|
EmotionEncoder: Emotion encoder object.
|
||||||
"""
|
"""
|
||||||
emotion_manager = None
|
emotion_manager = None
|
||||||
if get_from_config_or_model_args_with_default(config, "use_emotion_embedding", False):
|
if get_from_config_or_model_args_with_default(config, "use_emotion_embedding", False) or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False):
|
||||||
if get_from_config_or_model_args_with_default(config, "emotions_ids_file", None):
|
if get_from_config_or_model_args_with_default(config, "emotions_ids_file", None):
|
||||||
emotion_manager = EmotionManager(
|
emotion_manager = EmotionManager(
|
||||||
emotion_id_file_path=get_from_config_or_model_args_with_default(config, "emotions_ids_file", None)
|
emotion_id_file_path=get_from_config_or_model_args_with_default(config, "emotions_ids_file", None)
|
||||||
|
@ -106,7 +106,7 @@ class EmotionManager(EmbeddingManager):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if get_from_config_or_model_args_with_default(config, "use_external_emotions_embeddings", False):
|
if get_from_config_or_model_args_with_default(config, "use_external_emotions_embeddings", False) or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False):
|
||||||
if get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None):
|
if get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None):
|
||||||
emotion_manager = EmotionManager(
|
emotion_manager = EmotionManager(
|
||||||
embeddings_file_path=get_from_config_or_model_args_with_default(
|
embeddings_file_path=get_from_config_or_model_args_with_default(
|
||||||
|
|
|
@ -293,9 +293,9 @@ class Synthesizer(object):
|
||||||
|
|
||||||
# handle emotion
|
# handle emotion
|
||||||
emotion_embedding, emotion_id = None, None
|
emotion_embedding, emotion_id = None, None
|
||||||
if self.tts_emotions_file or (
|
if not getattr(self.tts_model, "prosody_encoder", False) and (self.tts_emotions_file or (
|
||||||
getattr(self.tts_model, "emotion_manager", None) and getattr(self.tts_model.emotion_manager, "ids", None)
|
getattr(self.tts_model, "emotion_manager", None) and getattr(self.tts_model.emotion_manager, "ids", None)
|
||||||
):
|
)):
|
||||||
if emotion_name and isinstance(emotion_name, str):
|
if emotion_name and isinstance(emotion_name, str):
|
||||||
if getattr(self.tts_config, "use_external_emotions_embeddings", False) or (
|
if getattr(self.tts_config, "use_external_emotions_embeddings", False) or (
|
||||||
getattr(self.tts_config, "model_args", None)
|
getattr(self.tts_config, "model_args", None)
|
||||||
|
|
|
@ -43,6 +43,10 @@ config.model_args.d_vector_dim = 128
|
||||||
# prosody embedding
|
# prosody embedding
|
||||||
config.model_args.use_prosody_encoder = True
|
config.model_args.use_prosody_encoder = True
|
||||||
config.model_args.prosody_embedding_dim = 64
|
config.model_args.prosody_embedding_dim = 64
|
||||||
|
# active classifier
|
||||||
|
config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"
|
||||||
|
config.model_args.use_prosody_enc_emo_classifier = True
|
||||||
|
config.model_args.use_text_enc_emo_classifier = True
|
||||||
|
|
||||||
config.save_json(config_path)
|
config.save_json(config_path)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue