diff --git a/mozilla_voice_tts/bin/train_encoder.py b/mozilla_voice_tts/bin/train_encoder.py index d612ac6e..c7c2e647 100644 --- a/mozilla_voice_tts/bin/train_encoder.py +++ b/mozilla_voice_tts/bin/train_encoder.py @@ -100,7 +100,7 @@ def train(model, criterion, optimizer, scheduler, ap, global_step): if global_step % c.steps_plot_stats == 0: # Plot Training Epoch Stats train_stats = { - "GE2Eloss": avg_loss, + "loss": avg_loss, "lr": current_lr, "grad_norm": grad_norm, "step_time": step_time @@ -140,7 +140,13 @@ def main(args): # pylint: disable=redefined-outer-name lstm_dim=384, num_lstm_layers=3) optimizer = RAdam(model.parameters(), lr=c.lr) - criterion = GE2ELoss(loss_method='softmax') + + if c.loss == "ge2e": + criterion = GE2ELoss(loss_method='softmax') + elif c.loss == "angleproto": + criterion = AngleProtoLoss() + else: + raise Exception("The %s not is a loss supported" %c.loss) if args.restore_path: checkpoint = torch.load(args.restore_path) @@ -186,7 +192,6 @@ def main(args): # pylint: disable=redefined-outer-name _, global_step = train(model, criterion, optimizer, scheduler, ap, global_step) - if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( diff --git a/mozilla_voice_tts/speaker_encoder/config.json b/mozilla_voice_tts/speaker_encoder/config.json index 0d0f8f68..5f72135f 100644 --- a/mozilla_voice_tts/speaker_encoder/config.json +++ b/mozilla_voice_tts/speaker_encoder/config.json @@ -21,6 +21,7 @@ "do_trim_silence": false // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true) }, "reinit_layers": [], + "loss": "angleproto", // "ge2e" to use Generalized End-to-End loss and "angleproto" to use Angular Prototypical loss (new SOTA) "grad_clip": 3.0, // upper limit for gradients for clipping. "epochs": 1000, // total number of epochs to train. "lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate. diff --git a/mozilla_voice_tts/speaker_encoder/generic_utils.py b/mozilla_voice_tts/speaker_encoder/generic_utils.py index f649ceb9..bc72c91c 100644 --- a/mozilla_voice_tts/speaker_encoder/generic_utils.py +++ b/mozilla_voice_tts/speaker_encoder/generic_utils.py @@ -15,7 +15,7 @@ def save_checkpoint(model, optimizer, model_loss, out_path, 'optimizer': optimizer.state_dict() if optimizer is not None else None, 'step': current_step, 'epoch': epoch, - 'GE2Eloss': model_loss, + 'loss': model_loss, 'date': datetime.date.today().strftime("%B %d, %Y"), } torch.save(state, checkpoint_path) @@ -29,7 +29,7 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path, 'model': new_state_dict, 'optimizer': optimizer.state_dict(), 'step': current_step, - 'GE2Eloss': model_loss, + 'loss': model_loss, 'date': datetime.date.today().strftime("%B %d, %Y"), } best_loss = model_loss diff --git a/mozilla_voice_tts/speaker_encoder/losses.py b/mozilla_voice_tts/speaker_encoder/losses.py new file mode 100644 index 00000000..7feced64 --- /dev/null +++ b/mozilla_voice_tts/speaker_encoder/losses.py @@ -0,0 +1,160 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +# adapted from https://github.com/cvqluu/GE2E-Loss +class GE2ELoss(nn.Module): + def __init__(self, init_w=10.0, init_b=-5.0, loss_method="softmax"): + """ + Implementation of the Generalized End-to-End loss defined in https://arxiv.org/abs/1710.10467 [1] + Accepts an input of size (N, M, D) + where N is the number of speakers in the batch, + M is the number of utterances per speaker, + and D is the dimensionality of the embedding vector (e.g. d-vector) + Args: + - init_w (float): defines the initial value of w in Equation (5) of [1] + - init_b (float): definies the initial value of b in Equation (5) of [1] + """ + super(GE2ELoss, self).__init__() + # pylint: disable=E1102 + self.w = nn.Parameter(torch.tensor(init_w)) + # pylint: disable=E1102 + self.b = nn.Parameter(torch.tensor(init_b)) + self.loss_method = loss_method + + print('Initialised Generalized End-to-End loss') + + assert self.loss_method in ["softmax", "contrast"] + + if self.loss_method == "softmax": + self.embed_loss = self.embed_loss_softmax + if self.loss_method == "contrast": + self.embed_loss = self.embed_loss_contrast + + # pylint: disable=R0201 + def calc_new_centroids(self, dvecs, centroids, spkr, utt): + """ + Calculates the new centroids excluding the reference utterance + """ + excl = torch.cat((dvecs[spkr, :utt], dvecs[spkr, utt + 1 :])) + excl = torch.mean(excl, 0) + new_centroids = [] + for i, centroid in enumerate(centroids): + if i == spkr: + new_centroids.append(excl) + else: + new_centroids.append(centroid) + return torch.stack(new_centroids) + + def calc_cosine_sim(self, dvecs, centroids): + """ + Make the cosine similarity matrix with dims (N,M,N) + """ + cos_sim_matrix = [] + for spkr_idx, speaker in enumerate(dvecs): + cs_row = [] + for utt_idx, utterance in enumerate(speaker): + new_centroids = self.calc_new_centroids( + dvecs, centroids, spkr_idx, utt_idx + ) + # vector based cosine similarity for speed + cs_row.append( + torch.clamp( + torch.mm( + utterance.unsqueeze(1).transpose(0, 1), + new_centroids.transpose(0, 1), + ) + / (torch.norm(utterance) * torch.norm(new_centroids, dim=1)), + 1e-6, + ) + ) + cs_row = torch.cat(cs_row, dim=0) + cos_sim_matrix.append(cs_row) + return torch.stack(cos_sim_matrix) + + # pylint: disable=R0201 + def embed_loss_softmax(self, dvecs, cos_sim_matrix): + """ + Calculates the loss on each embedding $L(e_{ji})$ by taking softmax + """ + N, M, _ = dvecs.shape + L = [] + for j in range(N): + L_row = [] + for i in range(M): + L_row.append(-F.log_softmax(cos_sim_matrix[j, i], 0)[j]) + L_row = torch.stack(L_row) + L.append(L_row) + return torch.stack(L) + + # pylint: disable=R0201 + def embed_loss_contrast(self, dvecs, cos_sim_matrix): + """ + Calculates the loss on each embedding $L(e_{ji})$ by contrast loss with closest centroid + """ + N, M, _ = dvecs.shape + L = [] + for j in range(N): + L_row = [] + for i in range(M): + centroids_sigmoids = torch.sigmoid(cos_sim_matrix[j, i]) + excl_centroids_sigmoids = torch.cat( + (centroids_sigmoids[:j], centroids_sigmoids[j + 1 :]) + ) + L_row.append( + 1.0 + - torch.sigmoid(cos_sim_matrix[j, i, j]) + + torch.max(excl_centroids_sigmoids) + ) + L_row = torch.stack(L_row) + L.append(L_row) + return torch.stack(L) + + def forward(self, dvecs): + """ + Calculates the GE2E loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats) + """ + centroids = torch.mean(dvecs, 1) + cos_sim_matrix = self.calc_cosine_sim(dvecs, centroids) + torch.clamp(self.w, 1e-6) + cos_sim_matrix = self.w * cos_sim_matrix + self.b + L = self.embed_loss(dvecs, cos_sim_matrix) + return L.mean() + +# adapted from https://github.com/clovaai/voxceleb_trainer/blob/master/loss/angleproto.py +class AngleProtoLoss(nn.Module): + """ + Implementation of the Angular Prototypical loss defined in https://arxiv.org/abs/2003.11982 + Accepts an input of size (N, M, D) + where N is the number of speakers in the batch, + M is the number of utterances per speaker, + and D is the dimensionality of the embedding vector + Args: + - init_w (float): defines the initial value of w + - init_b (float): definies the initial value of b + """ + def __init__(self, init_w=10.0, init_b=-5.0): + super(AngleProtoLoss, self).__init__() + # pylint: disable=E1102 + self.w = nn.Parameter(torch.tensor(init_w)) + # pylint: disable=E1102 + self.b = nn.Parameter(torch.tensor(init_b)) + self.criterion = torch.nn.CrossEntropyLoss() + + print('Initialised Angular Prototypical loss') + + def forward(self, x): + """ + Calculates the AngleProto loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats) + """ + out_anchor = torch.mean(x[:,1:,:],1) + out_positive = x[:,0,:] + num_speakers = out_anchor.size()[0] + + cos_sim_matrix = F.cosine_similarity(out_positive.unsqueeze(-1).expand(-1,-1,num_speakers),out_anchor.unsqueeze(-1).expand(-1,-1,num_speakers).transpose(0,2)) + torch.clamp(self.w, 1e-6) + cos_sim_matrix = cos_sim_matrix * self.w + self.b + label = torch.from_numpy(np.asarray(range(0,num_speakers))).to(cos_sim_matrix.device) + L = self.criterion(cos_sim_matrix, label) + return L \ No newline at end of file