mirror of https://github.com/coqui-ai/TTS.git
make style
This commit is contained in:
parent
975531b3f2
commit
bec85ac58d
|
@ -5,11 +5,11 @@ import os
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from TTS.config import BaseDatasetConfig, load_config
|
||||||
from TTS.speaker_encoder.utils.generic_utils import setup_model
|
from TTS.speaker_encoder.utils.generic_utils import setup_model
|
||||||
from TTS.tts.datasets.preprocess import load_meta_data
|
from TTS.tts.datasets.preprocess import load_meta_data
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.config import load_config, BaseDatasetConfig
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description='Compute embedding vectors for each wav file in a dataset. If "target_dataset" is defined, it generates "speakers.json" necessary for training a multi-speaker model.'
|
description='Compute embedding vectors for each wav file in a dataset. If "target_dataset" is defined, it generates "speakers.json" necessary for training a multi-speaker model.'
|
||||||
|
@ -100,7 +100,7 @@ for idx, wav_file in enumerate(tqdm(wav_files)):
|
||||||
|
|
||||||
if speaker_mapping:
|
if speaker_mapping:
|
||||||
# save speaker_mapping if target dataset is defined
|
# save speaker_mapping if target dataset is defined
|
||||||
if '.json' not in args.output_path:
|
if ".json" not in args.output_path:
|
||||||
mapping_file_path = os.path.join(args.output_path, "speakers.json")
|
mapping_file_path = os.path.join(args.output_path, "speakers.json")
|
||||||
else:
|
else:
|
||||||
mapping_file_path = args.output_path
|
mapping_file_path = args.output_path
|
||||||
|
|
|
@ -10,10 +10,8 @@ import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from TTS.speaker_encoder.dataset import SpeakerEncoderDataset
|
from TTS.speaker_encoder.dataset import SpeakerEncoderDataset
|
||||||
|
|
||||||
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
||||||
from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_model
|
from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_model
|
||||||
|
|
||||||
from TTS.speaker_encoder.utils.visual import plot_embeddings
|
from TTS.speaker_encoder.utils.visual import plot_embeddings
|
||||||
from TTS.tts.datasets.preprocess import load_meta_data
|
from TTS.tts.datasets.preprocess import load_meta_data
|
||||||
from TTS.utils.arguments import init_training
|
from TTS.utils.arguments import init_training
|
||||||
|
@ -45,7 +43,7 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
|
||||||
storage_size=c.storage["storage_size"],
|
storage_size=c.storage["storage_size"],
|
||||||
sample_from_storage_p=c.storage["sample_from_storage_p"],
|
sample_from_storage_p=c.storage["sample_from_storage_p"],
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
augmentation_config=c.audio_augmentation
|
augmentation_config=c.audio_augmentation,
|
||||||
)
|
)
|
||||||
|
|
||||||
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||||
|
@ -170,19 +168,18 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
else:
|
else:
|
||||||
raise Exception("The %s not is a loss supported" % c.loss)
|
raise Exception("The %s not is a loss supported" % c.loss)
|
||||||
|
|
||||||
|
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
checkpoint = torch.load(args.restore_path)
|
checkpoint = torch.load(args.restore_path)
|
||||||
try:
|
try:
|
||||||
model.load_state_dict(checkpoint["model"])
|
model.load_state_dict(checkpoint["model"])
|
||||||
|
|
||||||
if 'criterion' in checkpoint:
|
if "criterion" in checkpoint:
|
||||||
criterion.load_state_dict(checkpoint["criterion"])
|
criterion.load_state_dict(checkpoint["criterion"])
|
||||||
|
|
||||||
except (KeyError, RuntimeError):
|
except (KeyError, RuntimeError):
|
||||||
print(" > Partial model initialization.")
|
print(" > Partial model initialization.")
|
||||||
model_dict = model.state_dict()
|
model_dict = model.state_dict()
|
||||||
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
model_dict = set_init_dict(model_dict, checkpoint["model"], c)
|
||||||
model.load_state_dict(model_dict)
|
model.load_state_dict(model_dict)
|
||||||
del model_dict
|
del model_dict
|
||||||
for group in optimizer.param_groups:
|
for group in optimizer.param_groups:
|
||||||
|
|
|
@ -1,24 +1,25 @@
|
||||||
|
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from TTS.speaker_encoder.utils.generic_utils import AugmentWAV, Storage
|
from TTS.speaker_encoder.utils.generic_utils import AugmentWAV, Storage
|
||||||
|
|
||||||
|
|
||||||
class SpeakerEncoderDataset(Dataset):
|
class SpeakerEncoderDataset(Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
ap,
|
ap,
|
||||||
meta_data,
|
meta_data,
|
||||||
voice_len=1.6,
|
voice_len=1.6,
|
||||||
num_speakers_in_batch=64,
|
num_speakers_in_batch=64,
|
||||||
storage_size=1,
|
storage_size=1,
|
||||||
sample_from_storage_p=0.5,
|
sample_from_storage_p=0.5,
|
||||||
num_utter_per_speaker=10,
|
num_utter_per_speaker=10,
|
||||||
skip_speakers=False,
|
skip_speakers=False,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
augmentation_config=None
|
augmentation_config=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -38,23 +39,25 @@ class SpeakerEncoderDataset(Dataset):
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.__parse_items()
|
self.__parse_items()
|
||||||
storage_max_size = storage_size * num_speakers_in_batch
|
storage_max_size = storage_size * num_speakers_in_batch
|
||||||
self.storage = Storage(maxsize=storage_max_size, storage_batchs=storage_size, num_speakers_in_batch=num_speakers_in_batch)
|
self.storage = Storage(
|
||||||
|
maxsize=storage_max_size, storage_batchs=storage_size, num_speakers_in_batch=num_speakers_in_batch
|
||||||
|
)
|
||||||
self.sample_from_storage_p = float(sample_from_storage_p)
|
self.sample_from_storage_p = float(sample_from_storage_p)
|
||||||
|
|
||||||
speakers_aux = list(self.speakers)
|
speakers_aux = list(self.speakers)
|
||||||
speakers_aux.sort()
|
speakers_aux.sort()
|
||||||
self.speakerid_to_classid = {key : i for i, key in enumerate(speakers_aux)}
|
self.speakerid_to_classid = {key: i for i, key in enumerate(speakers_aux)}
|
||||||
|
|
||||||
# Augmentation
|
# Augmentation
|
||||||
self.augmentator = None
|
self.augmentator = None
|
||||||
self.gaussian_augmentation_config = None
|
self.gaussian_augmentation_config = None
|
||||||
if augmentation_config:
|
if augmentation_config:
|
||||||
self.data_augmentation_p = augmentation_config['p']
|
self.data_augmentation_p = augmentation_config["p"]
|
||||||
if self.data_augmentation_p and ('additive' in augmentation_config or 'rir' in augmentation_config):
|
if self.data_augmentation_p and ("additive" in augmentation_config or "rir" in augmentation_config):
|
||||||
self.augmentator = AugmentWAV(ap, augmentation_config)
|
self.augmentator = AugmentWAV(ap, augmentation_config)
|
||||||
|
|
||||||
if 'gaussian' in augmentation_config.keys():
|
if "gaussian" in augmentation_config.keys():
|
||||||
self.gaussian_augmentation_config = augmentation_config['gaussian']
|
self.gaussian_augmentation_config = augmentation_config["gaussian"]
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print("\n > DataLoader initialization")
|
print("\n > DataLoader initialization")
|
||||||
|
@ -231,9 +234,13 @@ class SpeakerEncoderDataset(Dataset):
|
||||||
offset = random.randint(0, wav.shape[0] - self.seq_len)
|
offset = random.randint(0, wav.shape[0] - self.seq_len)
|
||||||
wav = wav[offset : offset + self.seq_len]
|
wav = wav[offset : offset + self.seq_len]
|
||||||
# add random gaussian noise
|
# add random gaussian noise
|
||||||
if self.gaussian_augmentation_config and self.gaussian_augmentation_config['p']:
|
if self.gaussian_augmentation_config and self.gaussian_augmentation_config["p"]:
|
||||||
if random.random() < self.gaussian_augmentation_config['p']:
|
if random.random() < self.gaussian_augmentation_config["p"]:
|
||||||
wav += np.random.normal(self.gaussian_augmentation_config['min_amplitude'], self.gaussian_augmentation_config['max_amplitude'], size=len(wav))
|
wav += np.random.normal(
|
||||||
|
self.gaussian_augmentation_config["min_amplitude"],
|
||||||
|
self.gaussian_augmentation_config["max_amplitude"],
|
||||||
|
size=len(wav),
|
||||||
|
)
|
||||||
mel = self.ap.melspectrogram(wav)
|
mel = self.ap.melspectrogram(wav)
|
||||||
feats_.append(torch.FloatTensor(mel))
|
feats_.append(torch.FloatTensor(mel))
|
||||||
|
|
||||||
|
|
|
@ -162,6 +162,7 @@ class AngleProtoLoss(nn.Module):
|
||||||
L = self.criterion(cos_sim_matrix, label)
|
L = self.criterion(cos_sim_matrix, label)
|
||||||
return L
|
return L
|
||||||
|
|
||||||
|
|
||||||
class SoftmaxLoss(nn.Module):
|
class SoftmaxLoss(nn.Module):
|
||||||
"""
|
"""
|
||||||
Implementation of the Softmax loss as defined in https://arxiv.org/abs/2003.11982
|
Implementation of the Softmax loss as defined in https://arxiv.org/abs/2003.11982
|
||||||
|
@ -169,13 +170,14 @@ class SoftmaxLoss(nn.Module):
|
||||||
- embedding_dim (float): speaker embedding dim
|
- embedding_dim (float): speaker embedding dim
|
||||||
- n_speakers (float): number of speakers
|
- n_speakers (float): number of speakers
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, embedding_dim, n_speakers):
|
def __init__(self, embedding_dim, n_speakers):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.criterion = torch.nn.CrossEntropyLoss()
|
self.criterion = torch.nn.CrossEntropyLoss()
|
||||||
self.fc = nn.Linear(embedding_dim, n_speakers)
|
self.fc = nn.Linear(embedding_dim, n_speakers)
|
||||||
|
|
||||||
print('Initialised Softmax Loss')
|
print("Initialised Softmax Loss")
|
||||||
|
|
||||||
def forward(self, x, label=None):
|
def forward(self, x, label=None):
|
||||||
# reshape for compatibility
|
# reshape for compatibility
|
||||||
|
@ -187,6 +189,7 @@ class SoftmaxLoss(nn.Module):
|
||||||
|
|
||||||
return L
|
return L
|
||||||
|
|
||||||
|
|
||||||
class SoftmaxAngleProtoLoss(nn.Module):
|
class SoftmaxAngleProtoLoss(nn.Module):
|
||||||
"""
|
"""
|
||||||
Implementation of the Softmax AnglePrototypical loss as defined in https://arxiv.org/abs/2009.14153
|
Implementation of the Softmax AnglePrototypical loss as defined in https://arxiv.org/abs/2009.14153
|
||||||
|
@ -196,13 +199,14 @@ class SoftmaxAngleProtoLoss(nn.Module):
|
||||||
- init_w (float): defines the initial value of w
|
- init_w (float): defines the initial value of w
|
||||||
- init_b (float): definies the initial value of b
|
- init_b (float): definies the initial value of b
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, embedding_dim, n_speakers, init_w=10.0, init_b=-5.0):
|
def __init__(self, embedding_dim, n_speakers, init_w=10.0, init_b=-5.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.softmax = SoftmaxLoss(embedding_dim, n_speakers)
|
self.softmax = SoftmaxLoss(embedding_dim, n_speakers)
|
||||||
self.angleproto = AngleProtoLoss(init_w, init_b)
|
self.angleproto = AngleProtoLoss(init_w, init_b)
|
||||||
|
|
||||||
print('Initialised SoftmaxAnglePrototypical Loss')
|
print("Initialised SoftmaxAnglePrototypical Loss")
|
||||||
|
|
||||||
def forward(self, x, label=None):
|
def forward(self, x, label=None):
|
||||||
"""
|
"""
|
||||||
|
@ -213,4 +217,4 @@ class SoftmaxAngleProtoLoss(nn.Module):
|
||||||
|
|
||||||
Ls = self.softmax(x, label)
|
Ls = self.softmax(x, label)
|
||||||
|
|
||||||
return Ls+Lp
|
return Ls + Lp
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
import torch
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
class SELayer(nn.Module):
|
class SELayer(nn.Module):
|
||||||
def __init__(self, channel, reduction=8):
|
def __init__(self, channel, reduction=8):
|
||||||
super(SELayer, self).__init__()
|
super(SELayer, self).__init__()
|
||||||
|
@ -10,7 +11,7 @@ class SELayer(nn.Module):
|
||||||
nn.Linear(channel, channel // reduction),
|
nn.Linear(channel, channel // reduction),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.Linear(channel // reduction, channel),
|
nn.Linear(channel // reduction, channel),
|
||||||
nn.Sigmoid()
|
nn.Sigmoid(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -19,6 +20,7 @@ class SELayer(nn.Module):
|
||||||
y = self.fc(y).view(b, c, 1, 1)
|
y = self.fc(y).view(b, c, 1, 1)
|
||||||
return x * y
|
return x * y
|
||||||
|
|
||||||
|
|
||||||
class SEBasicBlock(nn.Module):
|
class SEBasicBlock(nn.Module):
|
||||||
expansion = 1
|
expansion = 1
|
||||||
|
|
||||||
|
@ -51,12 +53,22 @@ class SEBasicBlock(nn.Module):
|
||||||
out = self.relu(out)
|
out = self.relu(out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class ResNetSpeakerEncoder(nn.Module):
|
class ResNetSpeakerEncoder(nn.Module):
|
||||||
"""Implementation of the model H/ASP without batch normalization in speaker embedding. This model was proposed in: https://arxiv.org/abs/2009.14153
|
"""Implementation of the model H/ASP without batch normalization in speaker embedding. This model was proposed in: https://arxiv.org/abs/2009.14153
|
||||||
Adapted from: https://github.com/clovaai/voxceleb_trainer
|
Adapted from: https://github.com/clovaai/voxceleb_trainer
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=W0102
|
# pylint: disable=W0102
|
||||||
def __init__(self, input_dim=64, proj_dim=512, layers=[3, 4, 6, 3], num_filters=[32, 64, 128, 256], encoder_type='ASP', log_input=False):
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim=64,
|
||||||
|
proj_dim=512,
|
||||||
|
layers=[3, 4, 6, 3],
|
||||||
|
num_filters=[32, 64, 128, 256],
|
||||||
|
encoder_type="ASP",
|
||||||
|
log_input=False,
|
||||||
|
):
|
||||||
super(ResNetSpeakerEncoder, self).__init__()
|
super(ResNetSpeakerEncoder, self).__init__()
|
||||||
|
|
||||||
self.encoder_type = encoder_type
|
self.encoder_type = encoder_type
|
||||||
|
@ -74,7 +86,7 @@ class ResNetSpeakerEncoder(nn.Module):
|
||||||
|
|
||||||
self.instancenorm = nn.InstanceNorm1d(input_dim)
|
self.instancenorm = nn.InstanceNorm1d(input_dim)
|
||||||
|
|
||||||
outmap_size = int(self.input_dim/8)
|
outmap_size = int(self.input_dim / 8)
|
||||||
|
|
||||||
self.attention = nn.Sequential(
|
self.attention = nn.Sequential(
|
||||||
nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1),
|
nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1),
|
||||||
|
@ -82,14 +94,14 @@ class ResNetSpeakerEncoder(nn.Module):
|
||||||
nn.BatchNorm1d(128),
|
nn.BatchNorm1d(128),
|
||||||
nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1),
|
nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1),
|
||||||
nn.Softmax(dim=2),
|
nn.Softmax(dim=2),
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.encoder_type == "SAP":
|
if self.encoder_type == "SAP":
|
||||||
out_dim = num_filters[3] * outmap_size
|
out_dim = num_filters[3] * outmap_size
|
||||||
elif self.encoder_type == "ASP":
|
elif self.encoder_type == "ASP":
|
||||||
out_dim = num_filters[3] * outmap_size * 2
|
out_dim = num_filters[3] * outmap_size * 2
|
||||||
else:
|
else:
|
||||||
raise ValueError('Undefined encoder')
|
raise ValueError("Undefined encoder")
|
||||||
|
|
||||||
self.fc = nn.Linear(out_dim, proj_dim)
|
self.fc = nn.Linear(out_dim, proj_dim)
|
||||||
|
|
||||||
|
@ -98,7 +110,7 @@ class ResNetSpeakerEncoder(nn.Module):
|
||||||
def _init_layers(self):
|
def _init_layers(self):
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if isinstance(m, nn.Conv2d):
|
if isinstance(m, nn.Conv2d):
|
||||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||||
elif isinstance(m, nn.BatchNorm2d):
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
nn.init.constant_(m.weight, 1)
|
nn.init.constant_(m.weight, 1)
|
||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
|
@ -107,8 +119,7 @@ class ResNetSpeakerEncoder(nn.Module):
|
||||||
downsample = None
|
downsample = None
|
||||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||||
downsample = nn.Sequential(
|
downsample = nn.Sequential(
|
||||||
nn.Conv2d(self.inplanes, planes * block.expansion,
|
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
|
||||||
kernel_size=1, stride=stride, bias=False),
|
|
||||||
nn.BatchNorm2d(planes * block.expansion),
|
nn.BatchNorm2d(planes * block.expansion),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -131,7 +142,7 @@ class ResNetSpeakerEncoder(nn.Module):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
if self.log_input:
|
if self.log_input:
|
||||||
x = (x+1e-6).log()
|
x = (x + 1e-6).log()
|
||||||
x = self.instancenorm(x).unsqueeze(1)
|
x = self.instancenorm(x).unsqueeze(1)
|
||||||
|
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
|
@ -151,7 +162,7 @@ class ResNetSpeakerEncoder(nn.Module):
|
||||||
x = torch.sum(x * w, dim=2)
|
x = torch.sum(x * w, dim=2)
|
||||||
elif self.encoder_type == "ASP":
|
elif self.encoder_type == "ASP":
|
||||||
mu = torch.sum(x * w, dim=2)
|
mu = torch.sum(x * w, dim=2)
|
||||||
sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu ** 2).clamp(min=1e-5))
|
sg = torch.sqrt((torch.sum((x ** 2) * w, dim=2) - mu ** 2).clamp(min=1e-5))
|
||||||
x = torch.cat((mu, sg), 1)
|
x = torch.cat((mu, sg), 1)
|
||||||
|
|
||||||
x = x.view(x.size()[0], -1)
|
x = x.view(x.size()[0], -1)
|
||||||
|
@ -172,12 +183,12 @@ class ResNetSpeakerEncoder(nn.Module):
|
||||||
if max_len < num_frames:
|
if max_len < num_frames:
|
||||||
num_frames = max_len
|
num_frames = max_len
|
||||||
|
|
||||||
offsets = np.linspace(0, max_len-num_frames, num=num_eval)
|
offsets = np.linspace(0, max_len - num_frames, num=num_eval)
|
||||||
|
|
||||||
frames_batch = []
|
frames_batch = []
|
||||||
for offset in offsets:
|
for offset in offsets:
|
||||||
offset = int(offset)
|
offset = int(offset)
|
||||||
end_offset = int(offset+num_frames)
|
end_offset = int(offset + num_frames)
|
||||||
frames = x[:, offset:end_offset]
|
frames = x[:, offset:end_offset]
|
||||||
frames_batch.append(frames)
|
frames_batch.append(frames)
|
||||||
|
|
||||||
|
|
|
@ -25,10 +25,7 @@ class SpeakerEncoderConfig(BaseTrainingConfig):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
audio_augmentation : dict = field(
|
audio_augmentation: dict = field(default_factory=lambda: {})
|
||||||
default_factory=lambda: {
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
storage: dict = field(
|
storage: dict = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
|
|
|
@ -1,18 +1,18 @@
|
||||||
import re
|
import datetime
|
||||||
|
import glob
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
from multiprocessing import Manager
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import glob
|
|
||||||
import random
|
|
||||||
import datetime
|
|
||||||
|
|
||||||
from scipy import signal
|
from scipy import signal
|
||||||
from multiprocessing import Manager
|
|
||||||
|
|
||||||
from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder
|
from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder
|
||||||
from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder
|
from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder
|
||||||
|
|
||||||
|
|
||||||
class Storage(object):
|
class Storage(object):
|
||||||
def __init__(self, maxsize, storage_batchs, num_speakers_in_batch, num_threads=8):
|
def __init__(self, maxsize, storage_batchs, num_speakers_in_batch, num_threads=8):
|
||||||
# use multiprocessing for threading safe
|
# use multiprocessing for threading safe
|
||||||
|
@ -53,19 +53,19 @@ class Storage(object):
|
||||||
return self.storage[random.randint(0, storage_size)]
|
return self.storage[random.randint(0, storage_size)]
|
||||||
|
|
||||||
def get_random_sample_fast(self):
|
def get_random_sample_fast(self):
|
||||||
'''Call this method only when storage is full'''
|
"""Call this method only when storage is full"""
|
||||||
return self.storage[random.randint(0, self.safe_storage_size)]
|
return self.storage[random.randint(0, self.safe_storage_size)]
|
||||||
|
|
||||||
class AugmentWAV(object):
|
|
||||||
|
|
||||||
|
class AugmentWAV(object):
|
||||||
def __init__(self, ap, augmentation_config):
|
def __init__(self, ap, augmentation_config):
|
||||||
|
|
||||||
self.ap = ap
|
self.ap = ap
|
||||||
self.use_additive_noise = False
|
self.use_additive_noise = False
|
||||||
|
|
||||||
if 'additive' in augmentation_config.keys():
|
if "additive" in augmentation_config.keys():
|
||||||
self.additive_noise_config = augmentation_config['additive']
|
self.additive_noise_config = augmentation_config["additive"]
|
||||||
additive_path = self.additive_noise_config['sounds_path']
|
additive_path = self.additive_noise_config["sounds_path"]
|
||||||
if additive_path:
|
if additive_path:
|
||||||
self.use_additive_noise = True
|
self.use_additive_noise = True
|
||||||
# get noise types
|
# get noise types
|
||||||
|
@ -74,12 +74,12 @@ class AugmentWAV(object):
|
||||||
if isinstance(self.additive_noise_config[key], dict):
|
if isinstance(self.additive_noise_config[key], dict):
|
||||||
self.additive_noise_types.append(key)
|
self.additive_noise_types.append(key)
|
||||||
|
|
||||||
additive_files = glob.glob(os.path.join(additive_path, '**/*.wav'), recursive=True)
|
additive_files = glob.glob(os.path.join(additive_path, "**/*.wav"), recursive=True)
|
||||||
|
|
||||||
self.noise_list = {}
|
self.noise_list = {}
|
||||||
|
|
||||||
for wav_file in additive_files:
|
for wav_file in additive_files:
|
||||||
noise_dir = wav_file.replace(additive_path, '').split(os.sep)[0]
|
noise_dir = wav_file.replace(additive_path, "").split(os.sep)[0]
|
||||||
# ignore not listed directories
|
# ignore not listed directories
|
||||||
if noise_dir not in self.additive_noise_types:
|
if noise_dir not in self.additive_noise_types:
|
||||||
continue
|
continue
|
||||||
|
@ -87,14 +87,16 @@ class AugmentWAV(object):
|
||||||
self.noise_list[noise_dir] = []
|
self.noise_list[noise_dir] = []
|
||||||
self.noise_list[noise_dir].append(wav_file)
|
self.noise_list[noise_dir].append(wav_file)
|
||||||
|
|
||||||
print(f" | > Using Additive Noise Augmentation: with {len(additive_files)} audios instances from {self.additive_noise_types}")
|
print(
|
||||||
|
f" | > Using Additive Noise Augmentation: with {len(additive_files)} audios instances from {self.additive_noise_types}"
|
||||||
|
)
|
||||||
|
|
||||||
self.use_rir = False
|
self.use_rir = False
|
||||||
|
|
||||||
if 'rir' in augmentation_config.keys():
|
if "rir" in augmentation_config.keys():
|
||||||
self.rir_config = augmentation_config['rir']
|
self.rir_config = augmentation_config["rir"]
|
||||||
if self.rir_config['rir_path']:
|
if self.rir_config["rir_path"]:
|
||||||
self.rir_files = glob.glob(os.path.join(self.rir_config['rir_path'], '**/*.wav'), recursive=True)
|
self.rir_files = glob.glob(os.path.join(self.rir_config["rir_path"], "**/*.wav"), recursive=True)
|
||||||
self.use_rir = True
|
self.use_rir = True
|
||||||
|
|
||||||
print(f" | > Using RIR Noise Augmentation: with {len(self.rir_files)} audios instances")
|
print(f" | > Using RIR Noise Augmentation: with {len(self.rir_files)} audios instances")
|
||||||
|
@ -111,9 +113,15 @@ class AugmentWAV(object):
|
||||||
|
|
||||||
def additive_noise(self, noise_type, audio):
|
def additive_noise(self, noise_type, audio):
|
||||||
|
|
||||||
clean_db = 10 * np.log10(np.mean(audio**2) + 1e-4)
|
clean_db = 10 * np.log10(np.mean(audio ** 2) + 1e-4)
|
||||||
|
|
||||||
noise_list = random.sample(self.noise_list[noise_type], random.randint(self.additive_noise_config[noise_type]['min_num_noises'], self.additive_noise_config[noise_type]['max_num_noises']))
|
noise_list = random.sample(
|
||||||
|
self.noise_list[noise_type],
|
||||||
|
random.randint(
|
||||||
|
self.additive_noise_config[noise_type]["min_num_noises"],
|
||||||
|
self.additive_noise_config[noise_type]["max_num_noises"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
audio_len = audio.shape[0]
|
audio_len = audio.shape[0]
|
||||||
noises_wav = None
|
noises_wav = None
|
||||||
|
@ -123,7 +131,10 @@ class AugmentWAV(object):
|
||||||
if noiseaudio.shape[0] < audio_len:
|
if noiseaudio.shape[0] < audio_len:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
noise_snr = random.uniform(self.additive_noise_config[noise_type]['min_snr_in_db'], self.additive_noise_config[noise_type]['max_num_noises'])
|
noise_snr = random.uniform(
|
||||||
|
self.additive_noise_config[noise_type]["min_snr_in_db"],
|
||||||
|
self.additive_noise_config[noise_type]["max_num_noises"],
|
||||||
|
)
|
||||||
noise_db = 10 * np.log10(np.mean(noiseaudio ** 2) + 1e-4)
|
noise_db = 10 * np.log10(np.mean(noiseaudio ** 2) + 1e-4)
|
||||||
noise_wav = np.sqrt(10 ** ((clean_db - noise_db - noise_snr) / 10)) * noiseaudio
|
noise_wav = np.sqrt(10 ** ((clean_db - noise_db - noise_snr) / 10)) * noiseaudio
|
||||||
|
|
||||||
|
@ -144,7 +155,7 @@ class AugmentWAV(object):
|
||||||
rir_file = random.choice(self.rir_files)
|
rir_file = random.choice(self.rir_files)
|
||||||
rir = self.ap.load_wav(rir_file, sr=self.ap.sample_rate)
|
rir = self.ap.load_wav(rir_file, sr=self.ap.sample_rate)
|
||||||
rir = rir / np.sqrt(np.sum(rir ** 2))
|
rir = rir / np.sqrt(np.sum(rir ** 2))
|
||||||
return signal.convolve(audio, rir, mode=self.rir_config['conv_mode'])[:audio_len]
|
return signal.convolve(audio, rir, mode=self.rir_config["conv_mode"])[:audio_len]
|
||||||
|
|
||||||
def apply_one(self, audio):
|
def apply_one(self, audio):
|
||||||
noise_type = random.choice(self.global_noise_list)
|
noise_type = random.choice(self.global_noise_list)
|
||||||
|
@ -153,17 +164,25 @@ class AugmentWAV(object):
|
||||||
|
|
||||||
return self.additive_noise(noise_type, audio)
|
return self.additive_noise(noise_type, audio)
|
||||||
|
|
||||||
|
|
||||||
def to_camel(text):
|
def to_camel(text):
|
||||||
text = text.capitalize()
|
text = text.capitalize()
|
||||||
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
||||||
|
|
||||||
|
|
||||||
def setup_model(c):
|
def setup_model(c):
|
||||||
if c.model_params['model_name'].lower() == 'lstm':
|
if c.model_params["model_name"].lower() == "lstm":
|
||||||
model = LSTMSpeakerEncoder(c.model_params["input_dim"], c.model_params["proj_dim"], c.model_params["lstm_dim"], c.model_params["num_lstm_layers"])
|
model = LSTMSpeakerEncoder(
|
||||||
elif c.model_params['model_name'].lower() == 'resnet':
|
c.model_params["input_dim"],
|
||||||
|
c.model_params["proj_dim"],
|
||||||
|
c.model_params["lstm_dim"],
|
||||||
|
c.model_params["num_lstm_layers"],
|
||||||
|
)
|
||||||
|
elif c.model_params["model_name"].lower() == "resnet":
|
||||||
model = ResNetSpeakerEncoder(input_dim=c.model_params["input_dim"], proj_dim=c.model_params["proj_dim"])
|
model = ResNetSpeakerEncoder(input_dim=c.model_params["input_dim"], proj_dim=c.model_params["proj_dim"])
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_step, epoch):
|
def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_step, epoch):
|
||||||
checkpoint_path = "checkpoint_{}.pth.tar".format(current_step)
|
checkpoint_path = "checkpoint_{}.pth.tar".format(current_step)
|
||||||
checkpoint_path = os.path.join(out_path, checkpoint_path)
|
checkpoint_path = os.path.join(out_path, checkpoint_path)
|
||||||
|
|
|
@ -6,6 +6,7 @@ from tests import get_tests_input_path
|
||||||
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
||||||
from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder
|
from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder
|
||||||
from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder
|
from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder
|
||||||
|
|
||||||
file_path = get_tests_input_path()
|
file_path = get_tests_input_path()
|
||||||
|
|
||||||
|
|
||||||
|
@ -39,6 +40,7 @@ class LSTMSpeakerEncoderTests(unittest.TestCase):
|
||||||
assert output.shape[1] == 256
|
assert output.shape[1] == 256
|
||||||
assert len(output.shape) == 2
|
assert len(output.shape) == 2
|
||||||
|
|
||||||
|
|
||||||
class ResNetSpeakerEncoderTests(unittest.TestCase):
|
class ResNetSpeakerEncoderTests(unittest.TestCase):
|
||||||
# pylint: disable=R0201
|
# pylint: disable=R0201
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
|
@ -65,6 +67,7 @@ class ResNetSpeakerEncoderTests(unittest.TestCase):
|
||||||
assert output.shape[1] == 256
|
assert output.shape[1] == 256
|
||||||
assert len(output.shape) == 2
|
assert len(output.shape) == 2
|
||||||
|
|
||||||
|
|
||||||
class GE2ELossTests(unittest.TestCase):
|
class GE2ELossTests(unittest.TestCase):
|
||||||
# pylint: disable=R0201
|
# pylint: disable=R0201
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
|
@ -92,6 +95,7 @@ class GE2ELossTests(unittest.TestCase):
|
||||||
output = loss.forward(dummy_input)
|
output = loss.forward(dummy_input)
|
||||||
assert output.item() < 0.005
|
assert output.item() < 0.005
|
||||||
|
|
||||||
|
|
||||||
class AngleProtoLossTests(unittest.TestCase):
|
class AngleProtoLossTests(unittest.TestCase):
|
||||||
# pylint: disable=R0201
|
# pylint: disable=R0201
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
|
@ -121,6 +125,7 @@ class AngleProtoLossTests(unittest.TestCase):
|
||||||
output = loss.forward(dummy_input)
|
output = loss.forward(dummy_input)
|
||||||
assert output.item() < 0.005
|
assert output.item() < 0.005
|
||||||
|
|
||||||
|
|
||||||
class SoftmaxAngleProtoLossTests(unittest.TestCase):
|
class SoftmaxAngleProtoLossTests(unittest.TestCase):
|
||||||
# pylint: disable=R0201
|
# pylint: disable=R0201
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
|
|
|
@ -46,7 +46,7 @@ run_cli(command_train)
|
||||||
shutil.rmtree(continue_path)
|
shutil.rmtree(continue_path)
|
||||||
|
|
||||||
# test resnet speaker encoder
|
# test resnet speaker encoder
|
||||||
config.model_params['model_name'] = "resnet"
|
config.model_params["model_name"] = "resnet"
|
||||||
config.save_json(config_path)
|
config.save_json(config_path)
|
||||||
|
|
||||||
# train the model for one epoch
|
# train the model for one epoch
|
||||||
|
|
Loading…
Reference in New Issue