diff --git a/mozilla_voice_tts/speaker_encoder/losses.py b/mozilla_voice_tts/speaker_encoder/losses.py index 9065ccfd..f4687949 100644 --- a/mozilla_voice_tts/speaker_encoder/losses.py +++ b/mozilla_voice_tts/speaker_encoder/losses.py @@ -124,14 +124,14 @@ class GE2ELoss(nn.Module): # adapted from https://github.com/clovaai/voxceleb_trainer/blob/master/loss/angleproto.py class AngleProtoLoss(nn.Module): - """ + """ Implementation of the Angular Prototypical loss defined in https://arxiv.org/abs/2003.11982 Accepts an input of size (N, M, D) where N is the number of speakers in the batch, M is the number of utterances per speaker, and D is the dimensionality of the embedding vector Args: - - init_w (float): defines the initial value of w + - init_w (float): defines the initial value of w - init_b (float): definies the initial value of b """ def __init__(self, init_w=10.0, init_b=-5.0): @@ -148,13 +148,13 @@ class AngleProtoLoss(nn.Module): """ Calculates the AngleProto loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats) """ - out_anchor = torch.mean(x[:,1:,:],1) - out_positive = x[:,0,:] + out_anchor = torch.mean(x[:, 1:, :], 1) + out_positive = x[:, 0, :] num_speakers = out_anchor.size()[0] - cos_sim_matrix = F.cosine_similarity(out_positive.unsqueeze(-1).expand(-1,-1,num_speakers),out_anchor.unsqueeze(-1).expand(-1,-1,num_speakers).transpose(0,2)) + cos_sim_matrix = F.cosine_similarity(out_positive.unsqueeze(-1).expand(-1, -1, num_speakers),out_anchor.unsqueeze(-1).expand(-1, -1,num_speakers).transpose(0, 2)) torch.clamp(self.w, 1e-6) cos_sim_matrix = cos_sim_matrix * self.w + self.b - label = torch.from_numpy(np.asarray(range(0,num_speakers))).to(cos_sim_matrix.device) + label = torch.from_numpy(np.asarray(range(0, num_speakers))).to(cos_sim_matrix.device) L = self.criterion(cos_sim_matrix, label) return L \ No newline at end of file