mirror of https://github.com/coqui-ai/TTS.git
use torch for AngleProtoLoss
This commit is contained in:
parent
2b3e12ea49
commit
892c3c3623
|
@ -155,6 +155,6 @@ class AngleProtoLoss(nn.Module):
|
|||
cos_sim_matrix = F.cosine_similarity(out_positive.unsqueeze(-1).expand(-1, -1, num_speakers), out_anchor.unsqueeze(-1).expand(-1, -1, num_speakers).transpose(0, 2))
|
||||
torch.clamp(self.w, 1e-6)
|
||||
cos_sim_matrix = cos_sim_matrix * self.w + self.b
|
||||
label = torch.from_numpy(np.asarray(range(0, num_speakers))).to(cos_sim_matrix.device)
|
||||
label = torch.arange(num_speakers).to(cos_sim_matrix.device)
|
||||
L = self.criterion(cos_sim_matrix, label)
|
||||
return L
|
||||
|
|
Loading…
Reference in New Issue