add softmaxproto loss and bug fix in data loader

This commit is contained in:
Edresson 2021-05-10 17:08:38 -03:00
parent 78bad25f2b
commit 77d85c6cc5
6 changed files with 219 additions and 33 deletions

View File

@ -11,7 +11,7 @@ import torch
from torch.utils.data import DataLoader
from TTS.speaker_encoder.dataset import MyDataset
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxLoss, SoftmaxAngleProtoLoss
from TTS.speaker_encoder.model import SpeakerEncoder
from TTS.speaker_encoder.utils.generic_utils import check_config_speaker_encoder, save_best_model
from TTS.speaker_encoder.utils.visual import plot_embeddings
@ -45,15 +45,16 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
dataset = MyDataset(
ap,
meta_data_eval if is_val else meta_data_train,
voice_len=1.6,
voice_len=getattr(c, "voice_len", 1.6),
num_utter_per_speaker=c.num_utters_per_speaker,
num_speakers_in_batch=c.num_speakers_in_batch,
skip_speakers=False,
skip_speakers=getattr(c, "skip_speakers", False),
storage_size=c.storage["storage_size"],
sample_from_storage_p=c.storage["sample_from_storage_p"],
additive_noise=c.storage["additive_noise"],
verbose=verbose,
)
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader(
dataset,
@ -62,11 +63,25 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
num_workers=c.num_loader_workers,
collate_fn=dataset.collate_fn,
)
return loader
return loader, dataset.get_num_speakers()
def train(model, criterion, optimizer, scheduler, ap, global_step):
data_loader = setup_loader(ap, is_val=False, verbose=True)
def train(model, optimizer, scheduler, ap, global_step):
data_loader, num_speakers = setup_loader(ap, is_val=False, verbose=True)
if c.loss == "ge2e":
criterion = GE2ELoss(loss_method="softmax")
elif c.loss == "angleproto":
criterion = AngleProtoLoss()
elif c.loss == "softmaxproto":
criterion = SoftmaxAngleProtoLoss(c.model["proj_dim"], num_speakers)
else:
raise Exception("The %s not is a loss supported" % c.loss)
if use_cuda:
model = model.cuda()
criterion.cuda()
model.train()
epoch_time = 0
best_loss = float("inf")
@ -77,7 +92,8 @@ def train(model, criterion, optimizer, scheduler, ap, global_step):
start_time = time.time()
# setup input data
inputs = data[0]
inputs, labels = data
loader_time = time.time() - end_time
global_step += 1
@ -89,13 +105,13 @@ def train(model, criterion, optimizer, scheduler, ap, global_step):
# dispatch data to GPU
if use_cuda:
inputs = inputs.cuda(non_blocking=True)
# labels = labels.cuda(non_blocking=True)
labels = labels.cuda(non_blocking=True)
# forward pass model
outputs = model(inputs)
# loss computation
loss = criterion(outputs.view(c.num_speakers_in_batch, outputs.shape[0] // c.num_speakers_in_batch, -1))
loss = criterion(outputs.view(c.num_speakers_in_batch, outputs.shape[0] // c.num_speakers_in_batch, -1), labels)
loss.backward()
grad_norm, _ = check_update(model, c.grad_clip)
optimizer.step()
@ -158,13 +174,6 @@ def main(args): # pylint: disable=redefined-outer-name
)
optimizer = RAdam(model.parameters(), lr=c.lr)
if c.loss == "ge2e":
criterion = GE2ELoss(loss_method="softmax")
elif c.loss == "angleproto":
criterion = AngleProtoLoss()
else:
raise Exception("The %s not is a loss supported" % c.loss)
if args.restore_path:
checkpoint = torch.load(args.restore_path)
try:
@ -187,10 +196,6 @@ def main(args): # pylint: disable=redefined-outer-name
else:
args.restore_step = 0
if use_cuda:
model = model.cuda()
criterion.cuda()
if c.lr_decay:
scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1)
else:
@ -203,7 +208,7 @@ def main(args): # pylint: disable=redefined-outer-name
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
global_step = args.restore_step
_, global_step = train(model, criterion, optimizer, scheduler, ap, global_step)
_, global_step = train(model, optimizer, scheduler, ap, global_step)
if __name__ == "__main__":

View File

@ -37,6 +37,9 @@
"steps_plot_stats": 10, // number of steps to plot embeddings.
"num_speakers_in_batch": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
"num_utters_per_speaker": 10, //
"skip_speakers": false, // skip speakers with samples less than "num_utters_per_speaker"
"voice_len": 1.6, // number of seconds for each training instance
"num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values.
"wd": 0.000001, // Weight decay weight.
"checkpoint": true, // If true, it saves checkpoints per "save_step"

View File

@ -0,0 +1,78 @@
{
"run_name": "speaker_encoder",
"run_description": "train speaker encoder with VCTK",
"audio":{
// Audio processing parameters
"num_mels": 80, // size of the mel spec frame.
"fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame.
"sample_rate": 16000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
"win_length": 1024, // stft window length in ms.
"hop_length": 256, // stft window hop-lengh in ms.
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
"preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
"min_level_db": -100, // normalization range
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
"power": 1.5, // value to sharpen wav signals after GL algorithm.
"griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
"stft_pad_mode": "reflect",
// Normalization parameters
"signal_norm": true, // normalize the spec values in range [0, 1]
"symmetric_norm": true, // move normalization to range [-1, 1]
"max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
"clip_norm": true, // clip normalized values into the range.
"mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
"mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
"spec_gain": 20.0,
"do_trim_silence": false, // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
"trim_db": 60, // threshold for timming silence. Set this according to your dataset.
"stats_path": null // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
},
"reinit_layers": [],
"loss": "softmaxproto", // "ge2e" to use Generalized End-to-End loss, "angleproto" to use Angular Prototypical loss and "softmaxproto" to use Softmax with Angular Prototypical loss
"grad_clip": 3.0, // upper limit for gradients for clipping.
"epochs": 1000, // total number of epochs to train.
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
"lr_decay": false, // if true, Noam learning rate decaying is applied through training.
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
"steps_plot_stats": 10, // number of steps to plot embeddings.
// Speakers config
"num_speakers_in_batch": 108, // Batch size for training.
"num_utters_per_speaker": 2, //
"skip_speakers": true, // skip speakers with samples less than "num_utters_per_speaker"
"voice_len": 2, // number of seconds for each training instance
"num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values.
"wd": 0.000001, // Weight decay weight.
"checkpoint": true, // If true, it saves checkpoints per "save_step"
"save_step": 1000, // Number of training steps expected to save traning stats and checkpoints.
"print_step": 20, // Number of steps to log traning on console.
"output_path": "../../../checkpoints/speaker_encoder/", // DATASET-RELATED: output path for all training outputs.
"model": {
"input_dim": 80,
"proj_dim": 512,
"lstm_dim": 768,
"num_lstm_layers": 3,
"use_lstm_with_projection": true
},
"storage": {
"sample_from_storage_p": 0.66, // the probability with which we'll sample from the DataSet in-memory storage
"storage_size": 15, // the size of the in-memory storage with respect to a single batch
"additive_noise": 1e-5 // add very small gaussian noise to the data in order to increase robustness
},
"datasets":
[
{
"name": "vctk",
"path": "/workspace/store/ecasanova/datasets/VCTK-Corpus-removed-silence/",
"meta_file_train": null,
"meta_file_val": null
}
]
}

View File

@ -30,7 +30,6 @@ class MyDataset(Dataset):
super().__init__()
self.items = meta_data
self.sample_rate = ap.sample_rate
self.voice_len = voice_len
self.seq_len = int(voice_len * self.sample_rate)
self.num_speakers_in_batch = num_speakers_in_batch
self.num_utter_per_speaker = num_utter_per_speaker
@ -41,10 +40,15 @@ class MyDataset(Dataset):
self.storage = queue.Queue(maxsize=storage_size * num_speakers_in_batch)
self.sample_from_storage_p = float(sample_from_storage_p)
self.additive_noise = float(additive_noise)
speakers_aux = list(self.speakers)
speakers_aux.sort()
self.speakerid_to_classid = {key : i for i, key in enumerate(speakers_aux)}
if self.verbose:
print("\n > DataLoader initialization")
print(f" | > Speakers per Batch: {num_speakers_in_batch}")
print(f" | > Storage Size: {self.storage.maxsize} speakers, each with {num_utter_per_speaker} utters")
print(f" | > Storage Size: {self.storage.maxsize} instances, each with {num_utter_per_speaker} utters")
print(f" | > Sample_from_storage_p : {self.sample_from_storage_p}")
print(f" | > Noise added : {self.additive_noise}")
print(f" | > Number of instances : {len(self.items)}")
@ -110,8 +114,16 @@ class MyDataset(Dataset):
def __len__(self):
return int(1e10)
def __sample_speaker(self):
def get_num_speakers(self):
return len(self.speakers)
def __sample_speaker(self, ignore_speakers=None):
speaker = random.sample(self.speakers, 1)[0]
# if list of speakers_id is provide make sure that it's will be ignored
if ignore_speakers:
while self.speakerid_to_classid[speaker] in ignore_speakers:
speaker = random.sample(self.speakers, 1)[0]
if self.num_utter_per_speaker > len(self.speaker_to_utters[speaker]):
utters = random.choices(self.speaker_to_utters[speaker], k=self.num_utter_per_speaker)
else:
@ -127,7 +139,8 @@ class MyDataset(Dataset):
for _ in range(self.num_utter_per_speaker):
# TODO:dummy but works
while True:
if len(self.speaker_to_utters[speaker]) > 0:
# remove speakers that have num_utter less than 2
if len(self.speaker_to_utters[speaker]) > 1:
utter = random.sample(self.speaker_to_utters[speaker], 1)[0]
else:
self.speakers.remove(speaker)
@ -139,21 +152,47 @@ class MyDataset(Dataset):
self.speaker_to_utters[speaker].remove(utter)
wavs.append(wav)
labels.append(speaker)
labels.append(self.speakerid_to_classid[speaker])
return wavs, labels
def __getitem__(self, idx):
speaker, _ = self.__sample_speaker()
return speaker
speaker_id = self.speakerid_to_classid[speaker]
return speaker, speaker_id
def collate_fn(self, batch):
# get the batch speaker_ids
batch = np.array(batch)
speakers_id_in_batch = set(batch[:, 1].astype(np.int32))
labels = []
feats = []
for speaker in batch:
speakers = set()
for speaker, speaker_id in batch:
if random.random() < self.sample_from_storage_p and self.storage.full():
# sample from storage (if full), ignoring the speaker
wavs_, labels_ = random.choice(self.storage.queue)
# force choose the current speaker or other not in batch
'''while labels_[0] in speakers_id_in_batch:
if labels_[0] == speaker_id:
break
wavs_, labels_ = random.choice(self.storage.queue)'''
speakers.add(labels_[0])
speakers_id_in_batch.add(labels_[0])
else:
# ensure that an speaker appears only once in the batch
if speaker_id in speakers:
speaker, _ = self.__sample_speaker(speakers_id_in_batch)
speaker_id = self.speakerid_to_classid[speaker]
# append the new speaker from batch
speakers_id_in_batch.add(speaker_id)
speakers.add(speaker_id)
# don't sample from storage, but from HDD
wavs_, labels_ = self.__sample_speaker_utterances(speaker)
# if storage is full, remove an item
@ -167,14 +206,15 @@ class MyDataset(Dataset):
noises_ = [np.random.normal(0, self.additive_noise, size=len(w)) for w in wavs_]
wavs_ = [wavs_[i] + noises_[i] for i in range(len(wavs_))]
# get a random subset of each of the wavs and convert to MFCC.
# get a random subset of each of the wavs and extract mel spectrograms.
offsets_ = [random.randint(0, wav.shape[0] - self.seq_len) for wav in wavs_]
mels_ = [
self.ap.melspectrogram(wavs_[i][offsets_[i] : offsets_[i] + self.seq_len]) for i in range(len(wavs_))
]
feats_ = [torch.FloatTensor(mel) for mel in mels_]
labels.append(labels_)
labels.append(torch.LongTensor(labels_))
feats.extend(feats_)
feats = torch.stack(feats)
labels = torch.stack(labels)
return feats.transpose(1, 2), labels

View File

@ -103,10 +103,13 @@ class GE2ELoss(nn.Module):
L.append(L_row)
return torch.stack(L)
def forward(self, dvecs):
def forward(self, dvecs, label=None):
"""
Calculates the GE2E loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
"""
assert x.size()[1] >= 2
centroids = torch.mean(dvecs, 1)
cos_sim_matrix = self.calc_cosine_sim(dvecs, centroids)
torch.clamp(self.w, 1e-6)
@ -138,10 +141,13 @@ class AngleProtoLoss(nn.Module):
print(" > Initialised Angular Prototypical loss")
def forward(self, x):
def forward(self, x, label=None):
"""
Calculates the AngleProto loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
"""
assert x.size()[1] >= 2
out_anchor = torch.mean(x[:, 1:, :], 1)
out_positive = x[:, 0, :]
num_speakers = out_anchor.size()[0]
@ -155,3 +161,57 @@ class AngleProtoLoss(nn.Module):
label = torch.arange(num_speakers).to(cos_sim_matrix.device)
L = self.criterion(cos_sim_matrix, label)
return L
class SoftmaxLoss(nn.Module):
"""
Implementation of the Softmax loss as defined in https://arxiv.org/abs/2003.11982
Args:
- embedding_dim (float): speaker embedding dim
- n_speakers (float): number of speakers
"""
def __init__(self, embedding_dim, n_speakers):
super().__init__()
self.criterion = torch.nn.CrossEntropyLoss()
self.fc = nn.Linear(embedding_dim, n_speakers)
print('Initialised Softmax Loss')
def forward(self, x, label=None):
x = self.fc(x)
L = self.criterion(x, label)
return L
class SoftmaxAngleProtoLoss(nn.Module):
"""
Implementation of the Softmax AnglePrototypical loss as defined in https://arxiv.org/abs/2009.14153
Args:
- embedding_dim (float): speaker embedding dim
- n_speakers (float): number of speakers
- init_w (float): defines the initial value of w
- init_b (float): definies the initial value of b
"""
def __init__(self, embedding_dim, n_speakers, init_w=10.0, init_b=-5.0):
super().__init__()
self.softmax = SoftmaxLoss(embedding_dim, n_speakers)
self.angleproto = AngleProtoLoss(init_w, init_b)
print('Initialised SoftmaxAnglePrototypical Loss')
def forward(self, x, label=None):
"""
Calculates the SoftmaxAnglePrototypical loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
"""
assert x.size()[1] == 2
Lp = self.angleproto(x)
x = x.reshape(-1, x.size()[-1])
label = label.reshape(-1)
Ls = self.softmax(x, label)
return Ls+Lp

View File

@ -82,7 +82,7 @@ def check_config_speaker_encoder(c):
check_argument("griffin_lim_iters", c["audio"], restricted=True, val_type=int, min_val=10, max_val=1000)
# training parameters
check_argument("loss", c, enum_list=["ge2e", "angleproto"], restricted=True, val_type=str)
check_argument("loss", c, enum_list=["ge2e", "angleproto", "softmaxproto"], restricted=True, val_type=str)
check_argument("grad_clip", c, restricted=True, val_type=float)
check_argument("epochs", c, restricted=True, val_type=int, min_val=1)
check_argument("lr", c, restricted=True, val_type=float, min_val=0)