make style

This commit is contained in:
Eren Gölge 2021-05-31 16:37:15 +02:00
parent 975531b3f2
commit bec85ac58d
9 changed files with 115 additions and 75 deletions

View File

@ -5,11 +5,11 @@ import os
import torch
from tqdm import tqdm
from TTS.config import BaseDatasetConfig, load_config
from TTS.speaker_encoder.utils.generic_utils import setup_model
from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor
from TTS.config import load_config, BaseDatasetConfig
parser = argparse.ArgumentParser(
description='Compute embedding vectors for each wav file in a dataset. If "target_dataset" is defined, it generates "speakers.json" necessary for training a multi-speaker model.'
@ -100,7 +100,7 @@ for idx, wav_file in enumerate(tqdm(wav_files)):
if speaker_mapping:
# save speaker_mapping if target dataset is defined
if '.json' not in args.output_path:
if ".json" not in args.output_path:
mapping_file_path = os.path.join(args.output_path, "speakers.json")
else:
mapping_file_path = args.output_path

View File

@ -10,10 +10,8 @@ import torch
from torch.utils.data import DataLoader
from TTS.speaker_encoder.dataset import SpeakerEncoderDataset
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_model
from TTS.speaker_encoder.utils.visual import plot_embeddings
from TTS.tts.datasets.preprocess import load_meta_data
from TTS.utils.arguments import init_training
@ -45,7 +43,7 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
storage_size=c.storage["storage_size"],
sample_from_storage_p=c.storage["sample_from_storage_p"],
verbose=verbose,
augmentation_config=c.audio_augmentation
augmentation_config=c.audio_augmentation,
)
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
@ -170,19 +168,18 @@ def main(args): # pylint: disable=redefined-outer-name
else:
raise Exception("The %s not is a loss supported" % c.loss)
if args.restore_path:
checkpoint = torch.load(args.restore_path)
try:
model.load_state_dict(checkpoint["model"])
if 'criterion' in checkpoint:
if "criterion" in checkpoint:
criterion.load_state_dict(checkpoint["criterion"])
except (KeyError, RuntimeError):
print(" > Partial model initialization.")
model_dict = model.state_dict()
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
model_dict = set_init_dict(model_dict, checkpoint["model"], c)
model.load_state_dict(model_dict)
del model_dict
for group in optimizer.param_groups:

View File

@ -1,24 +1,25 @@
import random
import numpy as np
import torch
from torch.utils.data import Dataset
from TTS.speaker_encoder.utils.generic_utils import AugmentWAV, Storage
class SpeakerEncoderDataset(Dataset):
def __init__(
self,
ap,
meta_data,
voice_len=1.6,
num_speakers_in_batch=64,
storage_size=1,
sample_from_storage_p=0.5,
num_utter_per_speaker=10,
skip_speakers=False,
verbose=False,
augmentation_config=None
self,
ap,
meta_data,
voice_len=1.6,
num_speakers_in_batch=64,
storage_size=1,
sample_from_storage_p=0.5,
num_utter_per_speaker=10,
skip_speakers=False,
verbose=False,
augmentation_config=None,
):
"""
Args:
@ -38,23 +39,25 @@ class SpeakerEncoderDataset(Dataset):
self.verbose = verbose
self.__parse_items()
storage_max_size = storage_size * num_speakers_in_batch
self.storage = Storage(maxsize=storage_max_size, storage_batchs=storage_size, num_speakers_in_batch=num_speakers_in_batch)
self.storage = Storage(
maxsize=storage_max_size, storage_batchs=storage_size, num_speakers_in_batch=num_speakers_in_batch
)
self.sample_from_storage_p = float(sample_from_storage_p)
speakers_aux = list(self.speakers)
speakers_aux.sort()
self.speakerid_to_classid = {key : i for i, key in enumerate(speakers_aux)}
self.speakerid_to_classid = {key: i for i, key in enumerate(speakers_aux)}
# Augmentation
self.augmentator = None
self.gaussian_augmentation_config = None
if augmentation_config:
self.data_augmentation_p = augmentation_config['p']
if self.data_augmentation_p and ('additive' in augmentation_config or 'rir' in augmentation_config):
self.data_augmentation_p = augmentation_config["p"]
if self.data_augmentation_p and ("additive" in augmentation_config or "rir" in augmentation_config):
self.augmentator = AugmentWAV(ap, augmentation_config)
if 'gaussian' in augmentation_config.keys():
self.gaussian_augmentation_config = augmentation_config['gaussian']
if "gaussian" in augmentation_config.keys():
self.gaussian_augmentation_config = augmentation_config["gaussian"]
if self.verbose:
print("\n > DataLoader initialization")
@ -231,9 +234,13 @@ class SpeakerEncoderDataset(Dataset):
offset = random.randint(0, wav.shape[0] - self.seq_len)
wav = wav[offset : offset + self.seq_len]
# add random gaussian noise
if self.gaussian_augmentation_config and self.gaussian_augmentation_config['p']:
if random.random() < self.gaussian_augmentation_config['p']:
wav += np.random.normal(self.gaussian_augmentation_config['min_amplitude'], self.gaussian_augmentation_config['max_amplitude'], size=len(wav))
if self.gaussian_augmentation_config and self.gaussian_augmentation_config["p"]:
if random.random() < self.gaussian_augmentation_config["p"]:
wav += np.random.normal(
self.gaussian_augmentation_config["min_amplitude"],
self.gaussian_augmentation_config["max_amplitude"],
size=len(wav),
)
mel = self.ap.melspectrogram(wav)
feats_.append(torch.FloatTensor(mel))

View File

@ -162,6 +162,7 @@ class AngleProtoLoss(nn.Module):
L = self.criterion(cos_sim_matrix, label)
return L
class SoftmaxLoss(nn.Module):
"""
Implementation of the Softmax loss as defined in https://arxiv.org/abs/2003.11982
@ -169,13 +170,14 @@ class SoftmaxLoss(nn.Module):
- embedding_dim (float): speaker embedding dim
- n_speakers (float): number of speakers
"""
def __init__(self, embedding_dim, n_speakers):
super().__init__()
self.criterion = torch.nn.CrossEntropyLoss()
self.fc = nn.Linear(embedding_dim, n_speakers)
print('Initialised Softmax Loss')
print("Initialised Softmax Loss")
def forward(self, x, label=None):
# reshape for compatibility
@ -187,6 +189,7 @@ class SoftmaxLoss(nn.Module):
return L
class SoftmaxAngleProtoLoss(nn.Module):
"""
Implementation of the Softmax AnglePrototypical loss as defined in https://arxiv.org/abs/2009.14153
@ -196,13 +199,14 @@ class SoftmaxAngleProtoLoss(nn.Module):
- init_w (float): defines the initial value of w
- init_b (float): definies the initial value of b
"""
def __init__(self, embedding_dim, n_speakers, init_w=10.0, init_b=-5.0):
super().__init__()
self.softmax = SoftmaxLoss(embedding_dim, n_speakers)
self.angleproto = AngleProtoLoss(init_w, init_b)
print('Initialised SoftmaxAnglePrototypical Loss')
print("Initialised SoftmaxAnglePrototypical Loss")
def forward(self, x, label=None):
"""
@ -213,4 +217,4 @@ class SoftmaxAngleProtoLoss(nn.Module):
Ls = self.softmax(x, label)
return Ls+Lp
return Ls + Lp

View File

@ -1,7 +1,8 @@
import torch
import numpy as np
import torch
import torch.nn as nn
class SELayer(nn.Module):
def __init__(self, channel, reduction=8):
super(SELayer, self).__init__()
@ -10,7 +11,7 @@ class SELayer(nn.Module):
nn.Linear(channel, channel // reduction),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel),
nn.Sigmoid()
nn.Sigmoid(),
)
def forward(self, x):
@ -19,6 +20,7 @@ class SELayer(nn.Module):
y = self.fc(y).view(b, c, 1, 1)
return x * y
class SEBasicBlock(nn.Module):
expansion = 1
@ -51,12 +53,22 @@ class SEBasicBlock(nn.Module):
out = self.relu(out)
return out
class ResNetSpeakerEncoder(nn.Module):
"""Implementation of the model H/ASP without batch normalization in speaker embedding. This model was proposed in: https://arxiv.org/abs/2009.14153
Adapted from: https://github.com/clovaai/voxceleb_trainer
"""
# pylint: disable=W0102
def __init__(self, input_dim=64, proj_dim=512, layers=[3, 4, 6, 3], num_filters=[32, 64, 128, 256], encoder_type='ASP', log_input=False):
def __init__(
self,
input_dim=64,
proj_dim=512,
layers=[3, 4, 6, 3],
num_filters=[32, 64, 128, 256],
encoder_type="ASP",
log_input=False,
):
super(ResNetSpeakerEncoder, self).__init__()
self.encoder_type = encoder_type
@ -74,7 +86,7 @@ class ResNetSpeakerEncoder(nn.Module):
self.instancenorm = nn.InstanceNorm1d(input_dim)
outmap_size = int(self.input_dim/8)
outmap_size = int(self.input_dim / 8)
self.attention = nn.Sequential(
nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1),
@ -82,14 +94,14 @@ class ResNetSpeakerEncoder(nn.Module):
nn.BatchNorm1d(128),
nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1),
nn.Softmax(dim=2),
)
)
if self.encoder_type == "SAP":
out_dim = num_filters[3] * outmap_size
elif self.encoder_type == "ASP":
out_dim = num_filters[3] * outmap_size * 2
else:
raise ValueError('Undefined encoder')
raise ValueError("Undefined encoder")
self.fc = nn.Linear(out_dim, proj_dim)
@ -98,7 +110,7 @@ class ResNetSpeakerEncoder(nn.Module):
def _init_layers(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
@ -107,8 +119,7 @@ class ResNetSpeakerEncoder(nn.Module):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
@ -131,7 +142,7 @@ class ResNetSpeakerEncoder(nn.Module):
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False):
if self.log_input:
x = (x+1e-6).log()
x = (x + 1e-6).log()
x = self.instancenorm(x).unsqueeze(1)
x = self.conv1(x)
@ -151,7 +162,7 @@ class ResNetSpeakerEncoder(nn.Module):
x = torch.sum(x * w, dim=2)
elif self.encoder_type == "ASP":
mu = torch.sum(x * w, dim=2)
sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu ** 2).clamp(min=1e-5))
sg = torch.sqrt((torch.sum((x ** 2) * w, dim=2) - mu ** 2).clamp(min=1e-5))
x = torch.cat((mu, sg), 1)
x = x.view(x.size()[0], -1)
@ -172,12 +183,12 @@ class ResNetSpeakerEncoder(nn.Module):
if max_len < num_frames:
num_frames = max_len
offsets = np.linspace(0, max_len-num_frames, num=num_eval)
offsets = np.linspace(0, max_len - num_frames, num=num_eval)
frames_batch = []
for offset in offsets:
offset = int(offset)
end_offset = int(offset+num_frames)
end_offset = int(offset + num_frames)
frames = x[:, offset:end_offset]
frames_batch.append(frames)

View File

@ -25,10 +25,7 @@ class SpeakerEncoderConfig(BaseTrainingConfig):
}
)
audio_augmentation : dict = field(
default_factory=lambda: {
}
)
audio_augmentation: dict = field(default_factory=lambda: {})
storage: dict = field(
default_factory=lambda: {

View File

@ -1,18 +1,18 @@
import re
import datetime
import glob
import os
import random
import re
from multiprocessing import Manager
import numpy as np
import torch
import glob
import random
import datetime
from scipy import signal
from multiprocessing import Manager
from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder
from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder
class Storage(object):
def __init__(self, maxsize, storage_batchs, num_speakers_in_batch, num_threads=8):
# use multiprocessing for threading safe
@ -53,19 +53,19 @@ class Storage(object):
return self.storage[random.randint(0, storage_size)]
def get_random_sample_fast(self):
'''Call this method only when storage is full'''
"""Call this method only when storage is full"""
return self.storage[random.randint(0, self.safe_storage_size)]
class AugmentWAV(object):
class AugmentWAV(object):
def __init__(self, ap, augmentation_config):
self.ap = ap
self.use_additive_noise = False
if 'additive' in augmentation_config.keys():
self.additive_noise_config = augmentation_config['additive']
additive_path = self.additive_noise_config['sounds_path']
if "additive" in augmentation_config.keys():
self.additive_noise_config = augmentation_config["additive"]
additive_path = self.additive_noise_config["sounds_path"]
if additive_path:
self.use_additive_noise = True
# get noise types
@ -74,12 +74,12 @@ class AugmentWAV(object):
if isinstance(self.additive_noise_config[key], dict):
self.additive_noise_types.append(key)
additive_files = glob.glob(os.path.join(additive_path, '**/*.wav'), recursive=True)
additive_files = glob.glob(os.path.join(additive_path, "**/*.wav"), recursive=True)
self.noise_list = {}
for wav_file in additive_files:
noise_dir = wav_file.replace(additive_path, '').split(os.sep)[0]
noise_dir = wav_file.replace(additive_path, "").split(os.sep)[0]
# ignore not listed directories
if noise_dir not in self.additive_noise_types:
continue
@ -87,14 +87,16 @@ class AugmentWAV(object):
self.noise_list[noise_dir] = []
self.noise_list[noise_dir].append(wav_file)
print(f" | > Using Additive Noise Augmentation: with {len(additive_files)} audios instances from {self.additive_noise_types}")
print(
f" | > Using Additive Noise Augmentation: with {len(additive_files)} audios instances from {self.additive_noise_types}"
)
self.use_rir = False
if 'rir' in augmentation_config.keys():
self.rir_config = augmentation_config['rir']
if self.rir_config['rir_path']:
self.rir_files = glob.glob(os.path.join(self.rir_config['rir_path'], '**/*.wav'), recursive=True)
if "rir" in augmentation_config.keys():
self.rir_config = augmentation_config["rir"]
if self.rir_config["rir_path"]:
self.rir_files = glob.glob(os.path.join(self.rir_config["rir_path"], "**/*.wav"), recursive=True)
self.use_rir = True
print(f" | > Using RIR Noise Augmentation: with {len(self.rir_files)} audios instances")
@ -111,9 +113,15 @@ class AugmentWAV(object):
def additive_noise(self, noise_type, audio):
clean_db = 10 * np.log10(np.mean(audio**2) + 1e-4)
clean_db = 10 * np.log10(np.mean(audio ** 2) + 1e-4)
noise_list = random.sample(self.noise_list[noise_type], random.randint(self.additive_noise_config[noise_type]['min_num_noises'], self.additive_noise_config[noise_type]['max_num_noises']))
noise_list = random.sample(
self.noise_list[noise_type],
random.randint(
self.additive_noise_config[noise_type]["min_num_noises"],
self.additive_noise_config[noise_type]["max_num_noises"],
),
)
audio_len = audio.shape[0]
noises_wav = None
@ -123,7 +131,10 @@ class AugmentWAV(object):
if noiseaudio.shape[0] < audio_len:
continue
noise_snr = random.uniform(self.additive_noise_config[noise_type]['min_snr_in_db'], self.additive_noise_config[noise_type]['max_num_noises'])
noise_snr = random.uniform(
self.additive_noise_config[noise_type]["min_snr_in_db"],
self.additive_noise_config[noise_type]["max_num_noises"],
)
noise_db = 10 * np.log10(np.mean(noiseaudio ** 2) + 1e-4)
noise_wav = np.sqrt(10 ** ((clean_db - noise_db - noise_snr) / 10)) * noiseaudio
@ -144,7 +155,7 @@ class AugmentWAV(object):
rir_file = random.choice(self.rir_files)
rir = self.ap.load_wav(rir_file, sr=self.ap.sample_rate)
rir = rir / np.sqrt(np.sum(rir ** 2))
return signal.convolve(audio, rir, mode=self.rir_config['conv_mode'])[:audio_len]
return signal.convolve(audio, rir, mode=self.rir_config["conv_mode"])[:audio_len]
def apply_one(self, audio):
noise_type = random.choice(self.global_noise_list)
@ -153,17 +164,25 @@ class AugmentWAV(object):
return self.additive_noise(noise_type, audio)
def to_camel(text):
text = text.capitalize()
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
def setup_model(c):
if c.model_params['model_name'].lower() == 'lstm':
model = LSTMSpeakerEncoder(c.model_params["input_dim"], c.model_params["proj_dim"], c.model_params["lstm_dim"], c.model_params["num_lstm_layers"])
elif c.model_params['model_name'].lower() == 'resnet':
if c.model_params["model_name"].lower() == "lstm":
model = LSTMSpeakerEncoder(
c.model_params["input_dim"],
c.model_params["proj_dim"],
c.model_params["lstm_dim"],
c.model_params["num_lstm_layers"],
)
elif c.model_params["model_name"].lower() == "resnet":
model = ResNetSpeakerEncoder(input_dim=c.model_params["input_dim"], proj_dim=c.model_params["proj_dim"])
return model
def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_step, epoch):
checkpoint_path = "checkpoint_{}.pth.tar".format(current_step)
checkpoint_path = os.path.join(out_path, checkpoint_path)

View File

@ -6,6 +6,7 @@ from tests import get_tests_input_path
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder
from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder
file_path = get_tests_input_path()
@ -39,6 +40,7 @@ class LSTMSpeakerEncoderTests(unittest.TestCase):
assert output.shape[1] == 256
assert len(output.shape) == 2
class ResNetSpeakerEncoderTests(unittest.TestCase):
# pylint: disable=R0201
def test_in_out(self):
@ -65,6 +67,7 @@ class ResNetSpeakerEncoderTests(unittest.TestCase):
assert output.shape[1] == 256
assert len(output.shape) == 2
class GE2ELossTests(unittest.TestCase):
# pylint: disable=R0201
def test_in_out(self):
@ -92,6 +95,7 @@ class GE2ELossTests(unittest.TestCase):
output = loss.forward(dummy_input)
assert output.item() < 0.005
class AngleProtoLossTests(unittest.TestCase):
# pylint: disable=R0201
def test_in_out(self):
@ -121,6 +125,7 @@ class AngleProtoLossTests(unittest.TestCase):
output = loss.forward(dummy_input)
assert output.item() < 0.005
class SoftmaxAngleProtoLossTests(unittest.TestCase):
# pylint: disable=R0201
def test_in_out(self):

View File

@ -46,7 +46,7 @@ run_cli(command_train)
shutil.rmtree(continue_path)
# test resnet speaker encoder
config.model_params['model_name'] = "resnet"
config.model_params["model_name"] = "resnet"
config.save_json(config_path)
# train the model for one epoch