mirror of https://github.com/coqui-ai/TTS.git
add softmaxproto loss and bug fix in data loader
This commit is contained in:
parent
78bad25f2b
commit
77d85c6cc5
|
@ -11,7 +11,7 @@ import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from TTS.speaker_encoder.dataset import MyDataset
|
from TTS.speaker_encoder.dataset import MyDataset
|
||||||
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss
|
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxLoss, SoftmaxAngleProtoLoss
|
||||||
from TTS.speaker_encoder.model import SpeakerEncoder
|
from TTS.speaker_encoder.model import SpeakerEncoder
|
||||||
from TTS.speaker_encoder.utils.generic_utils import check_config_speaker_encoder, save_best_model
|
from TTS.speaker_encoder.utils.generic_utils import check_config_speaker_encoder, save_best_model
|
||||||
from TTS.speaker_encoder.utils.visual import plot_embeddings
|
from TTS.speaker_encoder.utils.visual import plot_embeddings
|
||||||
|
@ -45,15 +45,16 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
|
||||||
dataset = MyDataset(
|
dataset = MyDataset(
|
||||||
ap,
|
ap,
|
||||||
meta_data_eval if is_val else meta_data_train,
|
meta_data_eval if is_val else meta_data_train,
|
||||||
voice_len=1.6,
|
voice_len=getattr(c, "voice_len", 1.6),
|
||||||
num_utter_per_speaker=c.num_utters_per_speaker,
|
num_utter_per_speaker=c.num_utters_per_speaker,
|
||||||
num_speakers_in_batch=c.num_speakers_in_batch,
|
num_speakers_in_batch=c.num_speakers_in_batch,
|
||||||
skip_speakers=False,
|
skip_speakers=getattr(c, "skip_speakers", False),
|
||||||
storage_size=c.storage["storage_size"],
|
storage_size=c.storage["storage_size"],
|
||||||
sample_from_storage_p=c.storage["sample_from_storage_p"],
|
sample_from_storage_p=c.storage["sample_from_storage_p"],
|
||||||
additive_noise=c.storage["additive_noise"],
|
additive_noise=c.storage["additive_noise"],
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
|
|
||||||
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
|
@ -62,11 +63,25 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
|
||||||
num_workers=c.num_loader_workers,
|
num_workers=c.num_loader_workers,
|
||||||
collate_fn=dataset.collate_fn,
|
collate_fn=dataset.collate_fn,
|
||||||
)
|
)
|
||||||
return loader
|
return loader, dataset.get_num_speakers()
|
||||||
|
|
||||||
|
|
||||||
def train(model, criterion, optimizer, scheduler, ap, global_step):
|
def train(model, optimizer, scheduler, ap, global_step):
|
||||||
data_loader = setup_loader(ap, is_val=False, verbose=True)
|
data_loader, num_speakers = setup_loader(ap, is_val=False, verbose=True)
|
||||||
|
|
||||||
|
if c.loss == "ge2e":
|
||||||
|
criterion = GE2ELoss(loss_method="softmax")
|
||||||
|
elif c.loss == "angleproto":
|
||||||
|
criterion = AngleProtoLoss()
|
||||||
|
elif c.loss == "softmaxproto":
|
||||||
|
criterion = SoftmaxAngleProtoLoss(c.model["proj_dim"], num_speakers)
|
||||||
|
else:
|
||||||
|
raise Exception("The %s not is a loss supported" % c.loss)
|
||||||
|
|
||||||
|
if use_cuda:
|
||||||
|
model = model.cuda()
|
||||||
|
criterion.cuda()
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
best_loss = float("inf")
|
best_loss = float("inf")
|
||||||
|
@ -77,7 +92,8 @@ def train(model, criterion, optimizer, scheduler, ap, global_step):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# setup input data
|
# setup input data
|
||||||
inputs = data[0]
|
inputs, labels = data
|
||||||
|
|
||||||
loader_time = time.time() - end_time
|
loader_time = time.time() - end_time
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
|
@ -89,13 +105,13 @@ def train(model, criterion, optimizer, scheduler, ap, global_step):
|
||||||
# dispatch data to GPU
|
# dispatch data to GPU
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
inputs = inputs.cuda(non_blocking=True)
|
inputs = inputs.cuda(non_blocking=True)
|
||||||
# labels = labels.cuda(non_blocking=True)
|
labels = labels.cuda(non_blocking=True)
|
||||||
|
|
||||||
# forward pass model
|
# forward pass model
|
||||||
outputs = model(inputs)
|
outputs = model(inputs)
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
loss = criterion(outputs.view(c.num_speakers_in_batch, outputs.shape[0] // c.num_speakers_in_batch, -1))
|
loss = criterion(outputs.view(c.num_speakers_in_batch, outputs.shape[0] // c.num_speakers_in_batch, -1), labels)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
grad_norm, _ = check_update(model, c.grad_clip)
|
grad_norm, _ = check_update(model, c.grad_clip)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
@ -158,13 +174,6 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
)
|
)
|
||||||
optimizer = RAdam(model.parameters(), lr=c.lr)
|
optimizer = RAdam(model.parameters(), lr=c.lr)
|
||||||
|
|
||||||
if c.loss == "ge2e":
|
|
||||||
criterion = GE2ELoss(loss_method="softmax")
|
|
||||||
elif c.loss == "angleproto":
|
|
||||||
criterion = AngleProtoLoss()
|
|
||||||
else:
|
|
||||||
raise Exception("The %s not is a loss supported" % c.loss)
|
|
||||||
|
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
checkpoint = torch.load(args.restore_path)
|
checkpoint = torch.load(args.restore_path)
|
||||||
try:
|
try:
|
||||||
|
@ -187,10 +196,6 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
else:
|
else:
|
||||||
args.restore_step = 0
|
args.restore_step = 0
|
||||||
|
|
||||||
if use_cuda:
|
|
||||||
model = model.cuda()
|
|
||||||
criterion.cuda()
|
|
||||||
|
|
||||||
if c.lr_decay:
|
if c.lr_decay:
|
||||||
scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1)
|
scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1)
|
||||||
else:
|
else:
|
||||||
|
@ -203,7 +208,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
|
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
|
||||||
|
|
||||||
global_step = args.restore_step
|
global_step = args.restore_step
|
||||||
_, global_step = train(model, criterion, optimizer, scheduler, ap, global_step)
|
_, global_step = train(model, optimizer, scheduler, ap, global_step)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -37,6 +37,9 @@
|
||||||
"steps_plot_stats": 10, // number of steps to plot embeddings.
|
"steps_plot_stats": 10, // number of steps to plot embeddings.
|
||||||
"num_speakers_in_batch": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
"num_speakers_in_batch": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
||||||
"num_utters_per_speaker": 10, //
|
"num_utters_per_speaker": 10, //
|
||||||
|
"skip_speakers": false, // skip speakers with samples less than "num_utters_per_speaker"
|
||||||
|
|
||||||
|
"voice_len": 1.6, // number of seconds for each training instance
|
||||||
"num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
"num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
||||||
"wd": 0.000001, // Weight decay weight.
|
"wd": 0.000001, // Weight decay weight.
|
||||||
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
|
@ -0,0 +1,78 @@
|
||||||
|
|
||||||
|
{
|
||||||
|
"run_name": "speaker_encoder",
|
||||||
|
"run_description": "train speaker encoder with VCTK",
|
||||||
|
"audio":{
|
||||||
|
// Audio processing parameters
|
||||||
|
"num_mels": 80, // size of the mel spec frame.
|
||||||
|
"fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame.
|
||||||
|
"sample_rate": 16000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
|
||||||
|
"win_length": 1024, // stft window length in ms.
|
||||||
|
"hop_length": 256, // stft window hop-lengh in ms.
|
||||||
|
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
|
||||||
|
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
|
||||||
|
"preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
|
||||||
|
"min_level_db": -100, // normalization range
|
||||||
|
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
|
||||||
|
"power": 1.5, // value to sharpen wav signals after GL algorithm.
|
||||||
|
"griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
|
||||||
|
"stft_pad_mode": "reflect",
|
||||||
|
// Normalization parameters
|
||||||
|
"signal_norm": true, // normalize the spec values in range [0, 1]
|
||||||
|
"symmetric_norm": true, // move normalization to range [-1, 1]
|
||||||
|
"max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
|
||||||
|
"clip_norm": true, // clip normalized values into the range.
|
||||||
|
"mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
|
||||||
|
"mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
|
||||||
|
"spec_gain": 20.0,
|
||||||
|
"do_trim_silence": false, // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
|
||||||
|
"trim_db": 60, // threshold for timming silence. Set this according to your dataset.
|
||||||
|
"stats_path": null // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
|
||||||
|
},
|
||||||
|
"reinit_layers": [],
|
||||||
|
|
||||||
|
"loss": "softmaxproto", // "ge2e" to use Generalized End-to-End loss, "angleproto" to use Angular Prototypical loss and "softmaxproto" to use Softmax with Angular Prototypical loss
|
||||||
|
"grad_clip": 3.0, // upper limit for gradients for clipping.
|
||||||
|
"epochs": 1000, // total number of epochs to train.
|
||||||
|
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
||||||
|
"lr_decay": false, // if true, Noam learning rate decaying is applied through training.
|
||||||
|
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
||||||
|
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||||
|
"steps_plot_stats": 10, // number of steps to plot embeddings.
|
||||||
|
|
||||||
|
// Speakers config
|
||||||
|
"num_speakers_in_batch": 108, // Batch size for training.
|
||||||
|
"num_utters_per_speaker": 2, //
|
||||||
|
"skip_speakers": true, // skip speakers with samples less than "num_utters_per_speaker"
|
||||||
|
|
||||||
|
"voice_len": 2, // number of seconds for each training instance
|
||||||
|
|
||||||
|
"num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
||||||
|
"wd": 0.000001, // Weight decay weight.
|
||||||
|
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||||
|
"save_step": 1000, // Number of training steps expected to save traning stats and checkpoints.
|
||||||
|
"print_step": 20, // Number of steps to log traning on console.
|
||||||
|
"output_path": "../../../checkpoints/speaker_encoder/", // DATASET-RELATED: output path for all training outputs.
|
||||||
|
|
||||||
|
"model": {
|
||||||
|
"input_dim": 80,
|
||||||
|
"proj_dim": 512,
|
||||||
|
"lstm_dim": 768,
|
||||||
|
"num_lstm_layers": 3,
|
||||||
|
"use_lstm_with_projection": true
|
||||||
|
},
|
||||||
|
"storage": {
|
||||||
|
"sample_from_storage_p": 0.66, // the probability with which we'll sample from the DataSet in-memory storage
|
||||||
|
"storage_size": 15, // the size of the in-memory storage with respect to a single batch
|
||||||
|
"additive_noise": 1e-5 // add very small gaussian noise to the data in order to increase robustness
|
||||||
|
},
|
||||||
|
"datasets":
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "vctk",
|
||||||
|
"path": "/workspace/store/ecasanova/datasets/VCTK-Corpus-removed-silence/",
|
||||||
|
"meta_file_train": null,
|
||||||
|
"meta_file_val": null
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
|
@ -30,7 +30,6 @@ class MyDataset(Dataset):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.items = meta_data
|
self.items = meta_data
|
||||||
self.sample_rate = ap.sample_rate
|
self.sample_rate = ap.sample_rate
|
||||||
self.voice_len = voice_len
|
|
||||||
self.seq_len = int(voice_len * self.sample_rate)
|
self.seq_len = int(voice_len * self.sample_rate)
|
||||||
self.num_speakers_in_batch = num_speakers_in_batch
|
self.num_speakers_in_batch = num_speakers_in_batch
|
||||||
self.num_utter_per_speaker = num_utter_per_speaker
|
self.num_utter_per_speaker = num_utter_per_speaker
|
||||||
|
@ -41,10 +40,15 @@ class MyDataset(Dataset):
|
||||||
self.storage = queue.Queue(maxsize=storage_size * num_speakers_in_batch)
|
self.storage = queue.Queue(maxsize=storage_size * num_speakers_in_batch)
|
||||||
self.sample_from_storage_p = float(sample_from_storage_p)
|
self.sample_from_storage_p = float(sample_from_storage_p)
|
||||||
self.additive_noise = float(additive_noise)
|
self.additive_noise = float(additive_noise)
|
||||||
|
|
||||||
|
speakers_aux = list(self.speakers)
|
||||||
|
speakers_aux.sort()
|
||||||
|
self.speakerid_to_classid = {key : i for i, key in enumerate(speakers_aux)}
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print("\n > DataLoader initialization")
|
print("\n > DataLoader initialization")
|
||||||
print(f" | > Speakers per Batch: {num_speakers_in_batch}")
|
print(f" | > Speakers per Batch: {num_speakers_in_batch}")
|
||||||
print(f" | > Storage Size: {self.storage.maxsize} speakers, each with {num_utter_per_speaker} utters")
|
print(f" | > Storage Size: {self.storage.maxsize} instances, each with {num_utter_per_speaker} utters")
|
||||||
print(f" | > Sample_from_storage_p : {self.sample_from_storage_p}")
|
print(f" | > Sample_from_storage_p : {self.sample_from_storage_p}")
|
||||||
print(f" | > Noise added : {self.additive_noise}")
|
print(f" | > Noise added : {self.additive_noise}")
|
||||||
print(f" | > Number of instances : {len(self.items)}")
|
print(f" | > Number of instances : {len(self.items)}")
|
||||||
|
@ -110,8 +114,16 @@ class MyDataset(Dataset):
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return int(1e10)
|
return int(1e10)
|
||||||
|
|
||||||
def __sample_speaker(self):
|
def get_num_speakers(self):
|
||||||
|
return len(self.speakers)
|
||||||
|
|
||||||
|
def __sample_speaker(self, ignore_speakers=None):
|
||||||
speaker = random.sample(self.speakers, 1)[0]
|
speaker = random.sample(self.speakers, 1)[0]
|
||||||
|
# if list of speakers_id is provide make sure that it's will be ignored
|
||||||
|
if ignore_speakers:
|
||||||
|
while self.speakerid_to_classid[speaker] in ignore_speakers:
|
||||||
|
speaker = random.sample(self.speakers, 1)[0]
|
||||||
|
|
||||||
if self.num_utter_per_speaker > len(self.speaker_to_utters[speaker]):
|
if self.num_utter_per_speaker > len(self.speaker_to_utters[speaker]):
|
||||||
utters = random.choices(self.speaker_to_utters[speaker], k=self.num_utter_per_speaker)
|
utters = random.choices(self.speaker_to_utters[speaker], k=self.num_utter_per_speaker)
|
||||||
else:
|
else:
|
||||||
|
@ -127,7 +139,8 @@ class MyDataset(Dataset):
|
||||||
for _ in range(self.num_utter_per_speaker):
|
for _ in range(self.num_utter_per_speaker):
|
||||||
# TODO:dummy but works
|
# TODO:dummy but works
|
||||||
while True:
|
while True:
|
||||||
if len(self.speaker_to_utters[speaker]) > 0:
|
# remove speakers that have num_utter less than 2
|
||||||
|
if len(self.speaker_to_utters[speaker]) > 1:
|
||||||
utter = random.sample(self.speaker_to_utters[speaker], 1)[0]
|
utter = random.sample(self.speaker_to_utters[speaker], 1)[0]
|
||||||
else:
|
else:
|
||||||
self.speakers.remove(speaker)
|
self.speakers.remove(speaker)
|
||||||
|
@ -139,21 +152,47 @@ class MyDataset(Dataset):
|
||||||
self.speaker_to_utters[speaker].remove(utter)
|
self.speaker_to_utters[speaker].remove(utter)
|
||||||
|
|
||||||
wavs.append(wav)
|
wavs.append(wav)
|
||||||
labels.append(speaker)
|
labels.append(self.speakerid_to_classid[speaker])
|
||||||
return wavs, labels
|
return wavs, labels
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
speaker, _ = self.__sample_speaker()
|
speaker, _ = self.__sample_speaker()
|
||||||
return speaker
|
speaker_id = self.speakerid_to_classid[speaker]
|
||||||
|
return speaker, speaker_id
|
||||||
|
|
||||||
def collate_fn(self, batch):
|
def collate_fn(self, batch):
|
||||||
|
# get the batch speaker_ids
|
||||||
|
batch = np.array(batch)
|
||||||
|
speakers_id_in_batch = set(batch[:, 1].astype(np.int32))
|
||||||
|
|
||||||
labels = []
|
labels = []
|
||||||
feats = []
|
feats = []
|
||||||
for speaker in batch:
|
speakers = set()
|
||||||
|
for speaker, speaker_id in batch:
|
||||||
|
|
||||||
if random.random() < self.sample_from_storage_p and self.storage.full():
|
if random.random() < self.sample_from_storage_p and self.storage.full():
|
||||||
# sample from storage (if full), ignoring the speaker
|
# sample from storage (if full), ignoring the speaker
|
||||||
wavs_, labels_ = random.choice(self.storage.queue)
|
wavs_, labels_ = random.choice(self.storage.queue)
|
||||||
|
|
||||||
|
# force choose the current speaker or other not in batch
|
||||||
|
'''while labels_[0] in speakers_id_in_batch:
|
||||||
|
if labels_[0] == speaker_id:
|
||||||
|
break
|
||||||
|
wavs_, labels_ = random.choice(self.storage.queue)'''
|
||||||
|
|
||||||
|
speakers.add(labels_[0])
|
||||||
|
speakers_id_in_batch.add(labels_[0])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
# ensure that an speaker appears only once in the batch
|
||||||
|
if speaker_id in speakers:
|
||||||
|
speaker, _ = self.__sample_speaker(speakers_id_in_batch)
|
||||||
|
speaker_id = self.speakerid_to_classid[speaker]
|
||||||
|
# append the new speaker from batch
|
||||||
|
speakers_id_in_batch.add(speaker_id)
|
||||||
|
|
||||||
|
speakers.add(speaker_id)
|
||||||
|
|
||||||
# don't sample from storage, but from HDD
|
# don't sample from storage, but from HDD
|
||||||
wavs_, labels_ = self.__sample_speaker_utterances(speaker)
|
wavs_, labels_ = self.__sample_speaker_utterances(speaker)
|
||||||
# if storage is full, remove an item
|
# if storage is full, remove an item
|
||||||
|
@ -167,14 +206,15 @@ class MyDataset(Dataset):
|
||||||
noises_ = [np.random.normal(0, self.additive_noise, size=len(w)) for w in wavs_]
|
noises_ = [np.random.normal(0, self.additive_noise, size=len(w)) for w in wavs_]
|
||||||
wavs_ = [wavs_[i] + noises_[i] for i in range(len(wavs_))]
|
wavs_ = [wavs_[i] + noises_[i] for i in range(len(wavs_))]
|
||||||
|
|
||||||
# get a random subset of each of the wavs and convert to MFCC.
|
# get a random subset of each of the wavs and extract mel spectrograms.
|
||||||
offsets_ = [random.randint(0, wav.shape[0] - self.seq_len) for wav in wavs_]
|
offsets_ = [random.randint(0, wav.shape[0] - self.seq_len) for wav in wavs_]
|
||||||
mels_ = [
|
mels_ = [
|
||||||
self.ap.melspectrogram(wavs_[i][offsets_[i] : offsets_[i] + self.seq_len]) for i in range(len(wavs_))
|
self.ap.melspectrogram(wavs_[i][offsets_[i] : offsets_[i] + self.seq_len]) for i in range(len(wavs_))
|
||||||
]
|
]
|
||||||
feats_ = [torch.FloatTensor(mel) for mel in mels_]
|
feats_ = [torch.FloatTensor(mel) for mel in mels_]
|
||||||
|
|
||||||
labels.append(labels_)
|
labels.append(torch.LongTensor(labels_))
|
||||||
feats.extend(feats_)
|
feats.extend(feats_)
|
||||||
feats = torch.stack(feats)
|
feats = torch.stack(feats)
|
||||||
|
labels = torch.stack(labels)
|
||||||
return feats.transpose(1, 2), labels
|
return feats.transpose(1, 2), labels
|
||||||
|
|
|
@ -103,10 +103,13 @@ class GE2ELoss(nn.Module):
|
||||||
L.append(L_row)
|
L.append(L_row)
|
||||||
return torch.stack(L)
|
return torch.stack(L)
|
||||||
|
|
||||||
def forward(self, dvecs):
|
def forward(self, dvecs, label=None):
|
||||||
"""
|
"""
|
||||||
Calculates the GE2E loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
|
Calculates the GE2E loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
assert x.size()[1] >= 2
|
||||||
|
|
||||||
centroids = torch.mean(dvecs, 1)
|
centroids = torch.mean(dvecs, 1)
|
||||||
cos_sim_matrix = self.calc_cosine_sim(dvecs, centroids)
|
cos_sim_matrix = self.calc_cosine_sim(dvecs, centroids)
|
||||||
torch.clamp(self.w, 1e-6)
|
torch.clamp(self.w, 1e-6)
|
||||||
|
@ -138,10 +141,13 @@ class AngleProtoLoss(nn.Module):
|
||||||
|
|
||||||
print(" > Initialised Angular Prototypical loss")
|
print(" > Initialised Angular Prototypical loss")
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, label=None):
|
||||||
"""
|
"""
|
||||||
Calculates the AngleProto loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
|
Calculates the AngleProto loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
assert x.size()[1] >= 2
|
||||||
|
|
||||||
out_anchor = torch.mean(x[:, 1:, :], 1)
|
out_anchor = torch.mean(x[:, 1:, :], 1)
|
||||||
out_positive = x[:, 0, :]
|
out_positive = x[:, 0, :]
|
||||||
num_speakers = out_anchor.size()[0]
|
num_speakers = out_anchor.size()[0]
|
||||||
|
@ -155,3 +161,57 @@ class AngleProtoLoss(nn.Module):
|
||||||
label = torch.arange(num_speakers).to(cos_sim_matrix.device)
|
label = torch.arange(num_speakers).to(cos_sim_matrix.device)
|
||||||
L = self.criterion(cos_sim_matrix, label)
|
L = self.criterion(cos_sim_matrix, label)
|
||||||
return L
|
return L
|
||||||
|
|
||||||
|
class SoftmaxLoss(nn.Module):
|
||||||
|
"""
|
||||||
|
Implementation of the Softmax loss as defined in https://arxiv.org/abs/2003.11982
|
||||||
|
Args:
|
||||||
|
- embedding_dim (float): speaker embedding dim
|
||||||
|
- n_speakers (float): number of speakers
|
||||||
|
"""
|
||||||
|
def __init__(self, embedding_dim, n_speakers):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.criterion = torch.nn.CrossEntropyLoss()
|
||||||
|
self.fc = nn.Linear(embedding_dim, n_speakers)
|
||||||
|
|
||||||
|
print('Initialised Softmax Loss')
|
||||||
|
|
||||||
|
def forward(self, x, label=None):
|
||||||
|
|
||||||
|
x = self.fc(x)
|
||||||
|
L = self.criterion(x, label)
|
||||||
|
|
||||||
|
return L
|
||||||
|
|
||||||
|
class SoftmaxAngleProtoLoss(nn.Module):
|
||||||
|
"""
|
||||||
|
Implementation of the Softmax AnglePrototypical loss as defined in https://arxiv.org/abs/2009.14153
|
||||||
|
Args:
|
||||||
|
- embedding_dim (float): speaker embedding dim
|
||||||
|
- n_speakers (float): number of speakers
|
||||||
|
- init_w (float): defines the initial value of w
|
||||||
|
- init_b (float): definies the initial value of b
|
||||||
|
"""
|
||||||
|
def __init__(self, embedding_dim, n_speakers, init_w=10.0, init_b=-5.0):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.softmax = SoftmaxLoss(embedding_dim, n_speakers)
|
||||||
|
self.angleproto = AngleProtoLoss(init_w, init_b)
|
||||||
|
|
||||||
|
print('Initialised SoftmaxAnglePrototypical Loss')
|
||||||
|
|
||||||
|
def forward(self, x, label=None):
|
||||||
|
"""
|
||||||
|
Calculates the SoftmaxAnglePrototypical loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert x.size()[1] == 2
|
||||||
|
|
||||||
|
Lp = self.angleproto(x)
|
||||||
|
|
||||||
|
x = x.reshape(-1, x.size()[-1])
|
||||||
|
label = label.reshape(-1)
|
||||||
|
Ls = self.softmax(x, label)
|
||||||
|
|
||||||
|
return Ls+Lp
|
||||||
|
|
|
@ -82,7 +82,7 @@ def check_config_speaker_encoder(c):
|
||||||
check_argument("griffin_lim_iters", c["audio"], restricted=True, val_type=int, min_val=10, max_val=1000)
|
check_argument("griffin_lim_iters", c["audio"], restricted=True, val_type=int, min_val=10, max_val=1000)
|
||||||
|
|
||||||
# training parameters
|
# training parameters
|
||||||
check_argument("loss", c, enum_list=["ge2e", "angleproto"], restricted=True, val_type=str)
|
check_argument("loss", c, enum_list=["ge2e", "angleproto", "softmaxproto"], restricted=True, val_type=str)
|
||||||
check_argument("grad_clip", c, restricted=True, val_type=float)
|
check_argument("grad_clip", c, restricted=True, val_type=float)
|
||||||
check_argument("epochs", c, restricted=True, val_type=int, min_val=1)
|
check_argument("epochs", c, restricted=True, val_type=int, min_val=1)
|
||||||
check_argument("lr", c, restricted=True, val_type=float, min_val=0)
|
check_argument("lr", c, restricted=True, val_type=float, min_val=0)
|
||||||
|
|
Loading…
Reference in New Issue