coqui-tts/TTS/speaker_encoder/utils/generic_utils.py

255 lines
11 KiB
Python

import datetime
import os
import re
import numpy as np
import torch
import glob
import random
from scipy import signal
from TTS.speaker_encoder.model import SpeakerEncoder
from TTS.utils.generic_utils import check_argument
class AugmentWAV(object):
def __init__(self, ap, augmentation_config):
self.ap = ap
'''augmentation_config = {
"p": 1,
"rir":{
"rir_path": "rir_path/"
"conv_mode": "full"
},
"additive":{
"sounds_path": "musan/",
# directorys in sounds_path
"speech":{
"min_snr_in_db": 13,
"max_snr_in_db": 20,
"min_num_noises": 3,
"max_num_noises": 7
},
"noise":{
"min_snr_in_db": 0,
"max_snr_in_db": 15,
"min_num_noises": 1,
"max_num_noises": 1
},
"music":{
"min_snr_in_db": 5,
"max_snr_in_db": 15,
"min_num_noises": 1,
"max_num_noises": 1
}
}
}'''
self.use_additive_noise = False
if 'additive' in augmentation_config.keys():
self.additive_noise_config = augmentation_config['additive']
additive_path = self.additive_noise_config['sounds_path']
if additive_path:
self.use_additive_noise = True
# get noise types
self.additive_noise_types = []
for key in self.additive_noise_config.keys():
if isinstance(self.additive_noise_config[key], dict):
self.additive_noise_types.append(key)
additive_files = glob.glob(os.path.join(additive_path,'**/*.wav'), recursive=True)
self.noise_list = {}
for wav_file in additive_files:
noise_dir = wav_file.replace(additive_path, '').split(os.sep)[0]
# ignore not listed directories
if noise_dir not in self.additive_noise_types:
continue
if not noise_dir in self.noise_list:
self.noise_list[noise_dir] = []
self.noise_list[noise_dir].append(wav_file)
print(f" | > Using Additive Noise Augmentation: with {len(additive_files)} audios instances from {self.additive_noise_types}")
self.use_rir = False
if 'rir' in augmentation_config.keys():
self.rir_config = augmentation_config['rir']
if self.rir_config['rir_path']:
self.rir_files = glob.glob(os.path.join(self.rir_config['rir_path'],'**/*.wav'), recursive=True)
self.use_rir = True
print(f" | > Using RIR Noise Augmentation: with {len(self.rir_files)} audios instances")
self.create_augmentation_global_list()
def create_augmentation_global_list(self):
if self.use_additive_noise:
self.global_noise_list = self.additive_noise_types
else:
self.global_noise_list = []
if self.use_rir:
self.global_noise_list.append("RIR_AUG")
def additive_noise(self, noise_type, audio):
clean_db = 10 * np.log10(np.mean(audio**2) + 1e-4)
noise_list = random.sample(self.noise_list[noise_type], random.randint(self.additive_noise_config[noise_type]['min_num_noises'], self.additive_noise_config[noise_type]['max_num_noises']))
audio_len = audio.shape[0]
noises_wav = None
for noise in noise_list:
noiseaudio = self.ap.load_wav(noise, sr=self.ap.sample_rate)[:audio_len]
if noiseaudio.shape[0] < audio_len:
continue
noise_snr = random.uniform(self.additive_noise_config[noise_type]['min_snr_in_db'], self.additive_noise_config[noise_type]['max_num_noises'])
noise_db = 10 * np.log10(np.mean(noiseaudio ** 2) + 1e-4)
noise_wav = np.sqrt(10 ** ((clean_db - noise_db - noise_snr) / 10)) * noiseaudio
if noises_wav is None:
noises_wav = noise_wav
else:
noises_wav += noise_wav
# if all possibel files is less than audio, choose other files
if noises_wav is None:
print("audio ignorado")
return self.additive_noise(noise_type, audio)
return audio + noises_wav
def reverberate(self, audio):
audio_len = audio.shape[0]
rir_file = random.choice(self.rir_files)
rir = self.ap.load_wav(rir_file, sr=self.ap.sample_rate)
rir = rir / np.sqrt(np.sum(rir ** 2))
return signal.convolve(audio, rir, mode=self.rir_config['conv_mode'])[:audio_len]
def apply_one(self, audio):
return self.reverberate(audio)
noise_type = random.choice(self.global_noise_list)
if noise_type == "RIR_AUG":
return self.reverberate(audio)
else:
return self.additive_noise(noise_type, audio)
def to_camel(text):
text = text.capitalize()
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
def setup_model(c):
model = SpeakerEncoder(c.model["input_dim"], c.model["proj_dim"], c.model["lstm_dim"], c.model["num_lstm_layers"])
return model
def save_checkpoint(model, optimizer, model_loss, out_path, current_step, epoch):
checkpoint_path = "checkpoint_{}.pth.tar".format(current_step)
checkpoint_path = os.path.join(out_path, checkpoint_path)
print(" | | > Checkpoint saving : {}".format(checkpoint_path))
new_state_dict = model.state_dict()
state = {
"model": new_state_dict,
"optimizer": optimizer.state_dict() if optimizer is not None else None,
"step": current_step,
"epoch": epoch,
"loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
torch.save(state, checkpoint_path)
def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_step):
if model_loss < best_loss:
new_state_dict = model.state_dict()
state = {
"model": new_state_dict,
"optimizer": optimizer.state_dict(),
"step": current_step,
"loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
best_loss = model_loss
bestmodel_path = "best_model.pth.tar"
bestmodel_path = os.path.join(out_path, bestmodel_path)
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
torch.save(state, bestmodel_path)
return best_loss
def check_config_speaker_encoder(c):
"""Check the config.json file of the speaker encoder"""
check_argument("run_name", c, restricted=True, val_type=str)
check_argument("run_description", c, val_type=str)
# audio processing parameters
check_argument("audio", c, restricted=True, val_type=dict)
check_argument("num_mels", c["audio"], restricted=True, val_type=int, min_val=10, max_val=2056)
check_argument("fft_size", c["audio"], restricted=True, val_type=int, min_val=128, max_val=4058)
check_argument("sample_rate", c["audio"], restricted=True, val_type=int, min_val=512, max_val=100000)
check_argument(
"frame_length_ms",
c["audio"],
restricted=True,
val_type=float,
min_val=10,
max_val=1000,
alternative="win_length",
)
check_argument(
"frame_shift_ms", c["audio"], restricted=True, val_type=float, min_val=1, max_val=1000, alternative="hop_length"
)
check_argument("preemphasis", c["audio"], restricted=True, val_type=float, min_val=0, max_val=1)
check_argument("min_level_db", c["audio"], restricted=True, val_type=int, min_val=-1000, max_val=10)
check_argument("ref_level_db", c["audio"], restricted=True, val_type=int, min_val=0, max_val=1000)
check_argument("power", c["audio"], restricted=True, val_type=float, min_val=1, max_val=5)
check_argument("griffin_lim_iters", c["audio"], restricted=True, val_type=int, min_val=10, max_val=1000)
# training parameters
check_argument("loss", c, enum_list=["ge2e", "angleproto", "softmaxproto"], restricted=True, val_type=str)
check_argument("grad_clip", c, restricted=True, val_type=float)
check_argument("epochs", c, restricted=True, val_type=int, min_val=1)
check_argument("lr", c, restricted=True, val_type=float, min_val=0)
check_argument("lr_decay", c, restricted=True, val_type=bool)
check_argument("warmup_steps", c, restricted=True, val_type=int, min_val=0)
check_argument("tb_model_param_stats", c, restricted=True, val_type=bool)
check_argument("num_speakers_in_batch", c, restricted=True, val_type=int)
check_argument("num_loader_workers", c, restricted=True, val_type=int)
check_argument("wd", c, restricted=True, val_type=float, min_val=0.0, max_val=1.0)
# checkpoint and output parameters
check_argument("steps_plot_stats", c, restricted=True, val_type=int)
check_argument("checkpoint", c, restricted=True, val_type=bool)
check_argument("save_step", c, restricted=True, val_type=int)
check_argument("print_step", c, restricted=True, val_type=int)
check_argument("output_path", c, restricted=True, val_type=str)
# model parameters
check_argument("model", c, restricted=True, val_type=dict)
check_argument("input_dim", c["model"], restricted=True, val_type=int)
check_argument("proj_dim", c["model"], restricted=True, val_type=int)
check_argument("lstm_dim", c["model"], restricted=True, val_type=int)
check_argument("num_lstm_layers", c["model"], restricted=True, val_type=int)
check_argument("use_lstm_with_projection", c["model"], restricted=True, val_type=bool)
# in-memory storage parameters
check_argument("storage", c, restricted=True, val_type=dict)
check_argument("sample_from_storage_p", c["storage"], restricted=True, val_type=float, min_val=0.0, max_val=1.0)
check_argument("storage_size", c["storage"], restricted=True, val_type=int, min_val=1, max_val=100)
# datasets - checking only the first entry
check_argument("datasets", c, restricted=True, val_type=list)
for dataset_entry in c["datasets"]:
check_argument("name", dataset_entry, restricted=True, val_type=str)
check_argument("path", dataset_entry, restricted=True, val_type=str)
check_argument("meta_file_train", dataset_entry, restricted=True, val_type=[str, list])
check_argument("meta_file_val", dataset_entry, restricted=True, val_type=str)