From bec85ac58d21536e8bbd395eac5f7b70a1618206 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 31 May 2021 16:37:15 +0200 Subject: [PATCH] make style --- TTS/bin/compute_embeddings.py | 4 +- TTS/bin/train_encoder.py | 9 +-- TTS/speaker_encoder/dataset.py | 49 +++++++------ TTS/speaker_encoder/losses.py | 10 ++- TTS/speaker_encoder/models/resnet.py | 37 ++++++---- TTS/speaker_encoder/speaker_encoder_config.py | 5 +- TTS/speaker_encoder/utils/generic_utils.py | 69 ++++++++++++------- tests/test_speaker_encoder.py | 5 ++ tests/test_speaker_encoder_train.py | 2 +- 9 files changed, 115 insertions(+), 75 deletions(-) diff --git a/TTS/bin/compute_embeddings.py b/TTS/bin/compute_embeddings.py index 003da1e5..872fc875 100644 --- a/TTS/bin/compute_embeddings.py +++ b/TTS/bin/compute_embeddings.py @@ -5,11 +5,11 @@ import os import torch from tqdm import tqdm +from TTS.config import BaseDatasetConfig, load_config from TTS.speaker_encoder.utils.generic_utils import setup_model from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.utils.speakers import SpeakerManager from TTS.utils.audio import AudioProcessor -from TTS.config import load_config, BaseDatasetConfig parser = argparse.ArgumentParser( description='Compute embedding vectors for each wav file in a dataset. If "target_dataset" is defined, it generates "speakers.json" necessary for training a multi-speaker model.' @@ -100,7 +100,7 @@ for idx, wav_file in enumerate(tqdm(wav_files)): if speaker_mapping: # save speaker_mapping if target dataset is defined - if '.json' not in args.output_path: + if ".json" not in args.output_path: mapping_file_path = os.path.join(args.output_path, "speakers.json") else: mapping_file_path = args.output_path diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index c9493535..48309dc9 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -10,10 +10,8 @@ import torch from torch.utils.data import DataLoader from TTS.speaker_encoder.dataset import SpeakerEncoderDataset - from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_model - from TTS.speaker_encoder.utils.visual import plot_embeddings from TTS.tts.datasets.preprocess import load_meta_data from TTS.utils.arguments import init_training @@ -45,7 +43,7 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False storage_size=c.storage["storage_size"], sample_from_storage_p=c.storage["sample_from_storage_p"], verbose=verbose, - augmentation_config=c.audio_augmentation + augmentation_config=c.audio_augmentation, ) # sampler = DistributedSampler(dataset) if num_gpus > 1 else None @@ -170,19 +168,18 @@ def main(args): # pylint: disable=redefined-outer-name else: raise Exception("The %s not is a loss supported" % c.loss) - if args.restore_path: checkpoint = torch.load(args.restore_path) try: model.load_state_dict(checkpoint["model"]) - if 'criterion' in checkpoint: + if "criterion" in checkpoint: criterion.load_state_dict(checkpoint["criterion"]) except (KeyError, RuntimeError): print(" > Partial model initialization.") model_dict = model.state_dict() - model_dict = set_init_dict(model_dict, checkpoint['model'], c) + model_dict = set_init_dict(model_dict, checkpoint["model"], c) model.load_state_dict(model_dict) del model_dict for group in optimizer.param_groups: diff --git a/TTS/speaker_encoder/dataset.py b/TTS/speaker_encoder/dataset.py index cd95a4f5..6b2b0dd4 100644 --- a/TTS/speaker_encoder/dataset.py +++ b/TTS/speaker_encoder/dataset.py @@ -1,24 +1,25 @@ - import random import numpy as np import torch from torch.utils.data import Dataset + from TTS.speaker_encoder.utils.generic_utils import AugmentWAV, Storage + class SpeakerEncoderDataset(Dataset): def __init__( - self, - ap, - meta_data, - voice_len=1.6, - num_speakers_in_batch=64, - storage_size=1, - sample_from_storage_p=0.5, - num_utter_per_speaker=10, - skip_speakers=False, - verbose=False, - augmentation_config=None + self, + ap, + meta_data, + voice_len=1.6, + num_speakers_in_batch=64, + storage_size=1, + sample_from_storage_p=0.5, + num_utter_per_speaker=10, + skip_speakers=False, + verbose=False, + augmentation_config=None, ): """ Args: @@ -38,23 +39,25 @@ class SpeakerEncoderDataset(Dataset): self.verbose = verbose self.__parse_items() storage_max_size = storage_size * num_speakers_in_batch - self.storage = Storage(maxsize=storage_max_size, storage_batchs=storage_size, num_speakers_in_batch=num_speakers_in_batch) + self.storage = Storage( + maxsize=storage_max_size, storage_batchs=storage_size, num_speakers_in_batch=num_speakers_in_batch + ) self.sample_from_storage_p = float(sample_from_storage_p) speakers_aux = list(self.speakers) speakers_aux.sort() - self.speakerid_to_classid = {key : i for i, key in enumerate(speakers_aux)} + self.speakerid_to_classid = {key: i for i, key in enumerate(speakers_aux)} # Augmentation self.augmentator = None self.gaussian_augmentation_config = None if augmentation_config: - self.data_augmentation_p = augmentation_config['p'] - if self.data_augmentation_p and ('additive' in augmentation_config or 'rir' in augmentation_config): + self.data_augmentation_p = augmentation_config["p"] + if self.data_augmentation_p and ("additive" in augmentation_config or "rir" in augmentation_config): self.augmentator = AugmentWAV(ap, augmentation_config) - if 'gaussian' in augmentation_config.keys(): - self.gaussian_augmentation_config = augmentation_config['gaussian'] + if "gaussian" in augmentation_config.keys(): + self.gaussian_augmentation_config = augmentation_config["gaussian"] if self.verbose: print("\n > DataLoader initialization") @@ -231,9 +234,13 @@ class SpeakerEncoderDataset(Dataset): offset = random.randint(0, wav.shape[0] - self.seq_len) wav = wav[offset : offset + self.seq_len] # add random gaussian noise - if self.gaussian_augmentation_config and self.gaussian_augmentation_config['p']: - if random.random() < self.gaussian_augmentation_config['p']: - wav += np.random.normal(self.gaussian_augmentation_config['min_amplitude'], self.gaussian_augmentation_config['max_amplitude'], size=len(wav)) + if self.gaussian_augmentation_config and self.gaussian_augmentation_config["p"]: + if random.random() < self.gaussian_augmentation_config["p"]: + wav += np.random.normal( + self.gaussian_augmentation_config["min_amplitude"], + self.gaussian_augmentation_config["max_amplitude"], + size=len(wav), + ) mel = self.ap.melspectrogram(wav) feats_.append(torch.FloatTensor(mel)) diff --git a/TTS/speaker_encoder/losses.py b/TTS/speaker_encoder/losses.py index 9b573b6d..ac7e62bf 100644 --- a/TTS/speaker_encoder/losses.py +++ b/TTS/speaker_encoder/losses.py @@ -162,6 +162,7 @@ class AngleProtoLoss(nn.Module): L = self.criterion(cos_sim_matrix, label) return L + class SoftmaxLoss(nn.Module): """ Implementation of the Softmax loss as defined in https://arxiv.org/abs/2003.11982 @@ -169,13 +170,14 @@ class SoftmaxLoss(nn.Module): - embedding_dim (float): speaker embedding dim - n_speakers (float): number of speakers """ + def __init__(self, embedding_dim, n_speakers): super().__init__() self.criterion = torch.nn.CrossEntropyLoss() self.fc = nn.Linear(embedding_dim, n_speakers) - print('Initialised Softmax Loss') + print("Initialised Softmax Loss") def forward(self, x, label=None): # reshape for compatibility @@ -187,6 +189,7 @@ class SoftmaxLoss(nn.Module): return L + class SoftmaxAngleProtoLoss(nn.Module): """ Implementation of the Softmax AnglePrototypical loss as defined in https://arxiv.org/abs/2009.14153 @@ -196,13 +199,14 @@ class SoftmaxAngleProtoLoss(nn.Module): - init_w (float): defines the initial value of w - init_b (float): definies the initial value of b """ + def __init__(self, embedding_dim, n_speakers, init_w=10.0, init_b=-5.0): super().__init__() self.softmax = SoftmaxLoss(embedding_dim, n_speakers) self.angleproto = AngleProtoLoss(init_w, init_b) - print('Initialised SoftmaxAnglePrototypical Loss') + print("Initialised SoftmaxAnglePrototypical Loss") def forward(self, x, label=None): """ @@ -213,4 +217,4 @@ class SoftmaxAngleProtoLoss(nn.Module): Ls = self.softmax(x, label) - return Ls+Lp + return Ls + Lp diff --git a/TTS/speaker_encoder/models/resnet.py b/TTS/speaker_encoder/models/resnet.py index aa2171ed..ce86b01f 100644 --- a/TTS/speaker_encoder/models/resnet.py +++ b/TTS/speaker_encoder/models/resnet.py @@ -1,7 +1,8 @@ -import torch import numpy as np +import torch import torch.nn as nn + class SELayer(nn.Module): def __init__(self, channel, reduction=8): super(SELayer, self).__init__() @@ -10,7 +11,7 @@ class SELayer(nn.Module): nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True), nn.Linear(channel // reduction, channel), - nn.Sigmoid() + nn.Sigmoid(), ) def forward(self, x): @@ -19,6 +20,7 @@ class SELayer(nn.Module): y = self.fc(y).view(b, c, 1, 1) return x * y + class SEBasicBlock(nn.Module): expansion = 1 @@ -51,12 +53,22 @@ class SEBasicBlock(nn.Module): out = self.relu(out) return out + class ResNetSpeakerEncoder(nn.Module): """Implementation of the model H/ASP without batch normalization in speaker embedding. This model was proposed in: https://arxiv.org/abs/2009.14153 Adapted from: https://github.com/clovaai/voxceleb_trainer """ + # pylint: disable=W0102 - def __init__(self, input_dim=64, proj_dim=512, layers=[3, 4, 6, 3], num_filters=[32, 64, 128, 256], encoder_type='ASP', log_input=False): + def __init__( + self, + input_dim=64, + proj_dim=512, + layers=[3, 4, 6, 3], + num_filters=[32, 64, 128, 256], + encoder_type="ASP", + log_input=False, + ): super(ResNetSpeakerEncoder, self).__init__() self.encoder_type = encoder_type @@ -74,7 +86,7 @@ class ResNetSpeakerEncoder(nn.Module): self.instancenorm = nn.InstanceNorm1d(input_dim) - outmap_size = int(self.input_dim/8) + outmap_size = int(self.input_dim / 8) self.attention = nn.Sequential( nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1), @@ -82,14 +94,14 @@ class ResNetSpeakerEncoder(nn.Module): nn.BatchNorm1d(128), nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1), nn.Softmax(dim=2), - ) + ) if self.encoder_type == "SAP": out_dim = num_filters[3] * outmap_size elif self.encoder_type == "ASP": out_dim = num_filters[3] * outmap_size * 2 else: - raise ValueError('Undefined encoder') + raise ValueError("Undefined encoder") self.fc = nn.Linear(out_dim, proj_dim) @@ -98,7 +110,7 @@ class ResNetSpeakerEncoder(nn.Module): def _init_layers(self): for m in self.modules(): if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) @@ -107,8 +119,7 @@ class ResNetSpeakerEncoder(nn.Module): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( - nn.Conv2d(self.inplanes, planes * block.expansion, - kernel_size=1, stride=stride, bias=False), + nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion), ) @@ -131,7 +142,7 @@ class ResNetSpeakerEncoder(nn.Module): with torch.no_grad(): with torch.cuda.amp.autocast(enabled=False): if self.log_input: - x = (x+1e-6).log() + x = (x + 1e-6).log() x = self.instancenorm(x).unsqueeze(1) x = self.conv1(x) @@ -151,7 +162,7 @@ class ResNetSpeakerEncoder(nn.Module): x = torch.sum(x * w, dim=2) elif self.encoder_type == "ASP": mu = torch.sum(x * w, dim=2) - sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu ** 2).clamp(min=1e-5)) + sg = torch.sqrt((torch.sum((x ** 2) * w, dim=2) - mu ** 2).clamp(min=1e-5)) x = torch.cat((mu, sg), 1) x = x.view(x.size()[0], -1) @@ -172,12 +183,12 @@ class ResNetSpeakerEncoder(nn.Module): if max_len < num_frames: num_frames = max_len - offsets = np.linspace(0, max_len-num_frames, num=num_eval) + offsets = np.linspace(0, max_len - num_frames, num=num_eval) frames_batch = [] for offset in offsets: offset = int(offset) - end_offset = int(offset+num_frames) + end_offset = int(offset + num_frames) frames = x[:, offset:end_offset] frames_batch.append(frames) diff --git a/TTS/speaker_encoder/speaker_encoder_config.py b/TTS/speaker_encoder/speaker_encoder_config.py index 31149822..e830a0f5 100644 --- a/TTS/speaker_encoder/speaker_encoder_config.py +++ b/TTS/speaker_encoder/speaker_encoder_config.py @@ -25,10 +25,7 @@ class SpeakerEncoderConfig(BaseTrainingConfig): } ) - audio_augmentation : dict = field( - default_factory=lambda: { - } - ) + audio_augmentation: dict = field(default_factory=lambda: {}) storage: dict = field( default_factory=lambda: { diff --git a/TTS/speaker_encoder/utils/generic_utils.py b/TTS/speaker_encoder/utils/generic_utils.py index 3299f75a..fb61e48e 100644 --- a/TTS/speaker_encoder/utils/generic_utils.py +++ b/TTS/speaker_encoder/utils/generic_utils.py @@ -1,18 +1,18 @@ -import re +import datetime +import glob import os +import random +import re +from multiprocessing import Manager import numpy as np import torch -import glob -import random -import datetime - from scipy import signal -from multiprocessing import Manager from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder + class Storage(object): def __init__(self, maxsize, storage_batchs, num_speakers_in_batch, num_threads=8): # use multiprocessing for threading safe @@ -53,19 +53,19 @@ class Storage(object): return self.storage[random.randint(0, storage_size)] def get_random_sample_fast(self): - '''Call this method only when storage is full''' + """Call this method only when storage is full""" return self.storage[random.randint(0, self.safe_storage_size)] -class AugmentWAV(object): +class AugmentWAV(object): def __init__(self, ap, augmentation_config): self.ap = ap self.use_additive_noise = False - if 'additive' in augmentation_config.keys(): - self.additive_noise_config = augmentation_config['additive'] - additive_path = self.additive_noise_config['sounds_path'] + if "additive" in augmentation_config.keys(): + self.additive_noise_config = augmentation_config["additive"] + additive_path = self.additive_noise_config["sounds_path"] if additive_path: self.use_additive_noise = True # get noise types @@ -74,12 +74,12 @@ class AugmentWAV(object): if isinstance(self.additive_noise_config[key], dict): self.additive_noise_types.append(key) - additive_files = glob.glob(os.path.join(additive_path, '**/*.wav'), recursive=True) + additive_files = glob.glob(os.path.join(additive_path, "**/*.wav"), recursive=True) self.noise_list = {} for wav_file in additive_files: - noise_dir = wav_file.replace(additive_path, '').split(os.sep)[0] + noise_dir = wav_file.replace(additive_path, "").split(os.sep)[0] # ignore not listed directories if noise_dir not in self.additive_noise_types: continue @@ -87,14 +87,16 @@ class AugmentWAV(object): self.noise_list[noise_dir] = [] self.noise_list[noise_dir].append(wav_file) - print(f" | > Using Additive Noise Augmentation: with {len(additive_files)} audios instances from {self.additive_noise_types}") + print( + f" | > Using Additive Noise Augmentation: with {len(additive_files)} audios instances from {self.additive_noise_types}" + ) self.use_rir = False - if 'rir' in augmentation_config.keys(): - self.rir_config = augmentation_config['rir'] - if self.rir_config['rir_path']: - self.rir_files = glob.glob(os.path.join(self.rir_config['rir_path'], '**/*.wav'), recursive=True) + if "rir" in augmentation_config.keys(): + self.rir_config = augmentation_config["rir"] + if self.rir_config["rir_path"]: + self.rir_files = glob.glob(os.path.join(self.rir_config["rir_path"], "**/*.wav"), recursive=True) self.use_rir = True print(f" | > Using RIR Noise Augmentation: with {len(self.rir_files)} audios instances") @@ -111,9 +113,15 @@ class AugmentWAV(object): def additive_noise(self, noise_type, audio): - clean_db = 10 * np.log10(np.mean(audio**2) + 1e-4) + clean_db = 10 * np.log10(np.mean(audio ** 2) + 1e-4) - noise_list = random.sample(self.noise_list[noise_type], random.randint(self.additive_noise_config[noise_type]['min_num_noises'], self.additive_noise_config[noise_type]['max_num_noises'])) + noise_list = random.sample( + self.noise_list[noise_type], + random.randint( + self.additive_noise_config[noise_type]["min_num_noises"], + self.additive_noise_config[noise_type]["max_num_noises"], + ), + ) audio_len = audio.shape[0] noises_wav = None @@ -123,7 +131,10 @@ class AugmentWAV(object): if noiseaudio.shape[0] < audio_len: continue - noise_snr = random.uniform(self.additive_noise_config[noise_type]['min_snr_in_db'], self.additive_noise_config[noise_type]['max_num_noises']) + noise_snr = random.uniform( + self.additive_noise_config[noise_type]["min_snr_in_db"], + self.additive_noise_config[noise_type]["max_num_noises"], + ) noise_db = 10 * np.log10(np.mean(noiseaudio ** 2) + 1e-4) noise_wav = np.sqrt(10 ** ((clean_db - noise_db - noise_snr) / 10)) * noiseaudio @@ -144,7 +155,7 @@ class AugmentWAV(object): rir_file = random.choice(self.rir_files) rir = self.ap.load_wav(rir_file, sr=self.ap.sample_rate) rir = rir / np.sqrt(np.sum(rir ** 2)) - return signal.convolve(audio, rir, mode=self.rir_config['conv_mode'])[:audio_len] + return signal.convolve(audio, rir, mode=self.rir_config["conv_mode"])[:audio_len] def apply_one(self, audio): noise_type = random.choice(self.global_noise_list) @@ -153,17 +164,25 @@ class AugmentWAV(object): return self.additive_noise(noise_type, audio) + def to_camel(text): text = text.capitalize() return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) + def setup_model(c): - if c.model_params['model_name'].lower() == 'lstm': - model = LSTMSpeakerEncoder(c.model_params["input_dim"], c.model_params["proj_dim"], c.model_params["lstm_dim"], c.model_params["num_lstm_layers"]) - elif c.model_params['model_name'].lower() == 'resnet': + if c.model_params["model_name"].lower() == "lstm": + model = LSTMSpeakerEncoder( + c.model_params["input_dim"], + c.model_params["proj_dim"], + c.model_params["lstm_dim"], + c.model_params["num_lstm_layers"], + ) + elif c.model_params["model_name"].lower() == "resnet": model = ResNetSpeakerEncoder(input_dim=c.model_params["input_dim"], proj_dim=c.model_params["proj_dim"]) return model + def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_step, epoch): checkpoint_path = "checkpoint_{}.pth.tar".format(current_step) checkpoint_path = os.path.join(out_path, checkpoint_path) diff --git a/tests/test_speaker_encoder.py b/tests/test_speaker_encoder.py index f56a9577..0bb07f37 100644 --- a/tests/test_speaker_encoder.py +++ b/tests/test_speaker_encoder.py @@ -6,6 +6,7 @@ from tests import get_tests_input_path from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder + file_path = get_tests_input_path() @@ -39,6 +40,7 @@ class LSTMSpeakerEncoderTests(unittest.TestCase): assert output.shape[1] == 256 assert len(output.shape) == 2 + class ResNetSpeakerEncoderTests(unittest.TestCase): # pylint: disable=R0201 def test_in_out(self): @@ -65,6 +67,7 @@ class ResNetSpeakerEncoderTests(unittest.TestCase): assert output.shape[1] == 256 assert len(output.shape) == 2 + class GE2ELossTests(unittest.TestCase): # pylint: disable=R0201 def test_in_out(self): @@ -92,6 +95,7 @@ class GE2ELossTests(unittest.TestCase): output = loss.forward(dummy_input) assert output.item() < 0.005 + class AngleProtoLossTests(unittest.TestCase): # pylint: disable=R0201 def test_in_out(self): @@ -121,6 +125,7 @@ class AngleProtoLossTests(unittest.TestCase): output = loss.forward(dummy_input) assert output.item() < 0.005 + class SoftmaxAngleProtoLossTests(unittest.TestCase): # pylint: disable=R0201 def test_in_out(self): diff --git a/tests/test_speaker_encoder_train.py b/tests/test_speaker_encoder_train.py index e168a785..21b12074 100644 --- a/tests/test_speaker_encoder_train.py +++ b/tests/test_speaker_encoder_train.py @@ -46,7 +46,7 @@ run_cli(command_train) shutil.rmtree(continue_path) # test resnet speaker encoder -config.model_params['model_name'] = "resnet" +config.model_params["model_name"] = "resnet" config.save_json(config_path) # train the model for one epoch