mirror of https://github.com/coqui-ai/TTS.git
add suport for AngleProto loss
This commit is contained in:
parent
f0bcc390d2
commit
bc09ca8111
|
@ -100,7 +100,7 @@ def train(model, criterion, optimizer, scheduler, ap, global_step):
|
||||||
if global_step % c.steps_plot_stats == 0:
|
if global_step % c.steps_plot_stats == 0:
|
||||||
# Plot Training Epoch Stats
|
# Plot Training Epoch Stats
|
||||||
train_stats = {
|
train_stats = {
|
||||||
"GE2Eloss": avg_loss,
|
"loss": avg_loss,
|
||||||
"lr": current_lr,
|
"lr": current_lr,
|
||||||
"grad_norm": grad_norm,
|
"grad_norm": grad_norm,
|
||||||
"step_time": step_time
|
"step_time": step_time
|
||||||
|
@ -140,7 +140,13 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
lstm_dim=384,
|
lstm_dim=384,
|
||||||
num_lstm_layers=3)
|
num_lstm_layers=3)
|
||||||
optimizer = RAdam(model.parameters(), lr=c.lr)
|
optimizer = RAdam(model.parameters(), lr=c.lr)
|
||||||
|
|
||||||
|
if c.loss == "ge2e":
|
||||||
criterion = GE2ELoss(loss_method='softmax')
|
criterion = GE2ELoss(loss_method='softmax')
|
||||||
|
elif c.loss == "angleproto":
|
||||||
|
criterion = AngleProtoLoss()
|
||||||
|
else:
|
||||||
|
raise Exception("The %s not is a loss supported" %c.loss)
|
||||||
|
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
checkpoint = torch.load(args.restore_path)
|
checkpoint = torch.load(args.restore_path)
|
||||||
|
@ -186,7 +192,6 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
_, global_step = train(model, criterion, optimizer, scheduler, ap,
|
_, global_step = train(model, criterion, optimizer, scheduler, ap,
|
||||||
global_step)
|
global_step)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
"do_trim_silence": false // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
|
"do_trim_silence": false // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
|
||||||
},
|
},
|
||||||
"reinit_layers": [],
|
"reinit_layers": [],
|
||||||
|
"loss": "angleproto", // "ge2e" to use Generalized End-to-End loss and "angleproto" to use Angular Prototypical loss (new SOTA)
|
||||||
"grad_clip": 3.0, // upper limit for gradients for clipping.
|
"grad_clip": 3.0, // upper limit for gradients for clipping.
|
||||||
"epochs": 1000, // total number of epochs to train.
|
"epochs": 1000, // total number of epochs to train.
|
||||||
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
||||||
|
|
|
@ -15,7 +15,7 @@ def save_checkpoint(model, optimizer, model_loss, out_path,
|
||||||
'optimizer': optimizer.state_dict() if optimizer is not None else None,
|
'optimizer': optimizer.state_dict() if optimizer is not None else None,
|
||||||
'step': current_step,
|
'step': current_step,
|
||||||
'epoch': epoch,
|
'epoch': epoch,
|
||||||
'GE2Eloss': model_loss,
|
'loss': model_loss,
|
||||||
'date': datetime.date.today().strftime("%B %d, %Y"),
|
'date': datetime.date.today().strftime("%B %d, %Y"),
|
||||||
}
|
}
|
||||||
torch.save(state, checkpoint_path)
|
torch.save(state, checkpoint_path)
|
||||||
|
@ -29,7 +29,7 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path,
|
||||||
'model': new_state_dict,
|
'model': new_state_dict,
|
||||||
'optimizer': optimizer.state_dict(),
|
'optimizer': optimizer.state_dict(),
|
||||||
'step': current_step,
|
'step': current_step,
|
||||||
'GE2Eloss': model_loss,
|
'loss': model_loss,
|
||||||
'date': datetime.date.today().strftime("%B %d, %Y"),
|
'date': datetime.date.today().strftime("%B %d, %Y"),
|
||||||
}
|
}
|
||||||
best_loss = model_loss
|
best_loss = model_loss
|
||||||
|
|
|
@ -0,0 +1,160 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# adapted from https://github.com/cvqluu/GE2E-Loss
|
||||||
|
class GE2ELoss(nn.Module):
|
||||||
|
def __init__(self, init_w=10.0, init_b=-5.0, loss_method="softmax"):
|
||||||
|
"""
|
||||||
|
Implementation of the Generalized End-to-End loss defined in https://arxiv.org/abs/1710.10467 [1]
|
||||||
|
Accepts an input of size (N, M, D)
|
||||||
|
where N is the number of speakers in the batch,
|
||||||
|
M is the number of utterances per speaker,
|
||||||
|
and D is the dimensionality of the embedding vector (e.g. d-vector)
|
||||||
|
Args:
|
||||||
|
- init_w (float): defines the initial value of w in Equation (5) of [1]
|
||||||
|
- init_b (float): definies the initial value of b in Equation (5) of [1]
|
||||||
|
"""
|
||||||
|
super(GE2ELoss, self).__init__()
|
||||||
|
# pylint: disable=E1102
|
||||||
|
self.w = nn.Parameter(torch.tensor(init_w))
|
||||||
|
# pylint: disable=E1102
|
||||||
|
self.b = nn.Parameter(torch.tensor(init_b))
|
||||||
|
self.loss_method = loss_method
|
||||||
|
|
||||||
|
print('Initialised Generalized End-to-End loss')
|
||||||
|
|
||||||
|
assert self.loss_method in ["softmax", "contrast"]
|
||||||
|
|
||||||
|
if self.loss_method == "softmax":
|
||||||
|
self.embed_loss = self.embed_loss_softmax
|
||||||
|
if self.loss_method == "contrast":
|
||||||
|
self.embed_loss = self.embed_loss_contrast
|
||||||
|
|
||||||
|
# pylint: disable=R0201
|
||||||
|
def calc_new_centroids(self, dvecs, centroids, spkr, utt):
|
||||||
|
"""
|
||||||
|
Calculates the new centroids excluding the reference utterance
|
||||||
|
"""
|
||||||
|
excl = torch.cat((dvecs[spkr, :utt], dvecs[spkr, utt + 1 :]))
|
||||||
|
excl = torch.mean(excl, 0)
|
||||||
|
new_centroids = []
|
||||||
|
for i, centroid in enumerate(centroids):
|
||||||
|
if i == spkr:
|
||||||
|
new_centroids.append(excl)
|
||||||
|
else:
|
||||||
|
new_centroids.append(centroid)
|
||||||
|
return torch.stack(new_centroids)
|
||||||
|
|
||||||
|
def calc_cosine_sim(self, dvecs, centroids):
|
||||||
|
"""
|
||||||
|
Make the cosine similarity matrix with dims (N,M,N)
|
||||||
|
"""
|
||||||
|
cos_sim_matrix = []
|
||||||
|
for spkr_idx, speaker in enumerate(dvecs):
|
||||||
|
cs_row = []
|
||||||
|
for utt_idx, utterance in enumerate(speaker):
|
||||||
|
new_centroids = self.calc_new_centroids(
|
||||||
|
dvecs, centroids, spkr_idx, utt_idx
|
||||||
|
)
|
||||||
|
# vector based cosine similarity for speed
|
||||||
|
cs_row.append(
|
||||||
|
torch.clamp(
|
||||||
|
torch.mm(
|
||||||
|
utterance.unsqueeze(1).transpose(0, 1),
|
||||||
|
new_centroids.transpose(0, 1),
|
||||||
|
)
|
||||||
|
/ (torch.norm(utterance) * torch.norm(new_centroids, dim=1)),
|
||||||
|
1e-6,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cs_row = torch.cat(cs_row, dim=0)
|
||||||
|
cos_sim_matrix.append(cs_row)
|
||||||
|
return torch.stack(cos_sim_matrix)
|
||||||
|
|
||||||
|
# pylint: disable=R0201
|
||||||
|
def embed_loss_softmax(self, dvecs, cos_sim_matrix):
|
||||||
|
"""
|
||||||
|
Calculates the loss on each embedding $L(e_{ji})$ by taking softmax
|
||||||
|
"""
|
||||||
|
N, M, _ = dvecs.shape
|
||||||
|
L = []
|
||||||
|
for j in range(N):
|
||||||
|
L_row = []
|
||||||
|
for i in range(M):
|
||||||
|
L_row.append(-F.log_softmax(cos_sim_matrix[j, i], 0)[j])
|
||||||
|
L_row = torch.stack(L_row)
|
||||||
|
L.append(L_row)
|
||||||
|
return torch.stack(L)
|
||||||
|
|
||||||
|
# pylint: disable=R0201
|
||||||
|
def embed_loss_contrast(self, dvecs, cos_sim_matrix):
|
||||||
|
"""
|
||||||
|
Calculates the loss on each embedding $L(e_{ji})$ by contrast loss with closest centroid
|
||||||
|
"""
|
||||||
|
N, M, _ = dvecs.shape
|
||||||
|
L = []
|
||||||
|
for j in range(N):
|
||||||
|
L_row = []
|
||||||
|
for i in range(M):
|
||||||
|
centroids_sigmoids = torch.sigmoid(cos_sim_matrix[j, i])
|
||||||
|
excl_centroids_sigmoids = torch.cat(
|
||||||
|
(centroids_sigmoids[:j], centroids_sigmoids[j + 1 :])
|
||||||
|
)
|
||||||
|
L_row.append(
|
||||||
|
1.0
|
||||||
|
- torch.sigmoid(cos_sim_matrix[j, i, j])
|
||||||
|
+ torch.max(excl_centroids_sigmoids)
|
||||||
|
)
|
||||||
|
L_row = torch.stack(L_row)
|
||||||
|
L.append(L_row)
|
||||||
|
return torch.stack(L)
|
||||||
|
|
||||||
|
def forward(self, dvecs):
|
||||||
|
"""
|
||||||
|
Calculates the GE2E loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
|
||||||
|
"""
|
||||||
|
centroids = torch.mean(dvecs, 1)
|
||||||
|
cos_sim_matrix = self.calc_cosine_sim(dvecs, centroids)
|
||||||
|
torch.clamp(self.w, 1e-6)
|
||||||
|
cos_sim_matrix = self.w * cos_sim_matrix + self.b
|
||||||
|
L = self.embed_loss(dvecs, cos_sim_matrix)
|
||||||
|
return L.mean()
|
||||||
|
|
||||||
|
# adapted from https://github.com/clovaai/voxceleb_trainer/blob/master/loss/angleproto.py
|
||||||
|
class AngleProtoLoss(nn.Module):
|
||||||
|
"""
|
||||||
|
Implementation of the Angular Prototypical loss defined in https://arxiv.org/abs/2003.11982
|
||||||
|
Accepts an input of size (N, M, D)
|
||||||
|
where N is the number of speakers in the batch,
|
||||||
|
M is the number of utterances per speaker,
|
||||||
|
and D is the dimensionality of the embedding vector
|
||||||
|
Args:
|
||||||
|
- init_w (float): defines the initial value of w
|
||||||
|
- init_b (float): definies the initial value of b
|
||||||
|
"""
|
||||||
|
def __init__(self, init_w=10.0, init_b=-5.0):
|
||||||
|
super(AngleProtoLoss, self).__init__()
|
||||||
|
# pylint: disable=E1102
|
||||||
|
self.w = nn.Parameter(torch.tensor(init_w))
|
||||||
|
# pylint: disable=E1102
|
||||||
|
self.b = nn.Parameter(torch.tensor(init_b))
|
||||||
|
self.criterion = torch.nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
print('Initialised Angular Prototypical loss')
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Calculates the AngleProto loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
|
||||||
|
"""
|
||||||
|
out_anchor = torch.mean(x[:,1:,:],1)
|
||||||
|
out_positive = x[:,0,:]
|
||||||
|
num_speakers = out_anchor.size()[0]
|
||||||
|
|
||||||
|
cos_sim_matrix = F.cosine_similarity(out_positive.unsqueeze(-1).expand(-1,-1,num_speakers),out_anchor.unsqueeze(-1).expand(-1,-1,num_speakers).transpose(0,2))
|
||||||
|
torch.clamp(self.w, 1e-6)
|
||||||
|
cos_sim_matrix = cos_sim_matrix * self.w + self.b
|
||||||
|
label = torch.from_numpy(np.asarray(range(0,num_speakers))).to(cos_sim_matrix.device)
|
||||||
|
L = self.criterion(cos_sim_matrix, label)
|
||||||
|
return L
|
Loading…
Reference in New Issue