solve merge problems

This commit is contained in:
Edresson 2021-05-26 16:01:30 -03:00
parent f89cb6aec2
commit c90037c2e9
8 changed files with 40 additions and 110 deletions

View File

@ -12,7 +12,7 @@ from torch.utils.data import DataLoader
from TTS.speaker_encoder.dataset import MyDataset from TTS.speaker_encoder.dataset import MyDataset
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
from TTS.speaker_encoder.utils.generic_utils import check_config_speaker_encoder, save_best_model, setup_model from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_model
from TTS.speaker_encoder.utils.visual import plot_embeddings from TTS.speaker_encoder.utils.visual import plot_embeddings
from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.preprocess import load_meta_data
@ -38,15 +38,16 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
dataset = MyDataset( dataset = MyDataset(
ap, ap,
meta_data_eval if is_val else meta_data_train, meta_data_eval if is_val else meta_data_train,
voice_len=getattr(c, "voice_len", 1.6), voice_len=c.voice_len,
num_utter_per_speaker=c.num_utters_per_speaker, num_utter_per_speaker=c.num_utters_per_speaker,
num_speakers_in_batch=c.num_speakers_in_batch, num_speakers_in_batch=c.num_speakers_in_batch,
skip_speakers=getattr(c, "skip_speakers", False), skip_speakers=c.skip_speakers,
storage_size=c.storage["storage_size"], storage_size=c.storage["storage_size"],
sample_from_storage_p=c.storage["sample_from_storage_p"], sample_from_storage_p=c.storage["sample_from_storage_p"],
verbose=verbose, verbose=verbose,
augmentation_config=getattr(c, "audio_augmentation", None) augmentation_config=c.audio_augmentation
) )
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None # sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader( loader = DataLoader(
dataset, dataset,
@ -133,17 +134,15 @@ def train(model, optimizer, scheduler, criterion, data_loader, global_step):
) )
avg_loss_all += avg_loss avg_loss_all += avg_loss
if global_step % c.save_step == 0: if global_step >= c.max_train_step or global_step % c.save_step == 0:
# save best model # save best model only
best_loss = save_best_model(model, optimizer, criterion, avg_loss, best_loss, OUT_PATH, global_step) best_loss = save_best_model(model, optimizer, criterion, avg_loss, best_loss, OUT_PATH, global_step)
avg_loss_all = 0 avg_loss_all = 0
end_time = time.time()
# checkpoint and check stop train cond.
if global_step >= c.max_train_step or global_step % c.save_step == 0:
save_checkpoint(model, optimizer, avg_loss, OUT_PATH, global_step)
if global_step >= c.max_train_step: if global_step >= c.max_train_step:
break break
end_time = time.time()
return avg_loss, global_step return avg_loss, global_step

View File

@ -226,7 +226,7 @@ class BaseTrainingConfig(Coqpit):
run_description: str = "" run_description: str = ""
# training params # training params
epochs: int = 10000 epochs: int = 10000
batch_size: int = MISSING batch_size: int = None
eval_batch_size: int = None eval_batch_size: int = None
mixed_precision: bool = False mixed_precision: bool = False
# eval params # eval params

View File

@ -1,5 +1,5 @@
{ {
"model_name": "resnet", "model": "speaker_encoder",
"run_name": "speaker_encoder", "run_name": "speaker_encoder",
"run_description": "resnet speaker encoder trained with commonvoice all languages dev and train, Voxceleb 1 dev and Voxceleb 2 dev", "run_description": "resnet speaker encoder trained with commonvoice all languages dev and train, Voxceleb 1 dev and Voxceleb 2 dev",
// AUDIO PARAMETERS // AUDIO PARAMETERS
@ -34,7 +34,7 @@
"loss": "angleproto", // "ge2e" to use Generalized End-to-End loss, "angleproto" to use Angular Prototypical loss and "softmaxproto" to use Softmax with Angular Prototypical loss "loss": "angleproto", // "ge2e" to use Generalized End-to-End loss, "angleproto" to use Angular Prototypical loss and "softmaxproto" to use Softmax with Angular Prototypical loss
"grad_clip": 3.0, // upper limit for gradients for clipping. "grad_clip": 3.0, // upper limit for gradients for clipping.
"epochs": 1000, // total number of epochs to train. "max_train_step": 1000000, // total number of steps to train.
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate. "lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
"lr_decay": false, // if true, Noam learning rate decaying is applied through training. "lr_decay": false, // if true, Noam learning rate decaying is applied through training.
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
@ -45,15 +45,14 @@
"num_speakers_in_batch": 200, // Batch size for training. "num_speakers_in_batch": 200, // Batch size for training.
"num_utters_per_speaker": 2, // "num_utters_per_speaker": 2, //
"skip_speakers": true, // skip speakers with samples less than "num_utters_per_speaker" "skip_speakers": true, // skip speakers with samples less than "num_utters_per_speaker"
"voice_len": 2, // number of seconds for each training instance "voice_len": 2, // number of seconds for each training instance
"num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values. "num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
"wd": 0.000001, // Weight decay weight. "wd": 0.000001, // Weight decay weight.
"checkpoint": true, // If true, it saves checkpoints per "save_step" "checkpoint": true, // If true, it saves checkpoints per "save_step"
"save_step": 1000, // Number of training steps expected to save the best checkpoints in training. "save_step": 1000, // Number of training steps expected to save the best checkpoints in training.
"print_step": 50, // Number of steps to log traning on console. "print_step": 50, // Number of steps to log traning on console.
"output_path": "../../../checkpoints/speaker_encoder/resnet_voxceleb1_and_voxceleb2-and-common-voice-all-using-angleproto/", // DATASET-RELATED: output path for all training outputs. "output_path": "../checkpoints/speaker_encoder/resnet_voxceleb1_and_voxceleb2-and-common-voice-all-using-angleproto-continue/", // DATASET-RELATED: output path for all training outputs.
"audio_augmentation": { "audio_augmentation": {
"p": 0.5, // propability of apply this method, 0 is disable rir and additive noise augmentation "p": 0.5, // propability of apply this method, 0 is disable rir and additive noise augmentation
@ -90,12 +89,13 @@
"max_amplitude": 1e-5 "max_amplitude": 1e-5
} }
}, },
"model": { "model_params": {
"model_name": "resnet",
"input_dim": 80, "input_dim": 80,
"proj_dim": 512 "proj_dim": 512
}, },
"storage": { "storage": {
"sample_from_storage_p": 0.66, // the probability with which we'll sample from the DataSet in-memory storage "sample_from_storage_p": 0.5, // the probability with which we'll sample from the DataSet in-memory storage
"storage_size": 35 // the size of the in-memory storage with respect to a single batch "storage_size": 35 // the size of the in-memory storage with respect to a single batch
}, },
"datasets": "datasets":

View File

@ -1,6 +1,6 @@
{ {
"model_name": "resnet", "model": "speaker_encoder",
"run_name": "speaker_encoder", "run_name": "speaker_encoder",
"run_description": "resnet speaker encoder trained with commonvoice all languages dev and train, Voxceleb 1 dev and Voxceleb 2 dev", "run_description": "resnet speaker encoder trained with commonvoice all languages dev and train, Voxceleb 1 dev and Voxceleb 2 dev",
// AUDIO PARAMETERS // AUDIO PARAMETERS
@ -35,7 +35,7 @@
"loss": "softmaxproto", // "ge2e" to use Generalized End-to-End loss, "angleproto" to use Angular Prototypical loss and "softmaxproto" to use Softmax with Angular Prototypical loss "loss": "softmaxproto", // "ge2e" to use Generalized End-to-End loss, "angleproto" to use Angular Prototypical loss and "softmaxproto" to use Softmax with Angular Prototypical loss
"grad_clip": 3.0, // upper limit for gradients for clipping. "grad_clip": 3.0, // upper limit for gradients for clipping.
"epochs": 1000, // total number of epochs to train. "max_train_step": 1000000, // total number of steps to train.
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate. "lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
"lr_decay": false, // if true, Noam learning rate decaying is applied through training. "lr_decay": false, // if true, Noam learning rate decaying is applied through training.
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
@ -46,7 +46,6 @@
"num_speakers_in_batch": 200, // Batch size for training. "num_speakers_in_batch": 200, // Batch size for training.
"num_utters_per_speaker": 2, // "num_utters_per_speaker": 2, //
"skip_speakers": true, // skip speakers with samples less than "num_utters_per_speaker" "skip_speakers": true, // skip speakers with samples less than "num_utters_per_speaker"
"voice_len": 2, // number of seconds for each training instance "voice_len": 2, // number of seconds for each training instance
"num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values. "num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values.
@ -91,7 +90,8 @@
"max_amplitude": 1e-5 "max_amplitude": 1e-5
} }
}, },
"model": { "model_params": {
"model_name": "resnet",
"input_dim": 80, "input_dim": 80,
"proj_dim": 512 "proj_dim": 512
}, },

View File

@ -240,9 +240,6 @@ class MyDataset(Dataset):
labels.append(torch.LongTensor(labels_)) labels.append(torch.LongTensor(labels_))
feats.extend(feats_) feats.extend(feats_)
if self.num_speakers_in_batch != len(speakers):
raise ValueError('Error: Speakers appear more than once on the Batch. This cannot happen because the loss functions AngleProto and GE2E consider these samples to be from another speaker.')
feats = torch.stack(feats) feats = torch.stack(feats)
labels = torch.stack(labels) labels = torch.stack(labels)

View File

@ -103,7 +103,7 @@ class GE2ELoss(nn.Module):
L.append(L_row) L.append(L_row)
return torch.stack(L) return torch.stack(L)
def forward(self, x, label=None): def forward(self, x, _label=None):
""" """
Calculates the GE2E loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats) Calculates the GE2E loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
""" """
@ -141,8 +141,7 @@ class AngleProtoLoss(nn.Module):
print(" > Initialized Angular Prototypical loss") print(" > Initialized Angular Prototypical loss")
# pylint: disable=W0613 def forward(self, x, _label=None):
def forward(self, x, label=None):
""" """
Calculates the AngleProto loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats) Calculates the AngleProto loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
""" """

View File

@ -13,11 +13,11 @@ class SpeakerEncoderConfig(BaseTrainingConfig):
model: str = "speaker_encoder" model: str = "speaker_encoder"
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig) audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()]) datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
# model params # model params
model_params: dict = field( model_params: dict = field(
default_factory=lambda: { default_factory=lambda: {
"input_dim": 40, "model_name": "lstm",
"input_dim": 80,
"proj_dim": 256, "proj_dim": 256,
"lstm_dim": 768, "lstm_dim": 768,
"num_lstm_layers": 3, "num_lstm_layers": 3,
@ -25,16 +25,20 @@ class SpeakerEncoderConfig(BaseTrainingConfig):
} }
) )
audio_augmentation : dict = field(
default_factory=lambda: {
}
)
storage: dict = field( storage: dict = field(
default_factory=lambda: { default_factory=lambda: {
"sample_from_storage_p": 0.66, # the probability with which we'll sample from the DataSet in-memory storage "sample_from_storage_p": 0.66, # the probability with which we'll sample from the DataSet in-memory storage
"storage_size": 15, # the size of the in-memory storage with respect to a single batch "storage_size": 15, # the size of the in-memory storage with respect to a single batch
"additive_noise": 1e-5, # add very small gaussian noise to the data in order to increase robustness
} }
) )
# training params # training params
max_train_step: int = 1000 # end training when number of training steps reaches this value. max_train_step: int = 1000000 # end training when number of training steps reaches this value.
loss: str = "angleproto" loss: str = "angleproto"
grad_clip: float = 3.0 grad_clip: float = 3.0
lr: float = 0.0001 lr: float = 0.0001
@ -53,6 +57,8 @@ class SpeakerEncoderConfig(BaseTrainingConfig):
num_speakers_in_batch: int = MISSING num_speakers_in_batch: int = MISSING
num_utters_per_speaker: int = MISSING num_utters_per_speaker: int = MISSING
num_loader_workers: int = MISSING num_loader_workers: int = MISSING
skip_speakers: bool = False
voice_len: float = 1.6
def check_values(self): def check_values(self):
super().check_values() super().check_values()

View File

@ -1,16 +1,17 @@
import re import re
import os
import numpy as np import numpy as np
import torch import torch
import glob import glob
import random import random
import datetime
from scipy import signal from scipy import signal
from multiprocessing import Manager from multiprocessing import Manager
from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder
from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder
from TTS.utils.generic_utils import check_argument
class Storage(object): class Storage(object):
def __init__(self, maxsize, storage_batchs, num_speakers_in_batch, num_threads=8): def __init__(self, maxsize, storage_batchs, num_speakers_in_batch, num_threads=8):
@ -157,10 +158,10 @@ def to_camel(text):
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
def setup_model(c): def setup_model(c):
if c.model_name.lower() == 'lstm': if c.model_params['model_name'].lower() == 'lstm':
model = LSTMSpeakerEncoder(c.model["input_dim"], c.model["proj_dim"], c.model["lstm_dim"], c.model["num_lstm_layers"]) model = LSTMSpeakerEncoder(c.model_params["input_dim"], c.model_params["proj_dim"], c.model_params["lstm_dim"], c.model_params["num_lstm_layers"])
elif c.model_name.lower() == 'resnet': elif c.model_params['model_name'].lower() == 'resnet':
model = ResNetSpeakerEncoder(input_dim=c.model["input_dim"], proj_dim=c.model["proj_dim"]) model = ResNetSpeakerEncoder(input_dim=c.model_params["input_dim"], proj_dim=c.model_params["proj_dim"])
return model return model
def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_step, epoch): def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_step, epoch):
@ -198,75 +199,3 @@ def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path)) print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
torch.save(state, bestmodel_path) torch.save(state, bestmodel_path)
return best_loss return best_loss
def check_config_speaker_encoder(c):
"""Check the config.json file of the speaker encoder"""
check_argument("run_name", c, restricted=True, val_type=str)
check_argument("run_description", c, val_type=str)
# audio processing parameters
check_argument("audio", c, restricted=True, val_type=dict)
check_argument("num_mels", c["audio"], restricted=True, val_type=int, min_val=10, max_val=2056)
check_argument("fft_size", c["audio"], restricted=True, val_type=int, min_val=128, max_val=4058)
check_argument("sample_rate", c["audio"], restricted=True, val_type=int, min_val=512, max_val=100000)
check_argument(
"frame_length_ms",
c["audio"],
restricted=True,
val_type=float,
min_val=10,
max_val=1000,
alternative="win_length",
)
check_argument(
"frame_shift_ms", c["audio"], restricted=True, val_type=float, min_val=1, max_val=1000, alternative="hop_length"
)
check_argument("preemphasis", c["audio"], restricted=True, val_type=float, min_val=0, max_val=1)
check_argument("min_level_db", c["audio"], restricted=True, val_type=int, min_val=-1000, max_val=10)
check_argument("ref_level_db", c["audio"], restricted=True, val_type=int, min_val=0, max_val=1000)
check_argument("power", c["audio"], restricted=True, val_type=float, min_val=1, max_val=5)
check_argument("griffin_lim_iters", c["audio"], restricted=True, val_type=int, min_val=10, max_val=1000)
# training parameters
check_argument("loss", c, enum_list=["ge2e", "angleproto", "softmaxproto"], restricted=True, val_type=str)
check_argument("grad_clip", c, restricted=True, val_type=float)
check_argument("epochs", c, restricted=True, val_type=int, min_val=1)
check_argument("lr", c, restricted=True, val_type=float, min_val=0)
check_argument("lr_decay", c, restricted=True, val_type=bool)
check_argument("warmup_steps", c, restricted=True, val_type=int, min_val=0)
check_argument("tb_model_param_stats", c, restricted=True, val_type=bool)
check_argument("num_speakers_in_batch", c, restricted=True, val_type=int)
check_argument("num_loader_workers", c, restricted=True, val_type=int)
check_argument("wd", c, restricted=True, val_type=float, min_val=0.0, max_val=1.0)
# checkpoint and output parameters
check_argument("steps_plot_stats", c, restricted=True, val_type=int)
check_argument("checkpoint", c, restricted=True, val_type=bool)
check_argument("save_step", c, restricted=True, val_type=int)
check_argument("print_step", c, restricted=True, val_type=int)
check_argument("output_path", c, restricted=True, val_type=str)
# model parameters
check_argument("model", c, restricted=True, val_type=dict)
check_argument("model_name", c, restricted=True, val_type=str)
check_argument("input_dim", c["model"], restricted=True, val_type=int)
check_argument("proj_dim", c["model"], restricted=True, val_type=int)
if c.model_name.lower() == 'lstm':
check_argument("lstm_dim", c["model"], restricted=True, val_type=int)
check_argument("num_lstm_layers", c["model"], restricted=True, val_type=int)
check_argument("use_lstm_with_projection", c["model"], restricted=True, val_type=bool)
# in-memory storage parameters
check_argument("storage", c, restricted=True, val_type=dict)
check_argument("sample_from_storage_p", c["storage"], restricted=True, val_type=float, min_val=0.0, max_val=1.0)
check_argument("storage_size", c["storage"], restricted=True, val_type=int, min_val=1, max_val=100)
# datasets - checking only the first entry
check_argument("datasets", c, restricted=True, val_type=list)
for dataset_entry in c["datasets"]:
check_argument("name", dataset_entry, restricted=True, val_type=str)
check_argument("path", dataset_entry, restricted=True, val_type=str)
check_argument("meta_file_train", dataset_entry, restricted=True, val_type=[str, list])
check_argument("meta_file_val", dataset_entry, restricted=True, val_type=str)