Implement Angular Prototypical loss

This commit is contained in:
Edresson 2020-07-31 00:05:08 -03:00 committed by erogol
parent 8b9c951da7
commit f0bcc390d2
1 changed files with 42 additions and 0 deletions

View File

@ -23,6 +23,8 @@ class GE2ELoss(nn.Module):
self.b = nn.Parameter(torch.tensor(init_b))
self.loss_method = loss_method
print('Initialised Generalized End-to-End loss')
assert self.loss_method in ["softmax", "contrast"]
if self.loss_method == "softmax":
@ -119,3 +121,43 @@ class GE2ELoss(nn.Module):
cos_sim_matrix = self.w * cos_sim_matrix + self.b
L = self.embed_loss(dvecs, cos_sim_matrix)
return L.mean()
# adapted from https://github.com/clovaai/voxceleb_trainer/blob/master/loss/angleproto.py
class AngleProtoLoss(nn.Module):
"""
Implementation of the Angular Prototypical loss defined in https://arxiv.org/abs/2003.11982
Accepts an input of size (N, M, D)
where N is the number of speakers in the batch,
M is the number of utterances per speaker,
and D is the dimensionality of the embedding vector
Args:
- init_w (float): defines the initial value of w
- init_b (float): definies the initial value of b
"""
def __init__(self, init_w=10.0, init_b=-5.0):
super(AngleProtoLoss, self).__init__()
# pylint: disable=E1102
self.w = nn.Parameter(torch.tensor(init_w))
# pylint: disable=E1102
self.b = nn.Parameter(torch.tensor(init_b))
self.criterion = torch.nn.CrossEntropyLoss()
self.use_cuda = torch.cuda.is_available()
print('Initialised Angular Prototypical loss')
def forward(self, x):
"""
Calculates the AngleProto loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
"""
out_anchor = torch.mean(x[:,1:,:],1)
out_positive = x[:,0,:]
num_speakers = out_anchor.size()[0]
cos_sim_matrix = F.cosine_similarity(out_positive.unsqueeze(-1).expand(-1,-1,num_speakers),out_anchor.unsqueeze(-1).expand(-1,-1,num_speakers).transpose(0,2))
torch.clamp(self.w, 1e-6)
cos_sim_matrix = cos_sim_matrix * self.w + self.b
label = torch.from_numpy(np.asarray(range(0,num_speakers)))
if self.use_cuda:
label = label.cuda()
L = self.criterion(cos_sim_matrix, label)
return L