diff --git a/mozilla_voice_tts/speaker_encoder/losses.py b/mozilla_voice_tts/speaker_encoder/losses.py index f4687949..750648e5 100644 --- a/mozilla_voice_tts/speaker_encoder/losses.py +++ b/mozilla_voice_tts/speaker_encoder/losses.py @@ -152,7 +152,7 @@ class AngleProtoLoss(nn.Module): out_positive = x[:, 0, :] num_speakers = out_anchor.size()[0] - cos_sim_matrix = F.cosine_similarity(out_positive.unsqueeze(-1).expand(-1, -1, num_speakers),out_anchor.unsqueeze(-1).expand(-1, -1,num_speakers).transpose(0, 2)) + cos_sim_matrix = F.cosine_similarity(out_positive.unsqueeze(-1).expand(-1, -1, num_speakers), out_anchor.unsqueeze(-1).expand(-1, -1, num_speakers).transpose(0, 2)) torch.clamp(self.w, 1e-6) cos_sim_matrix = cos_sim_matrix * self.w + self.b label = torch.from_numpy(np.asarray(range(0, num_speakers))).to(cos_sim_matrix.device)