mirror of https://github.com/coqui-ai/TTS.git
make style
This commit is contained in:
parent
975531b3f2
commit
bec85ac58d
|
@ -5,11 +5,11 @@ import os
|
|||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from TTS.config import BaseDatasetConfig, load_config
|
||||
from TTS.speaker_encoder.utils.generic_utils import setup_model
|
||||
from TTS.tts.datasets.preprocess import load_meta_data
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.config import load_config, BaseDatasetConfig
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Compute embedding vectors for each wav file in a dataset. If "target_dataset" is defined, it generates "speakers.json" necessary for training a multi-speaker model.'
|
||||
|
@ -100,7 +100,7 @@ for idx, wav_file in enumerate(tqdm(wav_files)):
|
|||
|
||||
if speaker_mapping:
|
||||
# save speaker_mapping if target dataset is defined
|
||||
if '.json' not in args.output_path:
|
||||
if ".json" not in args.output_path:
|
||||
mapping_file_path = os.path.join(args.output_path, "speakers.json")
|
||||
else:
|
||||
mapping_file_path = args.output_path
|
||||
|
|
|
@ -10,10 +10,8 @@ import torch
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from TTS.speaker_encoder.dataset import SpeakerEncoderDataset
|
||||
|
||||
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
||||
from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_model
|
||||
|
||||
from TTS.speaker_encoder.utils.visual import plot_embeddings
|
||||
from TTS.tts.datasets.preprocess import load_meta_data
|
||||
from TTS.utils.arguments import init_training
|
||||
|
@ -45,7 +43,7 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
|
|||
storage_size=c.storage["storage_size"],
|
||||
sample_from_storage_p=c.storage["sample_from_storage_p"],
|
||||
verbose=verbose,
|
||||
augmentation_config=c.audio_augmentation
|
||||
augmentation_config=c.audio_augmentation,
|
||||
)
|
||||
|
||||
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
|
@ -170,19 +168,18 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
else:
|
||||
raise Exception("The %s not is a loss supported" % c.loss)
|
||||
|
||||
|
||||
if args.restore_path:
|
||||
checkpoint = torch.load(args.restore_path)
|
||||
try:
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
|
||||
if 'criterion' in checkpoint:
|
||||
if "criterion" in checkpoint:
|
||||
criterion.load_state_dict(checkpoint["criterion"])
|
||||
|
||||
except (KeyError, RuntimeError):
|
||||
print(" > Partial model initialization.")
|
||||
model_dict = model.state_dict()
|
||||
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
||||
model_dict = set_init_dict(model_dict, checkpoint["model"], c)
|
||||
model.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
for group in optimizer.param_groups:
|
||||
|
|
|
@ -1,24 +1,25 @@
|
|||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from TTS.speaker_encoder.utils.generic_utils import AugmentWAV, Storage
|
||||
|
||||
|
||||
class SpeakerEncoderDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
ap,
|
||||
meta_data,
|
||||
voice_len=1.6,
|
||||
num_speakers_in_batch=64,
|
||||
storage_size=1,
|
||||
sample_from_storage_p=0.5,
|
||||
num_utter_per_speaker=10,
|
||||
skip_speakers=False,
|
||||
verbose=False,
|
||||
augmentation_config=None
|
||||
self,
|
||||
ap,
|
||||
meta_data,
|
||||
voice_len=1.6,
|
||||
num_speakers_in_batch=64,
|
||||
storage_size=1,
|
||||
sample_from_storage_p=0.5,
|
||||
num_utter_per_speaker=10,
|
||||
skip_speakers=False,
|
||||
verbose=False,
|
||||
augmentation_config=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -38,23 +39,25 @@ class SpeakerEncoderDataset(Dataset):
|
|||
self.verbose = verbose
|
||||
self.__parse_items()
|
||||
storage_max_size = storage_size * num_speakers_in_batch
|
||||
self.storage = Storage(maxsize=storage_max_size, storage_batchs=storage_size, num_speakers_in_batch=num_speakers_in_batch)
|
||||
self.storage = Storage(
|
||||
maxsize=storage_max_size, storage_batchs=storage_size, num_speakers_in_batch=num_speakers_in_batch
|
||||
)
|
||||
self.sample_from_storage_p = float(sample_from_storage_p)
|
||||
|
||||
speakers_aux = list(self.speakers)
|
||||
speakers_aux.sort()
|
||||
self.speakerid_to_classid = {key : i for i, key in enumerate(speakers_aux)}
|
||||
self.speakerid_to_classid = {key: i for i, key in enumerate(speakers_aux)}
|
||||
|
||||
# Augmentation
|
||||
self.augmentator = None
|
||||
self.gaussian_augmentation_config = None
|
||||
if augmentation_config:
|
||||
self.data_augmentation_p = augmentation_config['p']
|
||||
if self.data_augmentation_p and ('additive' in augmentation_config or 'rir' in augmentation_config):
|
||||
self.data_augmentation_p = augmentation_config["p"]
|
||||
if self.data_augmentation_p and ("additive" in augmentation_config or "rir" in augmentation_config):
|
||||
self.augmentator = AugmentWAV(ap, augmentation_config)
|
||||
|
||||
if 'gaussian' in augmentation_config.keys():
|
||||
self.gaussian_augmentation_config = augmentation_config['gaussian']
|
||||
if "gaussian" in augmentation_config.keys():
|
||||
self.gaussian_augmentation_config = augmentation_config["gaussian"]
|
||||
|
||||
if self.verbose:
|
||||
print("\n > DataLoader initialization")
|
||||
|
@ -231,9 +234,13 @@ class SpeakerEncoderDataset(Dataset):
|
|||
offset = random.randint(0, wav.shape[0] - self.seq_len)
|
||||
wav = wav[offset : offset + self.seq_len]
|
||||
# add random gaussian noise
|
||||
if self.gaussian_augmentation_config and self.gaussian_augmentation_config['p']:
|
||||
if random.random() < self.gaussian_augmentation_config['p']:
|
||||
wav += np.random.normal(self.gaussian_augmentation_config['min_amplitude'], self.gaussian_augmentation_config['max_amplitude'], size=len(wav))
|
||||
if self.gaussian_augmentation_config and self.gaussian_augmentation_config["p"]:
|
||||
if random.random() < self.gaussian_augmentation_config["p"]:
|
||||
wav += np.random.normal(
|
||||
self.gaussian_augmentation_config["min_amplitude"],
|
||||
self.gaussian_augmentation_config["max_amplitude"],
|
||||
size=len(wav),
|
||||
)
|
||||
mel = self.ap.melspectrogram(wav)
|
||||
feats_.append(torch.FloatTensor(mel))
|
||||
|
||||
|
|
|
@ -162,6 +162,7 @@ class AngleProtoLoss(nn.Module):
|
|||
L = self.criterion(cos_sim_matrix, label)
|
||||
return L
|
||||
|
||||
|
||||
class SoftmaxLoss(nn.Module):
|
||||
"""
|
||||
Implementation of the Softmax loss as defined in https://arxiv.org/abs/2003.11982
|
||||
|
@ -169,13 +170,14 @@ class SoftmaxLoss(nn.Module):
|
|||
- embedding_dim (float): speaker embedding dim
|
||||
- n_speakers (float): number of speakers
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim, n_speakers):
|
||||
super().__init__()
|
||||
|
||||
self.criterion = torch.nn.CrossEntropyLoss()
|
||||
self.fc = nn.Linear(embedding_dim, n_speakers)
|
||||
|
||||
print('Initialised Softmax Loss')
|
||||
print("Initialised Softmax Loss")
|
||||
|
||||
def forward(self, x, label=None):
|
||||
# reshape for compatibility
|
||||
|
@ -187,6 +189,7 @@ class SoftmaxLoss(nn.Module):
|
|||
|
||||
return L
|
||||
|
||||
|
||||
class SoftmaxAngleProtoLoss(nn.Module):
|
||||
"""
|
||||
Implementation of the Softmax AnglePrototypical loss as defined in https://arxiv.org/abs/2009.14153
|
||||
|
@ -196,13 +199,14 @@ class SoftmaxAngleProtoLoss(nn.Module):
|
|||
- init_w (float): defines the initial value of w
|
||||
- init_b (float): definies the initial value of b
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim, n_speakers, init_w=10.0, init_b=-5.0):
|
||||
super().__init__()
|
||||
|
||||
self.softmax = SoftmaxLoss(embedding_dim, n_speakers)
|
||||
self.angleproto = AngleProtoLoss(init_w, init_b)
|
||||
|
||||
print('Initialised SoftmaxAnglePrototypical Loss')
|
||||
print("Initialised SoftmaxAnglePrototypical Loss")
|
||||
|
||||
def forward(self, x, label=None):
|
||||
"""
|
||||
|
@ -213,4 +217,4 @@ class SoftmaxAngleProtoLoss(nn.Module):
|
|||
|
||||
Ls = self.softmax(x, label)
|
||||
|
||||
return Ls+Lp
|
||||
return Ls + Lp
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class SELayer(nn.Module):
|
||||
def __init__(self, channel, reduction=8):
|
||||
super(SELayer, self).__init__()
|
||||
|
@ -10,7 +11,7 @@ class SELayer(nn.Module):
|
|||
nn.Linear(channel, channel // reduction),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(channel // reduction, channel),
|
||||
nn.Sigmoid()
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -19,6 +20,7 @@ class SELayer(nn.Module):
|
|||
y = self.fc(y).view(b, c, 1, 1)
|
||||
return x * y
|
||||
|
||||
|
||||
class SEBasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
|
@ -51,12 +53,22 @@ class SEBasicBlock(nn.Module):
|
|||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNetSpeakerEncoder(nn.Module):
|
||||
"""Implementation of the model H/ASP without batch normalization in speaker embedding. This model was proposed in: https://arxiv.org/abs/2009.14153
|
||||
Adapted from: https://github.com/clovaai/voxceleb_trainer
|
||||
"""
|
||||
|
||||
# pylint: disable=W0102
|
||||
def __init__(self, input_dim=64, proj_dim=512, layers=[3, 4, 6, 3], num_filters=[32, 64, 128, 256], encoder_type='ASP', log_input=False):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim=64,
|
||||
proj_dim=512,
|
||||
layers=[3, 4, 6, 3],
|
||||
num_filters=[32, 64, 128, 256],
|
||||
encoder_type="ASP",
|
||||
log_input=False,
|
||||
):
|
||||
super(ResNetSpeakerEncoder, self).__init__()
|
||||
|
||||
self.encoder_type = encoder_type
|
||||
|
@ -74,7 +86,7 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
|
||||
self.instancenorm = nn.InstanceNorm1d(input_dim)
|
||||
|
||||
outmap_size = int(self.input_dim/8)
|
||||
outmap_size = int(self.input_dim / 8)
|
||||
|
||||
self.attention = nn.Sequential(
|
||||
nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1),
|
||||
|
@ -82,14 +94,14 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
nn.BatchNorm1d(128),
|
||||
nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1),
|
||||
nn.Softmax(dim=2),
|
||||
)
|
||||
)
|
||||
|
||||
if self.encoder_type == "SAP":
|
||||
out_dim = num_filters[3] * outmap_size
|
||||
elif self.encoder_type == "ASP":
|
||||
out_dim = num_filters[3] * outmap_size * 2
|
||||
else:
|
||||
raise ValueError('Undefined encoder')
|
||||
raise ValueError("Undefined encoder")
|
||||
|
||||
self.fc = nn.Linear(out_dim, proj_dim)
|
||||
|
||||
|
@ -98,7 +110,7 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
def _init_layers(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
@ -107,8 +119,7 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=stride, bias=False),
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
|
||||
|
@ -131,7 +142,7 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
with torch.no_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
if self.log_input:
|
||||
x = (x+1e-6).log()
|
||||
x = (x + 1e-6).log()
|
||||
x = self.instancenorm(x).unsqueeze(1)
|
||||
|
||||
x = self.conv1(x)
|
||||
|
@ -151,7 +162,7 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
x = torch.sum(x * w, dim=2)
|
||||
elif self.encoder_type == "ASP":
|
||||
mu = torch.sum(x * w, dim=2)
|
||||
sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu ** 2).clamp(min=1e-5))
|
||||
sg = torch.sqrt((torch.sum((x ** 2) * w, dim=2) - mu ** 2).clamp(min=1e-5))
|
||||
x = torch.cat((mu, sg), 1)
|
||||
|
||||
x = x.view(x.size()[0], -1)
|
||||
|
@ -172,12 +183,12 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
if max_len < num_frames:
|
||||
num_frames = max_len
|
||||
|
||||
offsets = np.linspace(0, max_len-num_frames, num=num_eval)
|
||||
offsets = np.linspace(0, max_len - num_frames, num=num_eval)
|
||||
|
||||
frames_batch = []
|
||||
for offset in offsets:
|
||||
offset = int(offset)
|
||||
end_offset = int(offset+num_frames)
|
||||
end_offset = int(offset + num_frames)
|
||||
frames = x[:, offset:end_offset]
|
||||
frames_batch.append(frames)
|
||||
|
||||
|
|
|
@ -25,10 +25,7 @@ class SpeakerEncoderConfig(BaseTrainingConfig):
|
|||
}
|
||||
)
|
||||
|
||||
audio_augmentation : dict = field(
|
||||
default_factory=lambda: {
|
||||
}
|
||||
)
|
||||
audio_augmentation: dict = field(default_factory=lambda: {})
|
||||
|
||||
storage: dict = field(
|
||||
default_factory=lambda: {
|
||||
|
|
|
@ -1,18 +1,18 @@
|
|||
import re
|
||||
import datetime
|
||||
import glob
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from multiprocessing import Manager
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import glob
|
||||
import random
|
||||
import datetime
|
||||
|
||||
from scipy import signal
|
||||
from multiprocessing import Manager
|
||||
|
||||
from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder
|
||||
from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder
|
||||
|
||||
|
||||
class Storage(object):
|
||||
def __init__(self, maxsize, storage_batchs, num_speakers_in_batch, num_threads=8):
|
||||
# use multiprocessing for threading safe
|
||||
|
@ -53,19 +53,19 @@ class Storage(object):
|
|||
return self.storage[random.randint(0, storage_size)]
|
||||
|
||||
def get_random_sample_fast(self):
|
||||
'''Call this method only when storage is full'''
|
||||
"""Call this method only when storage is full"""
|
||||
return self.storage[random.randint(0, self.safe_storage_size)]
|
||||
|
||||
class AugmentWAV(object):
|
||||
|
||||
class AugmentWAV(object):
|
||||
def __init__(self, ap, augmentation_config):
|
||||
|
||||
self.ap = ap
|
||||
self.use_additive_noise = False
|
||||
|
||||
if 'additive' in augmentation_config.keys():
|
||||
self.additive_noise_config = augmentation_config['additive']
|
||||
additive_path = self.additive_noise_config['sounds_path']
|
||||
if "additive" in augmentation_config.keys():
|
||||
self.additive_noise_config = augmentation_config["additive"]
|
||||
additive_path = self.additive_noise_config["sounds_path"]
|
||||
if additive_path:
|
||||
self.use_additive_noise = True
|
||||
# get noise types
|
||||
|
@ -74,12 +74,12 @@ class AugmentWAV(object):
|
|||
if isinstance(self.additive_noise_config[key], dict):
|
||||
self.additive_noise_types.append(key)
|
||||
|
||||
additive_files = glob.glob(os.path.join(additive_path, '**/*.wav'), recursive=True)
|
||||
additive_files = glob.glob(os.path.join(additive_path, "**/*.wav"), recursive=True)
|
||||
|
||||
self.noise_list = {}
|
||||
|
||||
for wav_file in additive_files:
|
||||
noise_dir = wav_file.replace(additive_path, '').split(os.sep)[0]
|
||||
noise_dir = wav_file.replace(additive_path, "").split(os.sep)[0]
|
||||
# ignore not listed directories
|
||||
if noise_dir not in self.additive_noise_types:
|
||||
continue
|
||||
|
@ -87,14 +87,16 @@ class AugmentWAV(object):
|
|||
self.noise_list[noise_dir] = []
|
||||
self.noise_list[noise_dir].append(wav_file)
|
||||
|
||||
print(f" | > Using Additive Noise Augmentation: with {len(additive_files)} audios instances from {self.additive_noise_types}")
|
||||
print(
|
||||
f" | > Using Additive Noise Augmentation: with {len(additive_files)} audios instances from {self.additive_noise_types}"
|
||||
)
|
||||
|
||||
self.use_rir = False
|
||||
|
||||
if 'rir' in augmentation_config.keys():
|
||||
self.rir_config = augmentation_config['rir']
|
||||
if self.rir_config['rir_path']:
|
||||
self.rir_files = glob.glob(os.path.join(self.rir_config['rir_path'], '**/*.wav'), recursive=True)
|
||||
if "rir" in augmentation_config.keys():
|
||||
self.rir_config = augmentation_config["rir"]
|
||||
if self.rir_config["rir_path"]:
|
||||
self.rir_files = glob.glob(os.path.join(self.rir_config["rir_path"], "**/*.wav"), recursive=True)
|
||||
self.use_rir = True
|
||||
|
||||
print(f" | > Using RIR Noise Augmentation: with {len(self.rir_files)} audios instances")
|
||||
|
@ -111,9 +113,15 @@ class AugmentWAV(object):
|
|||
|
||||
def additive_noise(self, noise_type, audio):
|
||||
|
||||
clean_db = 10 * np.log10(np.mean(audio**2) + 1e-4)
|
||||
clean_db = 10 * np.log10(np.mean(audio ** 2) + 1e-4)
|
||||
|
||||
noise_list = random.sample(self.noise_list[noise_type], random.randint(self.additive_noise_config[noise_type]['min_num_noises'], self.additive_noise_config[noise_type]['max_num_noises']))
|
||||
noise_list = random.sample(
|
||||
self.noise_list[noise_type],
|
||||
random.randint(
|
||||
self.additive_noise_config[noise_type]["min_num_noises"],
|
||||
self.additive_noise_config[noise_type]["max_num_noises"],
|
||||
),
|
||||
)
|
||||
|
||||
audio_len = audio.shape[0]
|
||||
noises_wav = None
|
||||
|
@ -123,7 +131,10 @@ class AugmentWAV(object):
|
|||
if noiseaudio.shape[0] < audio_len:
|
||||
continue
|
||||
|
||||
noise_snr = random.uniform(self.additive_noise_config[noise_type]['min_snr_in_db'], self.additive_noise_config[noise_type]['max_num_noises'])
|
||||
noise_snr = random.uniform(
|
||||
self.additive_noise_config[noise_type]["min_snr_in_db"],
|
||||
self.additive_noise_config[noise_type]["max_num_noises"],
|
||||
)
|
||||
noise_db = 10 * np.log10(np.mean(noiseaudio ** 2) + 1e-4)
|
||||
noise_wav = np.sqrt(10 ** ((clean_db - noise_db - noise_snr) / 10)) * noiseaudio
|
||||
|
||||
|
@ -144,7 +155,7 @@ class AugmentWAV(object):
|
|||
rir_file = random.choice(self.rir_files)
|
||||
rir = self.ap.load_wav(rir_file, sr=self.ap.sample_rate)
|
||||
rir = rir / np.sqrt(np.sum(rir ** 2))
|
||||
return signal.convolve(audio, rir, mode=self.rir_config['conv_mode'])[:audio_len]
|
||||
return signal.convolve(audio, rir, mode=self.rir_config["conv_mode"])[:audio_len]
|
||||
|
||||
def apply_one(self, audio):
|
||||
noise_type = random.choice(self.global_noise_list)
|
||||
|
@ -153,17 +164,25 @@ class AugmentWAV(object):
|
|||
|
||||
return self.additive_noise(noise_type, audio)
|
||||
|
||||
|
||||
def to_camel(text):
|
||||
text = text.capitalize()
|
||||
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
||||
|
||||
|
||||
def setup_model(c):
|
||||
if c.model_params['model_name'].lower() == 'lstm':
|
||||
model = LSTMSpeakerEncoder(c.model_params["input_dim"], c.model_params["proj_dim"], c.model_params["lstm_dim"], c.model_params["num_lstm_layers"])
|
||||
elif c.model_params['model_name'].lower() == 'resnet':
|
||||
if c.model_params["model_name"].lower() == "lstm":
|
||||
model = LSTMSpeakerEncoder(
|
||||
c.model_params["input_dim"],
|
||||
c.model_params["proj_dim"],
|
||||
c.model_params["lstm_dim"],
|
||||
c.model_params["num_lstm_layers"],
|
||||
)
|
||||
elif c.model_params["model_name"].lower() == "resnet":
|
||||
model = ResNetSpeakerEncoder(input_dim=c.model_params["input_dim"], proj_dim=c.model_params["proj_dim"])
|
||||
return model
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_step, epoch):
|
||||
checkpoint_path = "checkpoint_{}.pth.tar".format(current_step)
|
||||
checkpoint_path = os.path.join(out_path, checkpoint_path)
|
||||
|
|
|
@ -6,6 +6,7 @@ from tests import get_tests_input_path
|
|||
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
||||
from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder
|
||||
from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder
|
||||
|
||||
file_path = get_tests_input_path()
|
||||
|
||||
|
||||
|
@ -39,6 +40,7 @@ class LSTMSpeakerEncoderTests(unittest.TestCase):
|
|||
assert output.shape[1] == 256
|
||||
assert len(output.shape) == 2
|
||||
|
||||
|
||||
class ResNetSpeakerEncoderTests(unittest.TestCase):
|
||||
# pylint: disable=R0201
|
||||
def test_in_out(self):
|
||||
|
@ -65,6 +67,7 @@ class ResNetSpeakerEncoderTests(unittest.TestCase):
|
|||
assert output.shape[1] == 256
|
||||
assert len(output.shape) == 2
|
||||
|
||||
|
||||
class GE2ELossTests(unittest.TestCase):
|
||||
# pylint: disable=R0201
|
||||
def test_in_out(self):
|
||||
|
@ -92,6 +95,7 @@ class GE2ELossTests(unittest.TestCase):
|
|||
output = loss.forward(dummy_input)
|
||||
assert output.item() < 0.005
|
||||
|
||||
|
||||
class AngleProtoLossTests(unittest.TestCase):
|
||||
# pylint: disable=R0201
|
||||
def test_in_out(self):
|
||||
|
@ -121,6 +125,7 @@ class AngleProtoLossTests(unittest.TestCase):
|
|||
output = loss.forward(dummy_input)
|
||||
assert output.item() < 0.005
|
||||
|
||||
|
||||
class SoftmaxAngleProtoLossTests(unittest.TestCase):
|
||||
# pylint: disable=R0201
|
||||
def test_in_out(self):
|
||||
|
|
|
@ -46,7 +46,7 @@ run_cli(command_train)
|
|||
shutil.rmtree(continue_path)
|
||||
|
||||
# test resnet speaker encoder
|
||||
config.model_params['model_name'] = "resnet"
|
||||
config.model_params["model_name"] = "resnet"
|
||||
config.save_json(config_path)
|
||||
|
||||
# train the model for one epoch
|
||||
|
|
Loading…
Reference in New Issue