mirror of https://github.com/coqui-ai/TTS.git
fix Lint check errors
This commit is contained in:
parent
f37159c135
commit
f3b8ef4272
|
@ -124,14 +124,14 @@ class GE2ELoss(nn.Module):
|
|||
|
||||
# adapted from https://github.com/clovaai/voxceleb_trainer/blob/master/loss/angleproto.py
|
||||
class AngleProtoLoss(nn.Module):
|
||||
"""
|
||||
"""
|
||||
Implementation of the Angular Prototypical loss defined in https://arxiv.org/abs/2003.11982
|
||||
Accepts an input of size (N, M, D)
|
||||
where N is the number of speakers in the batch,
|
||||
M is the number of utterances per speaker,
|
||||
and D is the dimensionality of the embedding vector
|
||||
Args:
|
||||
- init_w (float): defines the initial value of w
|
||||
- init_w (float): defines the initial value of w
|
||||
- init_b (float): definies the initial value of b
|
||||
"""
|
||||
def __init__(self, init_w=10.0, init_b=-5.0):
|
||||
|
@ -148,13 +148,13 @@ class AngleProtoLoss(nn.Module):
|
|||
"""
|
||||
Calculates the AngleProto loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
|
||||
"""
|
||||
out_anchor = torch.mean(x[:,1:,:],1)
|
||||
out_positive = x[:,0,:]
|
||||
out_anchor = torch.mean(x[:, 1:, :], 1)
|
||||
out_positive = x[:, 0, :]
|
||||
num_speakers = out_anchor.size()[0]
|
||||
|
||||
cos_sim_matrix = F.cosine_similarity(out_positive.unsqueeze(-1).expand(-1,-1,num_speakers),out_anchor.unsqueeze(-1).expand(-1,-1,num_speakers).transpose(0,2))
|
||||
cos_sim_matrix = F.cosine_similarity(out_positive.unsqueeze(-1).expand(-1, -1, num_speakers),out_anchor.unsqueeze(-1).expand(-1, -1,num_speakers).transpose(0, 2))
|
||||
torch.clamp(self.w, 1e-6)
|
||||
cos_sim_matrix = cos_sim_matrix * self.w + self.b
|
||||
label = torch.from_numpy(np.asarray(range(0,num_speakers))).to(cos_sim_matrix.device)
|
||||
label = torch.from_numpy(np.asarray(range(0, num_speakers))).to(cos_sim_matrix.device)
|
||||
L = self.criterion(cos_sim_matrix, label)
|
||||
return L
|
Loading…
Reference in New Issue