diff --git a/TTS/tts/layers/generic/classifier.py b/TTS/tts/layers/generic/classifier.py new file mode 100644 index 00000000..b0136625 --- /dev/null +++ b/TTS/tts/layers/generic/classifier.py @@ -0,0 +1,61 @@ +import torch +from torch import nn + +class GradientReversalFunction(torch.autograd.Function): + """Revert gradient without any further input modification. + Adapted from: https://github.com/Tomiinek/Multilingual_Text_to_Speech/""" + + @staticmethod + def forward(ctx, x, l, c): + ctx.l = l + ctx.c = c + return x.view_as(x) + + @staticmethod + def backward(ctx, grad_output): + grad_output = grad_output.clamp(-ctx.c, ctx.c) + return ctx.l * grad_output.neg(), None, None + + +class ReversalClassifier(nn.Module): + """Adversarial classifier with a gradient reversal layer. + Adapted from: https://github.com/Tomiinek/Multilingual_Text_to_Speech/ + + Args: + in_channels (int): Number of input tensor channels. + out_channels (int): Number of output tensor channels (Number of classes). + hidden_channels (int): Number of hidden channels. + gradient_clipping_bound (float): Maximal value of the gradient which flows from this module. Default: 0.25 + scale_factor (float): Scale multiplier of the reversed gradientts. Default: 1.0 + """ + + def __init__(self, in_channels, out_channels, hidden_channels, gradient_clipping_bounds=0.25, scale_factor=1.0): + super(ReversalClassifier, self).__init__() + self._lambda = scale_factor + self._clipping = gradient_clipping_bounds + self._out_channels = out_channels + self._classifier = nn.Sequential( + nn.Linear(in_channels, hidden_channels), + nn.ReLU(), + nn.Linear(hidden_channels, out_channels) + ) + self.test = nn.Linear(in_channels, hidden_channels) + def forward(self, x, labels, x_mask=None): + x = GradientReversalFunction.apply(x, self._lambda, self._clipping) + x = self._classifier(x) + loss = self.loss(labels, x, x_mask) + return x, loss + + @staticmethod + def loss(labels, predictions, x_mask): + ignore_index = -100 + if x_mask is None: + x_mask = torch.Tensor([predictions.size(1)]).repeat(predictions.size(0)).int().to(predictions.device) + + ml = torch.max(x_mask) + input_mask = torch.arange(ml, device=predictions.device)[None, :] < x_mask[:, None] + + target = labels.repeat(ml.int().item(), 1).transpose(0,1) + target[~input_mask] = ignore_index + + return nn.functional.cross_entropy(predictions.transpose(1,2), target, ignore_index=ignore_index) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 87418473..1d47745c 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -662,6 +662,7 @@ class VitsGeneratorLoss(nn.Module): use_encoder_consistency_loss=False, gt_cons_emb=None, syn_cons_emb=None, + loss_spk_reversal_classifier=None, ): """ Shapes: @@ -696,6 +697,11 @@ class VitsGeneratorLoss(nn.Module): loss_enc = self.cosine_similarity_loss(gt_cons_emb, syn_cons_emb) * self.consistency_loss_alpha loss = loss + loss_enc return_dict["loss_consistency_enc"] = loss_enc + + if loss_spk_reversal_classifier is not None: + loss += loss_spk_reversal_classifier + return_dict["loss_spk_reversal_classifier"] = loss_spk_reversal_classifier + # pass losses to the dict return_dict["loss_gen"] = loss_gen return_dict["loss_kl"] = loss_kl diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 78750f4c..510f4091 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -34,6 +34,7 @@ from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.utils.generic_utils import plot_results from TTS.tts.layers.tacotron.gst_layers import GST +from TTS.tts.layers.generic.classifier import ReversalClassifier ############################## # IO / Feature extraction @@ -688,6 +689,11 @@ class Vits(BaseTTS): num_style_tokens=5, gst_embedding_dim=self.args.prosody_embedding_dim, ) + self.speaker_reversal_classifier = ReversalClassifier( + in_channels=self.args.prosody_embedding_dim, + out_channels=self.num_speakers, + hidden_channels=256, + ) self.waveform_decoder = HifiganGenerator( self.args.hidden_channels, @@ -1081,8 +1087,10 @@ class Vits(BaseTTS): # prosody embedding pros_emb = None + l_pros_speaker = None if self.args.use_prosody_encoder: pros_emb = self.prosody_encoder(z).transpose(1, 2) + _, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None) x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg, pros_emb=pros_emb) @@ -1160,6 +1168,7 @@ class Vits(BaseTTS): "gt_cons_emb": gt_cons_emb, "syn_cons_emb": syn_cons_emb, "slice_ids": slice_ids, + "loss_spk_reversal_classifier": l_pros_speaker, } ) return outputs @@ -1445,6 +1454,7 @@ class Vits(BaseTTS): or self.args.use_emotion_encoder_as_loss, gt_cons_emb=self.model_outputs_cache["gt_cons_emb"], syn_cons_emb=self.model_outputs_cache["syn_cons_emb"], + loss_spk_reversal_classifier=self.model_outputs_cache["loss_spk_reversal_classifier"] ) return self.model_outputs_cache, loss_dict @@ -1612,7 +1622,7 @@ class Vits(BaseTTS): emotion_ids = None # get numerical speaker ids from speaker names - if self.speaker_manager is not None and self.speaker_manager.ids and self.args.use_speaker_embedding: + if self.speaker_manager is not None and self.speaker_manager.ids and (self.args.use_speaker_embedding or self.args.use_prosody_encoder): speaker_ids = [self.speaker_manager.ids[sn] for sn in batch["speaker_names"]] if speaker_ids is not None: diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index 77b61a8d..c16f71d0 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -95,7 +95,7 @@ class SpeakerManager(EmbeddingManager): SpeakerEncoder: Speaker encoder object. """ speaker_manager = None - if get_from_config_or_model_args_with_default(config, "use_speaker_embedding", False): + if get_from_config_or_model_args_with_default(config, "use_speaker_embedding", False) or get_from_config_or_model_args_with_default(config, "use_prosody_encoder", False): if samples: speaker_manager = SpeakerManager(data_items=samples) if get_from_config_or_model_args_with_default(config, "speaker_file", None):