mirror of https://github.com/coqui-ai/TTS.git
307 lines
12 KiB
Python
307 lines
12 KiB
Python
import datetime
|
|
import os
|
|
import re
|
|
|
|
import numpy as np
|
|
import torch
|
|
import glob
|
|
import random
|
|
|
|
from scipy import signal
|
|
from multiprocessing import Manager
|
|
|
|
from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder
|
|
from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder
|
|
from TTS.utils.generic_utils import check_argument
|
|
|
|
class Storage(object):
|
|
def __init__(self, maxsize, storage_batchs, num_speakers_in_batch, num_threads=8):
|
|
# use multiprocessing for threading safe
|
|
self.storage = Manager().list()
|
|
self.maxsize = maxsize
|
|
self.num_speakers_in_batch = num_speakers_in_batch
|
|
self.num_threads = num_threads
|
|
self.ignore_last_batch = False
|
|
|
|
if storage_batchs >= 3:
|
|
self.ignore_last_batch = True
|
|
|
|
# used for fast random sample
|
|
self.safe_storage_size = self.maxsize - self.num_threads
|
|
if self.ignore_last_batch:
|
|
self.safe_storage_size -= self.num_speakers_in_batch
|
|
|
|
def __len__(self):
|
|
return len(self.storage)
|
|
|
|
def full(self):
|
|
return len(self.storage) >= self.maxsize
|
|
|
|
def append(self, item):
|
|
# if storage is full, remove an item
|
|
if self.full():
|
|
self.storage.pop(0)
|
|
|
|
self.storage.append(item)
|
|
|
|
def get_random_sample(self):
|
|
# safe storage size considering all threads remove one item from storage in same time
|
|
storage_size = len(self.storage) - self.num_threads
|
|
|
|
if self.ignore_last_batch:
|
|
storage_size -= self.num_speakers_in_batch
|
|
|
|
return self.storage[random.randint(0, storage_size)]
|
|
def get_random_sample_fast(self):
|
|
'''Call this method only when storage is full'''
|
|
return self.storage[random.randint(0, self.safe_storage_size)]
|
|
|
|
class AugmentWAV(object):
|
|
|
|
def __init__(self, ap, augmentation_config):
|
|
|
|
self.ap = ap
|
|
|
|
'''augmentation_config = {
|
|
"p": 1,
|
|
"rir":{
|
|
"rir_path": "rir_path/"
|
|
"conv_mode": "full"
|
|
},
|
|
"additive":{
|
|
"sounds_path": "musan/",
|
|
# directorys in sounds_path
|
|
"speech":{
|
|
"min_snr_in_db": 13,
|
|
"max_snr_in_db": 20,
|
|
"min_num_noises": 3,
|
|
"max_num_noises": 7
|
|
},
|
|
"noise":{
|
|
"min_snr_in_db": 0,
|
|
"max_snr_in_db": 15,
|
|
"min_num_noises": 1,
|
|
"max_num_noises": 1
|
|
},
|
|
"music":{
|
|
"min_snr_in_db": 5,
|
|
"max_snr_in_db": 15,
|
|
"min_num_noises": 1,
|
|
"max_num_noises": 1
|
|
}
|
|
}
|
|
}'''
|
|
|
|
self.use_additive_noise = False
|
|
if 'additive' in augmentation_config.keys():
|
|
self.additive_noise_config = augmentation_config['additive']
|
|
additive_path = self.additive_noise_config['sounds_path']
|
|
if additive_path:
|
|
self.use_additive_noise = True
|
|
# get noise types
|
|
self.additive_noise_types = []
|
|
for key in self.additive_noise_config.keys():
|
|
if isinstance(self.additive_noise_config[key], dict):
|
|
self.additive_noise_types.append(key)
|
|
|
|
additive_files = glob.glob(os.path.join(additive_path,'**/*.wav'), recursive=True)
|
|
|
|
self.noise_list = {}
|
|
|
|
for wav_file in additive_files:
|
|
noise_dir = wav_file.replace(additive_path, '').split(os.sep)[0]
|
|
# ignore not listed directories
|
|
if noise_dir not in self.additive_noise_types:
|
|
continue
|
|
if not noise_dir in self.noise_list:
|
|
self.noise_list[noise_dir] = []
|
|
self.noise_list[noise_dir].append(wav_file)
|
|
|
|
print(f" | > Using Additive Noise Augmentation: with {len(additive_files)} audios instances from {self.additive_noise_types}")
|
|
|
|
self.use_rir = False
|
|
if 'rir' in augmentation_config.keys():
|
|
self.rir_config = augmentation_config['rir']
|
|
if self.rir_config['rir_path']:
|
|
self.rir_files = glob.glob(os.path.join(self.rir_config['rir_path'],'**/*.wav'), recursive=True)
|
|
self.use_rir = True
|
|
|
|
print(f" | > Using RIR Noise Augmentation: with {len(self.rir_files)} audios instances")
|
|
|
|
self.create_augmentation_global_list()
|
|
|
|
def create_augmentation_global_list(self):
|
|
if self.use_additive_noise:
|
|
self.global_noise_list = self.additive_noise_types
|
|
else:
|
|
self.global_noise_list = []
|
|
if self.use_rir:
|
|
self.global_noise_list.append("RIR_AUG")
|
|
|
|
def additive_noise(self, noise_type, audio):
|
|
|
|
clean_db = 10 * np.log10(np.mean(audio**2) + 1e-4)
|
|
|
|
noise_list = random.sample(self.noise_list[noise_type], random.randint(self.additive_noise_config[noise_type]['min_num_noises'], self.additive_noise_config[noise_type]['max_num_noises']))
|
|
|
|
audio_len = audio.shape[0]
|
|
noises_wav = None
|
|
for noise in noise_list:
|
|
noiseaudio = self.ap.load_wav(noise, sr=self.ap.sample_rate)[:audio_len]
|
|
|
|
if noiseaudio.shape[0] < audio_len:
|
|
continue
|
|
|
|
noise_snr = random.uniform(self.additive_noise_config[noise_type]['min_snr_in_db'], self.additive_noise_config[noise_type]['max_num_noises'])
|
|
noise_db = 10 * np.log10(np.mean(noiseaudio ** 2) + 1e-4)
|
|
noise_wav = np.sqrt(10 ** ((clean_db - noise_db - noise_snr) / 10)) * noiseaudio
|
|
|
|
if noises_wav is None:
|
|
noises_wav = noise_wav
|
|
else:
|
|
noises_wav += noise_wav
|
|
|
|
# if all possibel files is less than audio, choose other files
|
|
if noises_wav is None:
|
|
print("audio ignorado")
|
|
return self.additive_noise(noise_type, audio)
|
|
|
|
return audio + noises_wav
|
|
|
|
def reverberate(self, audio):
|
|
audio_len = audio.shape[0]
|
|
|
|
rir_file = random.choice(self.rir_files)
|
|
|
|
rir = self.ap.load_wav(rir_file, sr=self.ap.sample_rate)
|
|
rir = rir / np.sqrt(np.sum(rir ** 2))
|
|
return signal.convolve(audio, rir, mode=self.rir_config['conv_mode'])[:audio_len]
|
|
|
|
def apply_one(self, audio):
|
|
return self.reverberate(audio)
|
|
noise_type = random.choice(self.global_noise_list)
|
|
if noise_type == "RIR_AUG":
|
|
return self.reverberate(audio)
|
|
else:
|
|
return self.additive_noise(noise_type, audio)
|
|
|
|
def to_camel(text):
|
|
text = text.capitalize()
|
|
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
|
|
|
|
|
def setup_model(c):
|
|
if c.model_name.lower() == 'lstm':
|
|
model = LSTMSpeakerEncoder(c.model["input_dim"], c.model["proj_dim"], c.model["lstm_dim"], c.model["num_lstm_layers"])
|
|
elif c.model_name.lower() == 'resnet':
|
|
model = ResNetSpeakerEncoder(input_dim=c.model["input_dim"], proj_dim=c.model["proj_dim"])
|
|
return model
|
|
|
|
|
|
def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_step, epoch):
|
|
checkpoint_path = "checkpoint_{}.pth.tar".format(current_step)
|
|
checkpoint_path = os.path.join(out_path, checkpoint_path)
|
|
print(" | | > Checkpoint saving : {}".format(checkpoint_path))
|
|
|
|
new_state_dict = model.state_dict()
|
|
state = {
|
|
"model": new_state_dict,
|
|
"optimizer": optimizer.state_dict() if optimizer is not None else None,
|
|
"criterion": criterion.state_dict(),
|
|
"step": current_step,
|
|
"epoch": epoch,
|
|
"loss": model_loss,
|
|
"date": datetime.date.today().strftime("%B %d, %Y"),
|
|
}
|
|
torch.save(state, checkpoint_path)
|
|
|
|
|
|
def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step):
|
|
if model_loss < best_loss:
|
|
new_state_dict = model.state_dict()
|
|
state = {
|
|
"model": new_state_dict,
|
|
"optimizer": optimizer.state_dict(),
|
|
"criterion": criterion.state_dict(),
|
|
"step": current_step,
|
|
"loss": model_loss,
|
|
"date": datetime.date.today().strftime("%B %d, %Y"),
|
|
}
|
|
best_loss = model_loss
|
|
bestmodel_path = "best_model.pth.tar"
|
|
bestmodel_path = os.path.join(out_path, bestmodel_path)
|
|
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
|
|
torch.save(state, bestmodel_path)
|
|
return best_loss
|
|
|
|
|
|
def check_config_speaker_encoder(c):
|
|
"""Check the config.json file of the speaker encoder"""
|
|
check_argument("run_name", c, restricted=True, val_type=str)
|
|
check_argument("run_description", c, val_type=str)
|
|
|
|
# audio processing parameters
|
|
check_argument("audio", c, restricted=True, val_type=dict)
|
|
check_argument("num_mels", c["audio"], restricted=True, val_type=int, min_val=10, max_val=2056)
|
|
check_argument("fft_size", c["audio"], restricted=True, val_type=int, min_val=128, max_val=4058)
|
|
check_argument("sample_rate", c["audio"], restricted=True, val_type=int, min_val=512, max_val=100000)
|
|
check_argument(
|
|
"frame_length_ms",
|
|
c["audio"],
|
|
restricted=True,
|
|
val_type=float,
|
|
min_val=10,
|
|
max_val=1000,
|
|
alternative="win_length",
|
|
)
|
|
check_argument(
|
|
"frame_shift_ms", c["audio"], restricted=True, val_type=float, min_val=1, max_val=1000, alternative="hop_length"
|
|
)
|
|
check_argument("preemphasis", c["audio"], restricted=True, val_type=float, min_val=0, max_val=1)
|
|
check_argument("min_level_db", c["audio"], restricted=True, val_type=int, min_val=-1000, max_val=10)
|
|
check_argument("ref_level_db", c["audio"], restricted=True, val_type=int, min_val=0, max_val=1000)
|
|
check_argument("power", c["audio"], restricted=True, val_type=float, min_val=1, max_val=5)
|
|
check_argument("griffin_lim_iters", c["audio"], restricted=True, val_type=int, min_val=10, max_val=1000)
|
|
|
|
# training parameters
|
|
check_argument("loss", c, enum_list=["ge2e", "angleproto", "softmaxproto"], restricted=True, val_type=str)
|
|
check_argument("grad_clip", c, restricted=True, val_type=float)
|
|
check_argument("epochs", c, restricted=True, val_type=int, min_val=1)
|
|
check_argument("lr", c, restricted=True, val_type=float, min_val=0)
|
|
check_argument("lr_decay", c, restricted=True, val_type=bool)
|
|
check_argument("warmup_steps", c, restricted=True, val_type=int, min_val=0)
|
|
check_argument("tb_model_param_stats", c, restricted=True, val_type=bool)
|
|
check_argument("num_speakers_in_batch", c, restricted=True, val_type=int)
|
|
check_argument("num_loader_workers", c, restricted=True, val_type=int)
|
|
check_argument("wd", c, restricted=True, val_type=float, min_val=0.0, max_val=1.0)
|
|
|
|
# checkpoint and output parameters
|
|
check_argument("steps_plot_stats", c, restricted=True, val_type=int)
|
|
check_argument("checkpoint", c, restricted=True, val_type=bool)
|
|
check_argument("save_step", c, restricted=True, val_type=int)
|
|
check_argument("print_step", c, restricted=True, val_type=int)
|
|
check_argument("output_path", c, restricted=True, val_type=str)
|
|
|
|
# model parameters
|
|
check_argument("model", c, restricted=True, val_type=dict)
|
|
check_argument("model_name", c, restricted=True, val_type=str)
|
|
check_argument("input_dim", c["model"], restricted=True, val_type=int)
|
|
if c.model_name.lower() == 'lstm':
|
|
check_argument("proj_dim", c["model"], restricted=True, val_type=int)
|
|
check_argument("lstm_dim", c["model"], restricted=True, val_type=int)
|
|
check_argument("num_lstm_layers", c["model"], restricted=True, val_type=int)
|
|
check_argument("use_lstm_with_projection", c["model"], restricted=True, val_type=bool)
|
|
|
|
# in-memory storage parameters
|
|
check_argument("storage", c, restricted=True, val_type=dict)
|
|
check_argument("sample_from_storage_p", c["storage"], restricted=True, val_type=float, min_val=0.0, max_val=1.0)
|
|
check_argument("storage_size", c["storage"], restricted=True, val_type=int, min_val=1, max_val=100)
|
|
|
|
# datasets - checking only the first entry
|
|
check_argument("datasets", c, restricted=True, val_type=list)
|
|
for dataset_entry in c["datasets"]:
|
|
check_argument("name", dataset_entry, restricted=True, val_type=str)
|
|
check_argument("path", dataset_entry, restricted=True, val_type=str)
|
|
check_argument("meta_file_train", dataset_entry, restricted=True, val_type=[str, list])
|
|
check_argument("meta_file_val", dataset_entry, restricted=True, val_type=str)
|