use torch for AngleProtoLoss

This commit is contained in:
Eren Gölge 2021-03-03 15:37:38 +01:00 committed by Eren Gölge
parent 2b3e12ea49
commit 892c3c3623
1 changed files with 1 additions and 1 deletions

View File

@ -155,6 +155,6 @@ class AngleProtoLoss(nn.Module):
cos_sim_matrix = F.cosine_similarity(out_positive.unsqueeze(-1).expand(-1, -1, num_speakers), out_anchor.unsqueeze(-1).expand(-1, -1, num_speakers).transpose(0, 2))
torch.clamp(self.w, 1e-6)
cos_sim_matrix = cos_sim_matrix * self.w + self.b
label = torch.from_numpy(np.asarray(range(0, num_speakers))).to(cos_sim_matrix.device)
label = torch.arange(num_speakers).to(cos_sim_matrix.device)
L = self.criterion(cos_sim_matrix, label)
return L