Add reversal classifier loss

This commit is contained in:
Edresson Casanova 2022-04-18 21:09:59 -03:00
parent 004862a79b
commit d49c6ab72f
4 changed files with 79 additions and 2 deletions

View File

@ -0,0 +1,61 @@
import torch
from torch import nn
class GradientReversalFunction(torch.autograd.Function):
"""Revert gradient without any further input modification.
Adapted from: https://github.com/Tomiinek/Multilingual_Text_to_Speech/"""
@staticmethod
def forward(ctx, x, l, c):
ctx.l = l
ctx.c = c
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
grad_output = grad_output.clamp(-ctx.c, ctx.c)
return ctx.l * grad_output.neg(), None, None
class ReversalClassifier(nn.Module):
"""Adversarial classifier with a gradient reversal layer.
Adapted from: https://github.com/Tomiinek/Multilingual_Text_to_Speech/
Args:
in_channels (int): Number of input tensor channels.
out_channels (int): Number of output tensor channels (Number of classes).
hidden_channels (int): Number of hidden channels.
gradient_clipping_bound (float): Maximal value of the gradient which flows from this module. Default: 0.25
scale_factor (float): Scale multiplier of the reversed gradientts. Default: 1.0
"""
def __init__(self, in_channels, out_channels, hidden_channels, gradient_clipping_bounds=0.25, scale_factor=1.0):
super(ReversalClassifier, self).__init__()
self._lambda = scale_factor
self._clipping = gradient_clipping_bounds
self._out_channels = out_channels
self._classifier = nn.Sequential(
nn.Linear(in_channels, hidden_channels),
nn.ReLU(),
nn.Linear(hidden_channels, out_channels)
)
self.test = nn.Linear(in_channels, hidden_channels)
def forward(self, x, labels, x_mask=None):
x = GradientReversalFunction.apply(x, self._lambda, self._clipping)
x = self._classifier(x)
loss = self.loss(labels, x, x_mask)
return x, loss
@staticmethod
def loss(labels, predictions, x_mask):
ignore_index = -100
if x_mask is None:
x_mask = torch.Tensor([predictions.size(1)]).repeat(predictions.size(0)).int().to(predictions.device)
ml = torch.max(x_mask)
input_mask = torch.arange(ml, device=predictions.device)[None, :] < x_mask[:, None]
target = labels.repeat(ml.int().item(), 1).transpose(0,1)
target[~input_mask] = ignore_index
return nn.functional.cross_entropy(predictions.transpose(1,2), target, ignore_index=ignore_index)

View File

@ -662,6 +662,7 @@ class VitsGeneratorLoss(nn.Module):
use_encoder_consistency_loss=False,
gt_cons_emb=None,
syn_cons_emb=None,
loss_spk_reversal_classifier=None,
):
"""
Shapes:
@ -696,6 +697,11 @@ class VitsGeneratorLoss(nn.Module):
loss_enc = self.cosine_similarity_loss(gt_cons_emb, syn_cons_emb) * self.consistency_loss_alpha
loss = loss + loss_enc
return_dict["loss_consistency_enc"] = loss_enc
if loss_spk_reversal_classifier is not None:
loss += loss_spk_reversal_classifier
return_dict["loss_spk_reversal_classifier"] = loss_spk_reversal_classifier
# pass losses to the dict
return_dict["loss_gen"] = loss_gen
return_dict["loss_kl"] = loss_kl

View File

@ -34,6 +34,7 @@ from TTS.vocoder.models.hifigan_generator import HifiganGenerator
from TTS.vocoder.utils.generic_utils import plot_results
from TTS.tts.layers.tacotron.gst_layers import GST
from TTS.tts.layers.generic.classifier import ReversalClassifier
##############################
# IO / Feature extraction
@ -688,6 +689,11 @@ class Vits(BaseTTS):
num_style_tokens=5,
gst_embedding_dim=self.args.prosody_embedding_dim,
)
self.speaker_reversal_classifier = ReversalClassifier(
in_channels=self.args.prosody_embedding_dim,
out_channels=self.num_speakers,
hidden_channels=256,
)
self.waveform_decoder = HifiganGenerator(
self.args.hidden_channels,
@ -1081,8 +1087,10 @@ class Vits(BaseTTS):
# prosody embedding
pros_emb = None
l_pros_speaker = None
if self.args.use_prosody_encoder:
pros_emb = self.prosody_encoder(z).transpose(1, 2)
_, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None)
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg, pros_emb=pros_emb)
@ -1160,6 +1168,7 @@ class Vits(BaseTTS):
"gt_cons_emb": gt_cons_emb,
"syn_cons_emb": syn_cons_emb,
"slice_ids": slice_ids,
"loss_spk_reversal_classifier": l_pros_speaker,
}
)
return outputs
@ -1445,6 +1454,7 @@ class Vits(BaseTTS):
or self.args.use_emotion_encoder_as_loss,
gt_cons_emb=self.model_outputs_cache["gt_cons_emb"],
syn_cons_emb=self.model_outputs_cache["syn_cons_emb"],
loss_spk_reversal_classifier=self.model_outputs_cache["loss_spk_reversal_classifier"]
)
return self.model_outputs_cache, loss_dict
@ -1612,7 +1622,7 @@ class Vits(BaseTTS):
emotion_ids = None
# get numerical speaker ids from speaker names
if self.speaker_manager is not None and self.speaker_manager.ids and self.args.use_speaker_embedding:
if self.speaker_manager is not None and self.speaker_manager.ids and (self.args.use_speaker_embedding or self.args.use_prosody_encoder):
speaker_ids = [self.speaker_manager.ids[sn] for sn in batch["speaker_names"]]
if speaker_ids is not None:

View File

@ -95,7 +95,7 @@ class SpeakerManager(EmbeddingManager):
SpeakerEncoder: Speaker encoder object.
"""
speaker_manager = None
if get_from_config_or_model_args_with_default(config, "use_speaker_embedding", False):
if get_from_config_or_model_args_with_default(config, "use_speaker_embedding", False) or get_from_config_or_model_args_with_default(config, "use_prosody_encoder", False):
if samples:
speaker_manager = SpeakerManager(data_items=samples)
if get_from_config_or_model_args_with_default(config, "speaker_file", None):