make style

This commit is contained in:
Eren Gölge 2021-05-31 16:37:15 +02:00
parent 975531b3f2
commit bec85ac58d
9 changed files with 115 additions and 75 deletions

View File

@ -5,11 +5,11 @@ import os
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from TTS.config import BaseDatasetConfig, load_config
from TTS.speaker_encoder.utils.generic_utils import setup_model from TTS.speaker_encoder.utils.generic_utils import setup_model
from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.config import load_config, BaseDatasetConfig
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Compute embedding vectors for each wav file in a dataset. If "target_dataset" is defined, it generates "speakers.json" necessary for training a multi-speaker model.' description='Compute embedding vectors for each wav file in a dataset. If "target_dataset" is defined, it generates "speakers.json" necessary for training a multi-speaker model.'
@ -100,7 +100,7 @@ for idx, wav_file in enumerate(tqdm(wav_files)):
if speaker_mapping: if speaker_mapping:
# save speaker_mapping if target dataset is defined # save speaker_mapping if target dataset is defined
if '.json' not in args.output_path: if ".json" not in args.output_path:
mapping_file_path = os.path.join(args.output_path, "speakers.json") mapping_file_path = os.path.join(args.output_path, "speakers.json")
else: else:
mapping_file_path = args.output_path mapping_file_path = args.output_path

View File

@ -10,10 +10,8 @@ import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from TTS.speaker_encoder.dataset import SpeakerEncoderDataset from TTS.speaker_encoder.dataset import SpeakerEncoderDataset
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_model from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_model
from TTS.speaker_encoder.utils.visual import plot_embeddings from TTS.speaker_encoder.utils.visual import plot_embeddings
from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.preprocess import load_meta_data
from TTS.utils.arguments import init_training from TTS.utils.arguments import init_training
@ -45,7 +43,7 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
storage_size=c.storage["storage_size"], storage_size=c.storage["storage_size"],
sample_from_storage_p=c.storage["sample_from_storage_p"], sample_from_storage_p=c.storage["sample_from_storage_p"],
verbose=verbose, verbose=verbose,
augmentation_config=c.audio_augmentation augmentation_config=c.audio_augmentation,
) )
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None # sampler = DistributedSampler(dataset) if num_gpus > 1 else None
@ -170,19 +168,18 @@ def main(args): # pylint: disable=redefined-outer-name
else: else:
raise Exception("The %s not is a loss supported" % c.loss) raise Exception("The %s not is a loss supported" % c.loss)
if args.restore_path: if args.restore_path:
checkpoint = torch.load(args.restore_path) checkpoint = torch.load(args.restore_path)
try: try:
model.load_state_dict(checkpoint["model"]) model.load_state_dict(checkpoint["model"])
if 'criterion' in checkpoint: if "criterion" in checkpoint:
criterion.load_state_dict(checkpoint["criterion"]) criterion.load_state_dict(checkpoint["criterion"])
except (KeyError, RuntimeError): except (KeyError, RuntimeError):
print(" > Partial model initialization.") print(" > Partial model initialization.")
model_dict = model.state_dict() model_dict = model.state_dict()
model_dict = set_init_dict(model_dict, checkpoint['model'], c) model_dict = set_init_dict(model_dict, checkpoint["model"], c)
model.load_state_dict(model_dict) model.load_state_dict(model_dict)
del model_dict del model_dict
for group in optimizer.param_groups: for group in optimizer.param_groups:

View File

@ -1,24 +1,25 @@
import random import random
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from TTS.speaker_encoder.utils.generic_utils import AugmentWAV, Storage from TTS.speaker_encoder.utils.generic_utils import AugmentWAV, Storage
class SpeakerEncoderDataset(Dataset): class SpeakerEncoderDataset(Dataset):
def __init__( def __init__(
self, self,
ap, ap,
meta_data, meta_data,
voice_len=1.6, voice_len=1.6,
num_speakers_in_batch=64, num_speakers_in_batch=64,
storage_size=1, storage_size=1,
sample_from_storage_p=0.5, sample_from_storage_p=0.5,
num_utter_per_speaker=10, num_utter_per_speaker=10,
skip_speakers=False, skip_speakers=False,
verbose=False, verbose=False,
augmentation_config=None augmentation_config=None,
): ):
""" """
Args: Args:
@ -38,23 +39,25 @@ class SpeakerEncoderDataset(Dataset):
self.verbose = verbose self.verbose = verbose
self.__parse_items() self.__parse_items()
storage_max_size = storage_size * num_speakers_in_batch storage_max_size = storage_size * num_speakers_in_batch
self.storage = Storage(maxsize=storage_max_size, storage_batchs=storage_size, num_speakers_in_batch=num_speakers_in_batch) self.storage = Storage(
maxsize=storage_max_size, storage_batchs=storage_size, num_speakers_in_batch=num_speakers_in_batch
)
self.sample_from_storage_p = float(sample_from_storage_p) self.sample_from_storage_p = float(sample_from_storage_p)
speakers_aux = list(self.speakers) speakers_aux = list(self.speakers)
speakers_aux.sort() speakers_aux.sort()
self.speakerid_to_classid = {key : i for i, key in enumerate(speakers_aux)} self.speakerid_to_classid = {key: i for i, key in enumerate(speakers_aux)}
# Augmentation # Augmentation
self.augmentator = None self.augmentator = None
self.gaussian_augmentation_config = None self.gaussian_augmentation_config = None
if augmentation_config: if augmentation_config:
self.data_augmentation_p = augmentation_config['p'] self.data_augmentation_p = augmentation_config["p"]
if self.data_augmentation_p and ('additive' in augmentation_config or 'rir' in augmentation_config): if self.data_augmentation_p and ("additive" in augmentation_config or "rir" in augmentation_config):
self.augmentator = AugmentWAV(ap, augmentation_config) self.augmentator = AugmentWAV(ap, augmentation_config)
if 'gaussian' in augmentation_config.keys(): if "gaussian" in augmentation_config.keys():
self.gaussian_augmentation_config = augmentation_config['gaussian'] self.gaussian_augmentation_config = augmentation_config["gaussian"]
if self.verbose: if self.verbose:
print("\n > DataLoader initialization") print("\n > DataLoader initialization")
@ -231,9 +234,13 @@ class SpeakerEncoderDataset(Dataset):
offset = random.randint(0, wav.shape[0] - self.seq_len) offset = random.randint(0, wav.shape[0] - self.seq_len)
wav = wav[offset : offset + self.seq_len] wav = wav[offset : offset + self.seq_len]
# add random gaussian noise # add random gaussian noise
if self.gaussian_augmentation_config and self.gaussian_augmentation_config['p']: if self.gaussian_augmentation_config and self.gaussian_augmentation_config["p"]:
if random.random() < self.gaussian_augmentation_config['p']: if random.random() < self.gaussian_augmentation_config["p"]:
wav += np.random.normal(self.gaussian_augmentation_config['min_amplitude'], self.gaussian_augmentation_config['max_amplitude'], size=len(wav)) wav += np.random.normal(
self.gaussian_augmentation_config["min_amplitude"],
self.gaussian_augmentation_config["max_amplitude"],
size=len(wav),
)
mel = self.ap.melspectrogram(wav) mel = self.ap.melspectrogram(wav)
feats_.append(torch.FloatTensor(mel)) feats_.append(torch.FloatTensor(mel))

View File

@ -162,6 +162,7 @@ class AngleProtoLoss(nn.Module):
L = self.criterion(cos_sim_matrix, label) L = self.criterion(cos_sim_matrix, label)
return L return L
class SoftmaxLoss(nn.Module): class SoftmaxLoss(nn.Module):
""" """
Implementation of the Softmax loss as defined in https://arxiv.org/abs/2003.11982 Implementation of the Softmax loss as defined in https://arxiv.org/abs/2003.11982
@ -169,13 +170,14 @@ class SoftmaxLoss(nn.Module):
- embedding_dim (float): speaker embedding dim - embedding_dim (float): speaker embedding dim
- n_speakers (float): number of speakers - n_speakers (float): number of speakers
""" """
def __init__(self, embedding_dim, n_speakers): def __init__(self, embedding_dim, n_speakers):
super().__init__() super().__init__()
self.criterion = torch.nn.CrossEntropyLoss() self.criterion = torch.nn.CrossEntropyLoss()
self.fc = nn.Linear(embedding_dim, n_speakers) self.fc = nn.Linear(embedding_dim, n_speakers)
print('Initialised Softmax Loss') print("Initialised Softmax Loss")
def forward(self, x, label=None): def forward(self, x, label=None):
# reshape for compatibility # reshape for compatibility
@ -187,6 +189,7 @@ class SoftmaxLoss(nn.Module):
return L return L
class SoftmaxAngleProtoLoss(nn.Module): class SoftmaxAngleProtoLoss(nn.Module):
""" """
Implementation of the Softmax AnglePrototypical loss as defined in https://arxiv.org/abs/2009.14153 Implementation of the Softmax AnglePrototypical loss as defined in https://arxiv.org/abs/2009.14153
@ -196,13 +199,14 @@ class SoftmaxAngleProtoLoss(nn.Module):
- init_w (float): defines the initial value of w - init_w (float): defines the initial value of w
- init_b (float): definies the initial value of b - init_b (float): definies the initial value of b
""" """
def __init__(self, embedding_dim, n_speakers, init_w=10.0, init_b=-5.0): def __init__(self, embedding_dim, n_speakers, init_w=10.0, init_b=-5.0):
super().__init__() super().__init__()
self.softmax = SoftmaxLoss(embedding_dim, n_speakers) self.softmax = SoftmaxLoss(embedding_dim, n_speakers)
self.angleproto = AngleProtoLoss(init_w, init_b) self.angleproto = AngleProtoLoss(init_w, init_b)
print('Initialised SoftmaxAnglePrototypical Loss') print("Initialised SoftmaxAnglePrototypical Loss")
def forward(self, x, label=None): def forward(self, x, label=None):
""" """
@ -213,4 +217,4 @@ class SoftmaxAngleProtoLoss(nn.Module):
Ls = self.softmax(x, label) Ls = self.softmax(x, label)
return Ls+Lp return Ls + Lp

View File

@ -1,7 +1,8 @@
import torch
import numpy as np import numpy as np
import torch
import torch.nn as nn import torch.nn as nn
class SELayer(nn.Module): class SELayer(nn.Module):
def __init__(self, channel, reduction=8): def __init__(self, channel, reduction=8):
super(SELayer, self).__init__() super(SELayer, self).__init__()
@ -10,7 +11,7 @@ class SELayer(nn.Module):
nn.Linear(channel, channel // reduction), nn.Linear(channel, channel // reduction),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel), nn.Linear(channel // reduction, channel),
nn.Sigmoid() nn.Sigmoid(),
) )
def forward(self, x): def forward(self, x):
@ -19,6 +20,7 @@ class SELayer(nn.Module):
y = self.fc(y).view(b, c, 1, 1) y = self.fc(y).view(b, c, 1, 1)
return x * y return x * y
class SEBasicBlock(nn.Module): class SEBasicBlock(nn.Module):
expansion = 1 expansion = 1
@ -51,12 +53,22 @@ class SEBasicBlock(nn.Module):
out = self.relu(out) out = self.relu(out)
return out return out
class ResNetSpeakerEncoder(nn.Module): class ResNetSpeakerEncoder(nn.Module):
"""Implementation of the model H/ASP without batch normalization in speaker embedding. This model was proposed in: https://arxiv.org/abs/2009.14153 """Implementation of the model H/ASP without batch normalization in speaker embedding. This model was proposed in: https://arxiv.org/abs/2009.14153
Adapted from: https://github.com/clovaai/voxceleb_trainer Adapted from: https://github.com/clovaai/voxceleb_trainer
""" """
# pylint: disable=W0102 # pylint: disable=W0102
def __init__(self, input_dim=64, proj_dim=512, layers=[3, 4, 6, 3], num_filters=[32, 64, 128, 256], encoder_type='ASP', log_input=False): def __init__(
self,
input_dim=64,
proj_dim=512,
layers=[3, 4, 6, 3],
num_filters=[32, 64, 128, 256],
encoder_type="ASP",
log_input=False,
):
super(ResNetSpeakerEncoder, self).__init__() super(ResNetSpeakerEncoder, self).__init__()
self.encoder_type = encoder_type self.encoder_type = encoder_type
@ -74,7 +86,7 @@ class ResNetSpeakerEncoder(nn.Module):
self.instancenorm = nn.InstanceNorm1d(input_dim) self.instancenorm = nn.InstanceNorm1d(input_dim)
outmap_size = int(self.input_dim/8) outmap_size = int(self.input_dim / 8)
self.attention = nn.Sequential( self.attention = nn.Sequential(
nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1), nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1),
@ -82,14 +94,14 @@ class ResNetSpeakerEncoder(nn.Module):
nn.BatchNorm1d(128), nn.BatchNorm1d(128),
nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1), nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1),
nn.Softmax(dim=2), nn.Softmax(dim=2),
) )
if self.encoder_type == "SAP": if self.encoder_type == "SAP":
out_dim = num_filters[3] * outmap_size out_dim = num_filters[3] * outmap_size
elif self.encoder_type == "ASP": elif self.encoder_type == "ASP":
out_dim = num_filters[3] * outmap_size * 2 out_dim = num_filters[3] * outmap_size * 2
else: else:
raise ValueError('Undefined encoder') raise ValueError("Undefined encoder")
self.fc = nn.Linear(out_dim, proj_dim) self.fc = nn.Linear(out_dim, proj_dim)
@ -98,7 +110,7 @@ class ResNetSpeakerEncoder(nn.Module):
def _init_layers(self): def _init_layers(self):
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1) nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
@ -107,8 +119,7 @@ class ResNetSpeakerEncoder(nn.Module):
downsample = None downsample = None
if stride != 1 or self.inplanes != planes * block.expansion: if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential( downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion, nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion), nn.BatchNorm2d(planes * block.expansion),
) )
@ -131,7 +142,7 @@ class ResNetSpeakerEncoder(nn.Module):
with torch.no_grad(): with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
if self.log_input: if self.log_input:
x = (x+1e-6).log() x = (x + 1e-6).log()
x = self.instancenorm(x).unsqueeze(1) x = self.instancenorm(x).unsqueeze(1)
x = self.conv1(x) x = self.conv1(x)
@ -151,7 +162,7 @@ class ResNetSpeakerEncoder(nn.Module):
x = torch.sum(x * w, dim=2) x = torch.sum(x * w, dim=2)
elif self.encoder_type == "ASP": elif self.encoder_type == "ASP":
mu = torch.sum(x * w, dim=2) mu = torch.sum(x * w, dim=2)
sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu ** 2).clamp(min=1e-5)) sg = torch.sqrt((torch.sum((x ** 2) * w, dim=2) - mu ** 2).clamp(min=1e-5))
x = torch.cat((mu, sg), 1) x = torch.cat((mu, sg), 1)
x = x.view(x.size()[0], -1) x = x.view(x.size()[0], -1)
@ -172,12 +183,12 @@ class ResNetSpeakerEncoder(nn.Module):
if max_len < num_frames: if max_len < num_frames:
num_frames = max_len num_frames = max_len
offsets = np.linspace(0, max_len-num_frames, num=num_eval) offsets = np.linspace(0, max_len - num_frames, num=num_eval)
frames_batch = [] frames_batch = []
for offset in offsets: for offset in offsets:
offset = int(offset) offset = int(offset)
end_offset = int(offset+num_frames) end_offset = int(offset + num_frames)
frames = x[:, offset:end_offset] frames = x[:, offset:end_offset]
frames_batch.append(frames) frames_batch.append(frames)

View File

@ -25,10 +25,7 @@ class SpeakerEncoderConfig(BaseTrainingConfig):
} }
) )
audio_augmentation : dict = field( audio_augmentation: dict = field(default_factory=lambda: {})
default_factory=lambda: {
}
)
storage: dict = field( storage: dict = field(
default_factory=lambda: { default_factory=lambda: {

View File

@ -1,18 +1,18 @@
import re import datetime
import glob
import os import os
import random
import re
from multiprocessing import Manager
import numpy as np import numpy as np
import torch import torch
import glob
import random
import datetime
from scipy import signal from scipy import signal
from multiprocessing import Manager
from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder
from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder
class Storage(object): class Storage(object):
def __init__(self, maxsize, storage_batchs, num_speakers_in_batch, num_threads=8): def __init__(self, maxsize, storage_batchs, num_speakers_in_batch, num_threads=8):
# use multiprocessing for threading safe # use multiprocessing for threading safe
@ -53,19 +53,19 @@ class Storage(object):
return self.storage[random.randint(0, storage_size)] return self.storage[random.randint(0, storage_size)]
def get_random_sample_fast(self): def get_random_sample_fast(self):
'''Call this method only when storage is full''' """Call this method only when storage is full"""
return self.storage[random.randint(0, self.safe_storage_size)] return self.storage[random.randint(0, self.safe_storage_size)]
class AugmentWAV(object):
class AugmentWAV(object):
def __init__(self, ap, augmentation_config): def __init__(self, ap, augmentation_config):
self.ap = ap self.ap = ap
self.use_additive_noise = False self.use_additive_noise = False
if 'additive' in augmentation_config.keys(): if "additive" in augmentation_config.keys():
self.additive_noise_config = augmentation_config['additive'] self.additive_noise_config = augmentation_config["additive"]
additive_path = self.additive_noise_config['sounds_path'] additive_path = self.additive_noise_config["sounds_path"]
if additive_path: if additive_path:
self.use_additive_noise = True self.use_additive_noise = True
# get noise types # get noise types
@ -74,12 +74,12 @@ class AugmentWAV(object):
if isinstance(self.additive_noise_config[key], dict): if isinstance(self.additive_noise_config[key], dict):
self.additive_noise_types.append(key) self.additive_noise_types.append(key)
additive_files = glob.glob(os.path.join(additive_path, '**/*.wav'), recursive=True) additive_files = glob.glob(os.path.join(additive_path, "**/*.wav"), recursive=True)
self.noise_list = {} self.noise_list = {}
for wav_file in additive_files: for wav_file in additive_files:
noise_dir = wav_file.replace(additive_path, '').split(os.sep)[0] noise_dir = wav_file.replace(additive_path, "").split(os.sep)[0]
# ignore not listed directories # ignore not listed directories
if noise_dir not in self.additive_noise_types: if noise_dir not in self.additive_noise_types:
continue continue
@ -87,14 +87,16 @@ class AugmentWAV(object):
self.noise_list[noise_dir] = [] self.noise_list[noise_dir] = []
self.noise_list[noise_dir].append(wav_file) self.noise_list[noise_dir].append(wav_file)
print(f" | > Using Additive Noise Augmentation: with {len(additive_files)} audios instances from {self.additive_noise_types}") print(
f" | > Using Additive Noise Augmentation: with {len(additive_files)} audios instances from {self.additive_noise_types}"
)
self.use_rir = False self.use_rir = False
if 'rir' in augmentation_config.keys(): if "rir" in augmentation_config.keys():
self.rir_config = augmentation_config['rir'] self.rir_config = augmentation_config["rir"]
if self.rir_config['rir_path']: if self.rir_config["rir_path"]:
self.rir_files = glob.glob(os.path.join(self.rir_config['rir_path'], '**/*.wav'), recursive=True) self.rir_files = glob.glob(os.path.join(self.rir_config["rir_path"], "**/*.wav"), recursive=True)
self.use_rir = True self.use_rir = True
print(f" | > Using RIR Noise Augmentation: with {len(self.rir_files)} audios instances") print(f" | > Using RIR Noise Augmentation: with {len(self.rir_files)} audios instances")
@ -111,9 +113,15 @@ class AugmentWAV(object):
def additive_noise(self, noise_type, audio): def additive_noise(self, noise_type, audio):
clean_db = 10 * np.log10(np.mean(audio**2) + 1e-4) clean_db = 10 * np.log10(np.mean(audio ** 2) + 1e-4)
noise_list = random.sample(self.noise_list[noise_type], random.randint(self.additive_noise_config[noise_type]['min_num_noises'], self.additive_noise_config[noise_type]['max_num_noises'])) noise_list = random.sample(
self.noise_list[noise_type],
random.randint(
self.additive_noise_config[noise_type]["min_num_noises"],
self.additive_noise_config[noise_type]["max_num_noises"],
),
)
audio_len = audio.shape[0] audio_len = audio.shape[0]
noises_wav = None noises_wav = None
@ -123,7 +131,10 @@ class AugmentWAV(object):
if noiseaudio.shape[0] < audio_len: if noiseaudio.shape[0] < audio_len:
continue continue
noise_snr = random.uniform(self.additive_noise_config[noise_type]['min_snr_in_db'], self.additive_noise_config[noise_type]['max_num_noises']) noise_snr = random.uniform(
self.additive_noise_config[noise_type]["min_snr_in_db"],
self.additive_noise_config[noise_type]["max_num_noises"],
)
noise_db = 10 * np.log10(np.mean(noiseaudio ** 2) + 1e-4) noise_db = 10 * np.log10(np.mean(noiseaudio ** 2) + 1e-4)
noise_wav = np.sqrt(10 ** ((clean_db - noise_db - noise_snr) / 10)) * noiseaudio noise_wav = np.sqrt(10 ** ((clean_db - noise_db - noise_snr) / 10)) * noiseaudio
@ -144,7 +155,7 @@ class AugmentWAV(object):
rir_file = random.choice(self.rir_files) rir_file = random.choice(self.rir_files)
rir = self.ap.load_wav(rir_file, sr=self.ap.sample_rate) rir = self.ap.load_wav(rir_file, sr=self.ap.sample_rate)
rir = rir / np.sqrt(np.sum(rir ** 2)) rir = rir / np.sqrt(np.sum(rir ** 2))
return signal.convolve(audio, rir, mode=self.rir_config['conv_mode'])[:audio_len] return signal.convolve(audio, rir, mode=self.rir_config["conv_mode"])[:audio_len]
def apply_one(self, audio): def apply_one(self, audio):
noise_type = random.choice(self.global_noise_list) noise_type = random.choice(self.global_noise_list)
@ -153,17 +164,25 @@ class AugmentWAV(object):
return self.additive_noise(noise_type, audio) return self.additive_noise(noise_type, audio)
def to_camel(text): def to_camel(text):
text = text.capitalize() text = text.capitalize()
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
def setup_model(c): def setup_model(c):
if c.model_params['model_name'].lower() == 'lstm': if c.model_params["model_name"].lower() == "lstm":
model = LSTMSpeakerEncoder(c.model_params["input_dim"], c.model_params["proj_dim"], c.model_params["lstm_dim"], c.model_params["num_lstm_layers"]) model = LSTMSpeakerEncoder(
elif c.model_params['model_name'].lower() == 'resnet': c.model_params["input_dim"],
c.model_params["proj_dim"],
c.model_params["lstm_dim"],
c.model_params["num_lstm_layers"],
)
elif c.model_params["model_name"].lower() == "resnet":
model = ResNetSpeakerEncoder(input_dim=c.model_params["input_dim"], proj_dim=c.model_params["proj_dim"]) model = ResNetSpeakerEncoder(input_dim=c.model_params["input_dim"], proj_dim=c.model_params["proj_dim"])
return model return model
def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_step, epoch): def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_step, epoch):
checkpoint_path = "checkpoint_{}.pth.tar".format(current_step) checkpoint_path = "checkpoint_{}.pth.tar".format(current_step)
checkpoint_path = os.path.join(out_path, checkpoint_path) checkpoint_path = os.path.join(out_path, checkpoint_path)

View File

@ -6,6 +6,7 @@ from tests import get_tests_input_path
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder
from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder
file_path = get_tests_input_path() file_path = get_tests_input_path()
@ -39,6 +40,7 @@ class LSTMSpeakerEncoderTests(unittest.TestCase):
assert output.shape[1] == 256 assert output.shape[1] == 256
assert len(output.shape) == 2 assert len(output.shape) == 2
class ResNetSpeakerEncoderTests(unittest.TestCase): class ResNetSpeakerEncoderTests(unittest.TestCase):
# pylint: disable=R0201 # pylint: disable=R0201
def test_in_out(self): def test_in_out(self):
@ -65,6 +67,7 @@ class ResNetSpeakerEncoderTests(unittest.TestCase):
assert output.shape[1] == 256 assert output.shape[1] == 256
assert len(output.shape) == 2 assert len(output.shape) == 2
class GE2ELossTests(unittest.TestCase): class GE2ELossTests(unittest.TestCase):
# pylint: disable=R0201 # pylint: disable=R0201
def test_in_out(self): def test_in_out(self):
@ -92,6 +95,7 @@ class GE2ELossTests(unittest.TestCase):
output = loss.forward(dummy_input) output = loss.forward(dummy_input)
assert output.item() < 0.005 assert output.item() < 0.005
class AngleProtoLossTests(unittest.TestCase): class AngleProtoLossTests(unittest.TestCase):
# pylint: disable=R0201 # pylint: disable=R0201
def test_in_out(self): def test_in_out(self):
@ -121,6 +125,7 @@ class AngleProtoLossTests(unittest.TestCase):
output = loss.forward(dummy_input) output = loss.forward(dummy_input)
assert output.item() < 0.005 assert output.item() < 0.005
class SoftmaxAngleProtoLossTests(unittest.TestCase): class SoftmaxAngleProtoLossTests(unittest.TestCase):
# pylint: disable=R0201 # pylint: disable=R0201
def test_in_out(self): def test_in_out(self):

View File

@ -46,7 +46,7 @@ run_cli(command_train)
shutil.rmtree(continue_path) shutil.rmtree(continue_path)
# test resnet speaker encoder # test resnet speaker encoder
config.model_params['model_name'] = "resnet" config.model_params["model_name"] = "resnet"
config.save_json(config_path) config.save_json(config_path)
# train the model for one epoch # train the model for one epoch