mirror of https://github.com/coqui-ai/TTS.git
Implement Angular Prototypical loss
This commit is contained in:
parent
8b9c951da7
commit
f0bcc390d2
|
@ -23,6 +23,8 @@ class GE2ELoss(nn.Module):
|
||||||
self.b = nn.Parameter(torch.tensor(init_b))
|
self.b = nn.Parameter(torch.tensor(init_b))
|
||||||
self.loss_method = loss_method
|
self.loss_method = loss_method
|
||||||
|
|
||||||
|
print('Initialised Generalized End-to-End loss')
|
||||||
|
|
||||||
assert self.loss_method in ["softmax", "contrast"]
|
assert self.loss_method in ["softmax", "contrast"]
|
||||||
|
|
||||||
if self.loss_method == "softmax":
|
if self.loss_method == "softmax":
|
||||||
|
@ -119,3 +121,43 @@ class GE2ELoss(nn.Module):
|
||||||
cos_sim_matrix = self.w * cos_sim_matrix + self.b
|
cos_sim_matrix = self.w * cos_sim_matrix + self.b
|
||||||
L = self.embed_loss(dvecs, cos_sim_matrix)
|
L = self.embed_loss(dvecs, cos_sim_matrix)
|
||||||
return L.mean()
|
return L.mean()
|
||||||
|
|
||||||
|
# adapted from https://github.com/clovaai/voxceleb_trainer/blob/master/loss/angleproto.py
|
||||||
|
class AngleProtoLoss(nn.Module):
|
||||||
|
"""
|
||||||
|
Implementation of the Angular Prototypical loss defined in https://arxiv.org/abs/2003.11982
|
||||||
|
Accepts an input of size (N, M, D)
|
||||||
|
where N is the number of speakers in the batch,
|
||||||
|
M is the number of utterances per speaker,
|
||||||
|
and D is the dimensionality of the embedding vector
|
||||||
|
Args:
|
||||||
|
- init_w (float): defines the initial value of w
|
||||||
|
- init_b (float): definies the initial value of b
|
||||||
|
"""
|
||||||
|
def __init__(self, init_w=10.0, init_b=-5.0):
|
||||||
|
super(AngleProtoLoss, self).__init__()
|
||||||
|
# pylint: disable=E1102
|
||||||
|
self.w = nn.Parameter(torch.tensor(init_w))
|
||||||
|
# pylint: disable=E1102
|
||||||
|
self.b = nn.Parameter(torch.tensor(init_b))
|
||||||
|
self.criterion = torch.nn.CrossEntropyLoss()
|
||||||
|
self.use_cuda = torch.cuda.is_available()
|
||||||
|
|
||||||
|
print('Initialised Angular Prototypical loss')
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Calculates the AngleProto loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
|
||||||
|
"""
|
||||||
|
out_anchor = torch.mean(x[:,1:,:],1)
|
||||||
|
out_positive = x[:,0,:]
|
||||||
|
num_speakers = out_anchor.size()[0]
|
||||||
|
|
||||||
|
cos_sim_matrix = F.cosine_similarity(out_positive.unsqueeze(-1).expand(-1,-1,num_speakers),out_anchor.unsqueeze(-1).expand(-1,-1,num_speakers).transpose(0,2))
|
||||||
|
torch.clamp(self.w, 1e-6)
|
||||||
|
cos_sim_matrix = cos_sim_matrix * self.w + self.b
|
||||||
|
label = torch.from_numpy(np.asarray(range(0,num_speakers)))
|
||||||
|
if self.use_cuda:
|
||||||
|
label = label.cuda()
|
||||||
|
L = self.criterion(cos_sim_matrix, label)
|
||||||
|
return L
|
Loading…
Reference in New Issue