mirror of https://github.com/coqui-ai/TTS.git
add softmaxproto loss and bug fix in data loader
This commit is contained in:
parent
78bad25f2b
commit
77d85c6cc5
|
@ -11,7 +11,7 @@ import torch
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from TTS.speaker_encoder.dataset import MyDataset
|
||||
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss
|
||||
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxLoss, SoftmaxAngleProtoLoss
|
||||
from TTS.speaker_encoder.model import SpeakerEncoder
|
||||
from TTS.speaker_encoder.utils.generic_utils import check_config_speaker_encoder, save_best_model
|
||||
from TTS.speaker_encoder.utils.visual import plot_embeddings
|
||||
|
@ -45,15 +45,16 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
|
|||
dataset = MyDataset(
|
||||
ap,
|
||||
meta_data_eval if is_val else meta_data_train,
|
||||
voice_len=1.6,
|
||||
voice_len=getattr(c, "voice_len", 1.6),
|
||||
num_utter_per_speaker=c.num_utters_per_speaker,
|
||||
num_speakers_in_batch=c.num_speakers_in_batch,
|
||||
skip_speakers=False,
|
||||
skip_speakers=getattr(c, "skip_speakers", False),
|
||||
storage_size=c.storage["storage_size"],
|
||||
sample_from_storage_p=c.storage["sample_from_storage_p"],
|
||||
additive_noise=c.storage["additive_noise"],
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
|
@ -62,11 +63,25 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
|
|||
num_workers=c.num_loader_workers,
|
||||
collate_fn=dataset.collate_fn,
|
||||
)
|
||||
return loader
|
||||
return loader, dataset.get_num_speakers()
|
||||
|
||||
|
||||
def train(model, criterion, optimizer, scheduler, ap, global_step):
|
||||
data_loader = setup_loader(ap, is_val=False, verbose=True)
|
||||
def train(model, optimizer, scheduler, ap, global_step):
|
||||
data_loader, num_speakers = setup_loader(ap, is_val=False, verbose=True)
|
||||
|
||||
if c.loss == "ge2e":
|
||||
criterion = GE2ELoss(loss_method="softmax")
|
||||
elif c.loss == "angleproto":
|
||||
criterion = AngleProtoLoss()
|
||||
elif c.loss == "softmaxproto":
|
||||
criterion = SoftmaxAngleProtoLoss(c.model["proj_dim"], num_speakers)
|
||||
else:
|
||||
raise Exception("The %s not is a loss supported" % c.loss)
|
||||
|
||||
if use_cuda:
|
||||
model = model.cuda()
|
||||
criterion.cuda()
|
||||
|
||||
model.train()
|
||||
epoch_time = 0
|
||||
best_loss = float("inf")
|
||||
|
@ -77,7 +92,8 @@ def train(model, criterion, optimizer, scheduler, ap, global_step):
|
|||
start_time = time.time()
|
||||
|
||||
# setup input data
|
||||
inputs = data[0]
|
||||
inputs, labels = data
|
||||
|
||||
loader_time = time.time() - end_time
|
||||
global_step += 1
|
||||
|
||||
|
@ -89,13 +105,13 @@ def train(model, criterion, optimizer, scheduler, ap, global_step):
|
|||
# dispatch data to GPU
|
||||
if use_cuda:
|
||||
inputs = inputs.cuda(non_blocking=True)
|
||||
# labels = labels.cuda(non_blocking=True)
|
||||
labels = labels.cuda(non_blocking=True)
|
||||
|
||||
# forward pass model
|
||||
outputs = model(inputs)
|
||||
|
||||
# loss computation
|
||||
loss = criterion(outputs.view(c.num_speakers_in_batch, outputs.shape[0] // c.num_speakers_in_batch, -1))
|
||||
loss = criterion(outputs.view(c.num_speakers_in_batch, outputs.shape[0] // c.num_speakers_in_batch, -1), labels)
|
||||
loss.backward()
|
||||
grad_norm, _ = check_update(model, c.grad_clip)
|
||||
optimizer.step()
|
||||
|
@ -158,13 +174,6 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
)
|
||||
optimizer = RAdam(model.parameters(), lr=c.lr)
|
||||
|
||||
if c.loss == "ge2e":
|
||||
criterion = GE2ELoss(loss_method="softmax")
|
||||
elif c.loss == "angleproto":
|
||||
criterion = AngleProtoLoss()
|
||||
else:
|
||||
raise Exception("The %s not is a loss supported" % c.loss)
|
||||
|
||||
if args.restore_path:
|
||||
checkpoint = torch.load(args.restore_path)
|
||||
try:
|
||||
|
@ -187,10 +196,6 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
else:
|
||||
args.restore_step = 0
|
||||
|
||||
if use_cuda:
|
||||
model = model.cuda()
|
||||
criterion.cuda()
|
||||
|
||||
if c.lr_decay:
|
||||
scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1)
|
||||
else:
|
||||
|
@ -203,7 +208,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
|
||||
|
||||
global_step = args.restore_step
|
||||
_, global_step = train(model, criterion, optimizer, scheduler, ap, global_step)
|
||||
_, global_step = train(model, optimizer, scheduler, ap, global_step)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -37,6 +37,9 @@
|
|||
"steps_plot_stats": 10, // number of steps to plot embeddings.
|
||||
"num_speakers_in_batch": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
||||
"num_utters_per_speaker": 10, //
|
||||
"skip_speakers": false, // skip speakers with samples less than "num_utters_per_speaker"
|
||||
|
||||
"voice_len": 1.6, // number of seconds for each training instance
|
||||
"num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
||||
"wd": 0.000001, // Weight decay weight.
|
||||
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
|
@ -0,0 +1,78 @@
|
|||
|
||||
{
|
||||
"run_name": "speaker_encoder",
|
||||
"run_description": "train speaker encoder with VCTK",
|
||||
"audio":{
|
||||
// Audio processing parameters
|
||||
"num_mels": 80, // size of the mel spec frame.
|
||||
"fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame.
|
||||
"sample_rate": 16000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
|
||||
"win_length": 1024, // stft window length in ms.
|
||||
"hop_length": 256, // stft window hop-lengh in ms.
|
||||
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
|
||||
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
|
||||
"preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
|
||||
"min_level_db": -100, // normalization range
|
||||
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
|
||||
"power": 1.5, // value to sharpen wav signals after GL algorithm.
|
||||
"griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
|
||||
"stft_pad_mode": "reflect",
|
||||
// Normalization parameters
|
||||
"signal_norm": true, // normalize the spec values in range [0, 1]
|
||||
"symmetric_norm": true, // move normalization to range [-1, 1]
|
||||
"max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
|
||||
"clip_norm": true, // clip normalized values into the range.
|
||||
"mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
|
||||
"mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
|
||||
"spec_gain": 20.0,
|
||||
"do_trim_silence": false, // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
|
||||
"trim_db": 60, // threshold for timming silence. Set this according to your dataset.
|
||||
"stats_path": null // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
|
||||
},
|
||||
"reinit_layers": [],
|
||||
|
||||
"loss": "softmaxproto", // "ge2e" to use Generalized End-to-End loss, "angleproto" to use Angular Prototypical loss and "softmaxproto" to use Softmax with Angular Prototypical loss
|
||||
"grad_clip": 3.0, // upper limit for gradients for clipping.
|
||||
"epochs": 1000, // total number of epochs to train.
|
||||
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
||||
"lr_decay": false, // if true, Noam learning rate decaying is applied through training.
|
||||
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
||||
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||
"steps_plot_stats": 10, // number of steps to plot embeddings.
|
||||
|
||||
// Speakers config
|
||||
"num_speakers_in_batch": 108, // Batch size for training.
|
||||
"num_utters_per_speaker": 2, //
|
||||
"skip_speakers": true, // skip speakers with samples less than "num_utters_per_speaker"
|
||||
|
||||
"voice_len": 2, // number of seconds for each training instance
|
||||
|
||||
"num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
||||
"wd": 0.000001, // Weight decay weight.
|
||||
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||
"save_step": 1000, // Number of training steps expected to save traning stats and checkpoints.
|
||||
"print_step": 20, // Number of steps to log traning on console.
|
||||
"output_path": "../../../checkpoints/speaker_encoder/", // DATASET-RELATED: output path for all training outputs.
|
||||
|
||||
"model": {
|
||||
"input_dim": 80,
|
||||
"proj_dim": 512,
|
||||
"lstm_dim": 768,
|
||||
"num_lstm_layers": 3,
|
||||
"use_lstm_with_projection": true
|
||||
},
|
||||
"storage": {
|
||||
"sample_from_storage_p": 0.66, // the probability with which we'll sample from the DataSet in-memory storage
|
||||
"storage_size": 15, // the size of the in-memory storage with respect to a single batch
|
||||
"additive_noise": 1e-5 // add very small gaussian noise to the data in order to increase robustness
|
||||
},
|
||||
"datasets":
|
||||
[
|
||||
{
|
||||
"name": "vctk",
|
||||
"path": "/workspace/store/ecasanova/datasets/VCTK-Corpus-removed-silence/",
|
||||
"meta_file_train": null,
|
||||
"meta_file_val": null
|
||||
}
|
||||
]
|
||||
}
|
|
@ -30,7 +30,6 @@ class MyDataset(Dataset):
|
|||
super().__init__()
|
||||
self.items = meta_data
|
||||
self.sample_rate = ap.sample_rate
|
||||
self.voice_len = voice_len
|
||||
self.seq_len = int(voice_len * self.sample_rate)
|
||||
self.num_speakers_in_batch = num_speakers_in_batch
|
||||
self.num_utter_per_speaker = num_utter_per_speaker
|
||||
|
@ -41,10 +40,15 @@ class MyDataset(Dataset):
|
|||
self.storage = queue.Queue(maxsize=storage_size * num_speakers_in_batch)
|
||||
self.sample_from_storage_p = float(sample_from_storage_p)
|
||||
self.additive_noise = float(additive_noise)
|
||||
|
||||
speakers_aux = list(self.speakers)
|
||||
speakers_aux.sort()
|
||||
self.speakerid_to_classid = {key : i for i, key in enumerate(speakers_aux)}
|
||||
|
||||
if self.verbose:
|
||||
print("\n > DataLoader initialization")
|
||||
print(f" | > Speakers per Batch: {num_speakers_in_batch}")
|
||||
print(f" | > Storage Size: {self.storage.maxsize} speakers, each with {num_utter_per_speaker} utters")
|
||||
print(f" | > Storage Size: {self.storage.maxsize} instances, each with {num_utter_per_speaker} utters")
|
||||
print(f" | > Sample_from_storage_p : {self.sample_from_storage_p}")
|
||||
print(f" | > Noise added : {self.additive_noise}")
|
||||
print(f" | > Number of instances : {len(self.items)}")
|
||||
|
@ -110,8 +114,16 @@ class MyDataset(Dataset):
|
|||
def __len__(self):
|
||||
return int(1e10)
|
||||
|
||||
def __sample_speaker(self):
|
||||
def get_num_speakers(self):
|
||||
return len(self.speakers)
|
||||
|
||||
def __sample_speaker(self, ignore_speakers=None):
|
||||
speaker = random.sample(self.speakers, 1)[0]
|
||||
# if list of speakers_id is provide make sure that it's will be ignored
|
||||
if ignore_speakers:
|
||||
while self.speakerid_to_classid[speaker] in ignore_speakers:
|
||||
speaker = random.sample(self.speakers, 1)[0]
|
||||
|
||||
if self.num_utter_per_speaker > len(self.speaker_to_utters[speaker]):
|
||||
utters = random.choices(self.speaker_to_utters[speaker], k=self.num_utter_per_speaker)
|
||||
else:
|
||||
|
@ -127,7 +139,8 @@ class MyDataset(Dataset):
|
|||
for _ in range(self.num_utter_per_speaker):
|
||||
# TODO:dummy but works
|
||||
while True:
|
||||
if len(self.speaker_to_utters[speaker]) > 0:
|
||||
# remove speakers that have num_utter less than 2
|
||||
if len(self.speaker_to_utters[speaker]) > 1:
|
||||
utter = random.sample(self.speaker_to_utters[speaker], 1)[0]
|
||||
else:
|
||||
self.speakers.remove(speaker)
|
||||
|
@ -139,21 +152,47 @@ class MyDataset(Dataset):
|
|||
self.speaker_to_utters[speaker].remove(utter)
|
||||
|
||||
wavs.append(wav)
|
||||
labels.append(speaker)
|
||||
labels.append(self.speakerid_to_classid[speaker])
|
||||
return wavs, labels
|
||||
|
||||
def __getitem__(self, idx):
|
||||
speaker, _ = self.__sample_speaker()
|
||||
return speaker
|
||||
speaker_id = self.speakerid_to_classid[speaker]
|
||||
return speaker, speaker_id
|
||||
|
||||
def collate_fn(self, batch):
|
||||
# get the batch speaker_ids
|
||||
batch = np.array(batch)
|
||||
speakers_id_in_batch = set(batch[:, 1].astype(np.int32))
|
||||
|
||||
labels = []
|
||||
feats = []
|
||||
for speaker in batch:
|
||||
speakers = set()
|
||||
for speaker, speaker_id in batch:
|
||||
|
||||
if random.random() < self.sample_from_storage_p and self.storage.full():
|
||||
# sample from storage (if full), ignoring the speaker
|
||||
wavs_, labels_ = random.choice(self.storage.queue)
|
||||
|
||||
# force choose the current speaker or other not in batch
|
||||
'''while labels_[0] in speakers_id_in_batch:
|
||||
if labels_[0] == speaker_id:
|
||||
break
|
||||
wavs_, labels_ = random.choice(self.storage.queue)'''
|
||||
|
||||
speakers.add(labels_[0])
|
||||
speakers_id_in_batch.add(labels_[0])
|
||||
|
||||
else:
|
||||
# ensure that an speaker appears only once in the batch
|
||||
if speaker_id in speakers:
|
||||
speaker, _ = self.__sample_speaker(speakers_id_in_batch)
|
||||
speaker_id = self.speakerid_to_classid[speaker]
|
||||
# append the new speaker from batch
|
||||
speakers_id_in_batch.add(speaker_id)
|
||||
|
||||
speakers.add(speaker_id)
|
||||
|
||||
# don't sample from storage, but from HDD
|
||||
wavs_, labels_ = self.__sample_speaker_utterances(speaker)
|
||||
# if storage is full, remove an item
|
||||
|
@ -167,14 +206,15 @@ class MyDataset(Dataset):
|
|||
noises_ = [np.random.normal(0, self.additive_noise, size=len(w)) for w in wavs_]
|
||||
wavs_ = [wavs_[i] + noises_[i] for i in range(len(wavs_))]
|
||||
|
||||
# get a random subset of each of the wavs and convert to MFCC.
|
||||
# get a random subset of each of the wavs and extract mel spectrograms.
|
||||
offsets_ = [random.randint(0, wav.shape[0] - self.seq_len) for wav in wavs_]
|
||||
mels_ = [
|
||||
self.ap.melspectrogram(wavs_[i][offsets_[i] : offsets_[i] + self.seq_len]) for i in range(len(wavs_))
|
||||
]
|
||||
feats_ = [torch.FloatTensor(mel) for mel in mels_]
|
||||
|
||||
labels.append(labels_)
|
||||
labels.append(torch.LongTensor(labels_))
|
||||
feats.extend(feats_)
|
||||
feats = torch.stack(feats)
|
||||
labels = torch.stack(labels)
|
||||
return feats.transpose(1, 2), labels
|
||||
|
|
|
@ -103,10 +103,13 @@ class GE2ELoss(nn.Module):
|
|||
L.append(L_row)
|
||||
return torch.stack(L)
|
||||
|
||||
def forward(self, dvecs):
|
||||
def forward(self, dvecs, label=None):
|
||||
"""
|
||||
Calculates the GE2E loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
|
||||
"""
|
||||
|
||||
assert x.size()[1] >= 2
|
||||
|
||||
centroids = torch.mean(dvecs, 1)
|
||||
cos_sim_matrix = self.calc_cosine_sim(dvecs, centroids)
|
||||
torch.clamp(self.w, 1e-6)
|
||||
|
@ -138,10 +141,13 @@ class AngleProtoLoss(nn.Module):
|
|||
|
||||
print(" > Initialised Angular Prototypical loss")
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, label=None):
|
||||
"""
|
||||
Calculates the AngleProto loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
|
||||
"""
|
||||
|
||||
assert x.size()[1] >= 2
|
||||
|
||||
out_anchor = torch.mean(x[:, 1:, :], 1)
|
||||
out_positive = x[:, 0, :]
|
||||
num_speakers = out_anchor.size()[0]
|
||||
|
@ -155,3 +161,57 @@ class AngleProtoLoss(nn.Module):
|
|||
label = torch.arange(num_speakers).to(cos_sim_matrix.device)
|
||||
L = self.criterion(cos_sim_matrix, label)
|
||||
return L
|
||||
|
||||
class SoftmaxLoss(nn.Module):
|
||||
"""
|
||||
Implementation of the Softmax loss as defined in https://arxiv.org/abs/2003.11982
|
||||
Args:
|
||||
- embedding_dim (float): speaker embedding dim
|
||||
- n_speakers (float): number of speakers
|
||||
"""
|
||||
def __init__(self, embedding_dim, n_speakers):
|
||||
super().__init__()
|
||||
|
||||
self.criterion = torch.nn.CrossEntropyLoss()
|
||||
self.fc = nn.Linear(embedding_dim, n_speakers)
|
||||
|
||||
print('Initialised Softmax Loss')
|
||||
|
||||
def forward(self, x, label=None):
|
||||
|
||||
x = self.fc(x)
|
||||
L = self.criterion(x, label)
|
||||
|
||||
return L
|
||||
|
||||
class SoftmaxAngleProtoLoss(nn.Module):
|
||||
"""
|
||||
Implementation of the Softmax AnglePrototypical loss as defined in https://arxiv.org/abs/2009.14153
|
||||
Args:
|
||||
- embedding_dim (float): speaker embedding dim
|
||||
- n_speakers (float): number of speakers
|
||||
- init_w (float): defines the initial value of w
|
||||
- init_b (float): definies the initial value of b
|
||||
"""
|
||||
def __init__(self, embedding_dim, n_speakers, init_w=10.0, init_b=-5.0):
|
||||
super().__init__()
|
||||
|
||||
self.softmax = SoftmaxLoss(embedding_dim, n_speakers)
|
||||
self.angleproto = AngleProtoLoss(init_w, init_b)
|
||||
|
||||
print('Initialised SoftmaxAnglePrototypical Loss')
|
||||
|
||||
def forward(self, x, label=None):
|
||||
"""
|
||||
Calculates the SoftmaxAnglePrototypical loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
|
||||
"""
|
||||
|
||||
assert x.size()[1] == 2
|
||||
|
||||
Lp = self.angleproto(x)
|
||||
|
||||
x = x.reshape(-1, x.size()[-1])
|
||||
label = label.reshape(-1)
|
||||
Ls = self.softmax(x, label)
|
||||
|
||||
return Ls+Lp
|
||||
|
|
|
@ -82,7 +82,7 @@ def check_config_speaker_encoder(c):
|
|||
check_argument("griffin_lim_iters", c["audio"], restricted=True, val_type=int, min_val=10, max_val=1000)
|
||||
|
||||
# training parameters
|
||||
check_argument("loss", c, enum_list=["ge2e", "angleproto"], restricted=True, val_type=str)
|
||||
check_argument("loss", c, enum_list=["ge2e", "angleproto", "softmaxproto"], restricted=True, val_type=str)
|
||||
check_argument("grad_clip", c, restricted=True, val_type=float)
|
||||
check_argument("epochs", c, restricted=True, val_type=int, min_val=1)
|
||||
check_argument("lr", c, restricted=True, val_type=float, min_val=0)
|
||||
|
|
Loading…
Reference in New Issue