linter and test updates for speaker_encoder, gmm_Attention

This commit is contained in:
Eren Golge 2019-11-12 12:42:42 +01:00
parent 1401a0db6b
commit df1b8b3ec7
14 changed files with 171 additions and 207 deletions

View File

@ -136,6 +136,8 @@ class GravesAttention(nn.Module):
self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device) self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device)
self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device) self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device)
# pylint: disable=R0201
# pylint: disable=unused-argument
def preprocess_inputs(self, inputs): def preprocess_inputs(self, inputs):
return None return None
@ -376,8 +378,7 @@ def init_attn(attn_type, query_dim, embedding_dim, attention_dim,
attention_location_kernel_size, windowing, attention_location_kernel_size, windowing,
norm, forward_attn, trans_agent, norm, forward_attn, trans_agent,
forward_attn_mask) forward_attn_mask)
elif attn_type == "graves": if attn_type == "graves":
return GravesAttention(query_dim, attn_K) return GravesAttention(query_dim, attn_K)
else:
raise RuntimeError( raise RuntimeError(
" [!] Given Attention Type '{attn_type}' is not exist.") " [!] Given Attention Type '{attn_type}' is not exist.")

View File

@ -27,6 +27,7 @@ class Tacotron2(nn.Module):
separate_stopnet=True, separate_stopnet=True,
bidirectional_decoder=False): bidirectional_decoder=False):
super(Tacotron2, self).__init__() super(Tacotron2, self).__init__()
self.postnet_output_dim = postnet_output_dim
self.decoder_output_dim = decoder_output_dim self.decoder_output_dim = decoder_output_dim
self.n_frames_per_step = r self.n_frames_per_step = r
self.bidirectional_decoder = bidirectional_decoder self.bidirectional_decoder = bidirectional_decoder
@ -50,7 +51,7 @@ class Tacotron2(nn.Module):
location_attn, attn_K, separate_stopnet, proj_speaker_dim) location_attn, attn_K, separate_stopnet, proj_speaker_dim)
if self.bidirectional_decoder: if self.bidirectional_decoder:
self.decoder_backward = copy.deepcopy(self.decoder) self.decoder_backward = copy.deepcopy(self.decoder)
self.postnet = Postnet(self.decoder_output_dim) self.postnet = Postnet(self.postnet_output_dim)
def _init_states(self): def _init_states(self):
self.speaker_embeddings = None self.speaker_embeddings = None

View File

@ -6,51 +6,41 @@ import numpy as np
from tqdm import tqdm from tqdm import tqdm
import torch import torch
from torch.utils.data import DataLoader
from TTS.datasets.preprocess import get_preprocessor_by_name
from TTS.speaker_encoder.dataset import MyDataset
from TTS.speaker_encoder.model import SpeakerEncoder from TTS.speaker_encoder.model import SpeakerEncoder
from TTS.speaker_encoder.visual import plot_embeddings
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import load_config from TTS.utils.generic_utils import load_config
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Compute embedding vectors for each wav file in a dataset. ') description="Compute embedding vectors for each wav file in a dataset. "
parser.add_argument(
'model_path',
type=str,
help='Path to model outputs (checkpoint, tensorboard etc.).')
parser.add_argument(
'config_path',
type=str,
help='Path to config file for training.',
) )
parser.add_argument( parser.add_argument(
'data_path', "model_path", type=str, help="Path to model outputs (checkpoint, tensorboard etc.)."
type=str,
help='Defines the data path. It overwrites config.json.')
parser.add_argument(
'output_path',
type=str,
help='path for training outputs.')
parser.add_argument(
'--use_cuda', type=bool, help='flag to set cuda.', default=False
) )
parser.add_argument(
"config_path", type=str, help="Path to config file for training.",
)
parser.add_argument(
"data_path", type=str, help="Defines the data path. It overwrites config.json."
)
parser.add_argument("output_path", type=str, help="path for training outputs.")
parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=False)
args = parser.parse_args() args = parser.parse_args()
c = load_config(args.config_path) c = load_config(args.config_path)
ap = AudioProcessor(**c['audio']) ap = AudioProcessor(**c["audio"])
wav_files = glob.glob(args.data_path + '/**/*.wav', recursive=True) wav_files = glob.glob(args.data_path + "/**/*.wav", recursive=True)
output_files = [wav_file.replace(args.data_path, args.output_path).replace( output_files = [
'.wav', '.npy') for wav_file in wav_files] wav_file.replace(args.data_path, args.output_path).replace(".wav", ".npy")
for wav_file in wav_files
]
for output_file in output_files: for output_file in output_files:
os.makedirs(os.path.dirname(output_file), exist_ok=True) os.makedirs(os.path.dirname(output_file), exist_ok=True)
model = SpeakerEncoder(**c.model) model = SpeakerEncoder(**c.model)
model.load_state_dict(torch.load(args.model_path)['model']) model.load_state_dict(torch.load(args.model_path)["model"])
model.eval() model.eval()
if args.use_cuda: if args.use_cuda:
model.cuda() model.cuda()

View File

@ -1,23 +1,12 @@
import os
import numpy as np import numpy as np
import collections
import torch import torch
import random import random
from torch.utils.data import Dataset from torch.utils.data import Dataset
from TTS.utils.text import text_to_sequence, phoneme_to_sequence, pad_with_eos_bos
from TTS.utils.data import prepare_data, prepare_tensor, prepare_stop_target
class MyDataset(Dataset): class MyDataset(Dataset):
def __init__(self, def __init__(self, ap, meta_data, voice_len=1.6, num_speakers_in_batch=64,
ap, num_utter_per_speaker=10, skip_speakers=False, verbose=False):
meta_data,
voice_len=1.6,
num_speakers_in_batch=64,
num_utter_per_speaker=10,
skip_speakers=False,
verbose=False):
""" """
Args: Args:
ap (TTS.utils.AudioProcessor): audio processor object. ap (TTS.utils.AudioProcessor): audio processor object.
@ -29,6 +18,7 @@ class MyDataset(Dataset):
self.sample_rate = ap.sample_rate self.sample_rate = ap.sample_rate
self.voice_len = voice_len self.voice_len = voice_len
self.seq_len = int(voice_len * self.sample_rate) self.seq_len = int(voice_len * self.sample_rate)
self.num_speakers_in_batch = num_speakers_in_batch
self.num_utter_per_speaker = num_utter_per_speaker self.num_utter_per_speaker = num_utter_per_speaker
self.skip_speakers = skip_speakers self.skip_speakers = skip_speakers
self.ap = ap self.ap = ap
@ -47,16 +37,16 @@ class MyDataset(Dataset):
def load_data(self, idx): def load_data(self, idx):
text, wav_file, speaker_name = self.items[idx] text, wav_file, speaker_name = self.items[idx]
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32) wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
mel = self.ap.melspectrogram(wav).astype('float32') mel = self.ap.melspectrogram(wav).astype("float32")
# sample seq_len # sample seq_len
assert text.size > 0, self.items[idx][1] assert text.size > 0, self.items[idx][1]
assert wav.size > 0, self.items[idx][1] assert wav.size > 0, self.items[idx][1]
sample = { sample = {
'mel': mel, "mel": mel,
'item_idx': self.items[idx][1], "item_idx": self.items[idx][1],
'speaker_name': speaker_name "speaker_name": speaker_name,
} }
return sample return sample
@ -64,26 +54,32 @@ class MyDataset(Dataset):
""" """
Find unique speaker ids and create a dict mapping utterances from speaker id Find unique speaker ids and create a dict mapping utterances from speaker id
""" """
speakers = list(set([item[-1] for item in self.items])) speakers = list({item[-1] for item in self.items})
self.speaker_to_utters = {} self.speaker_to_utters = {}
self.speakers = [] self.speakers = []
for speaker in speakers: for speaker in speakers:
speaker_utters = [item[1] for item in self.items if item[2] == speaker] speaker_utters = [item[1] for item in self.items if item[2] == speaker]
if len(speaker_utters) < self.num_utter_per_speaker and self.skip_speakers: if len(speaker_utters) < self.num_utter_per_speaker and self.skip_speakers:
print(f" [!] Skipped speaker {speaker}. Not enough utterances {self.num_utter_per_speaker} vs {len(speaker_utters)}.") print(
f" [!] Skipped speaker {speaker}. Not enough utterances {self.num_utter_per_speaker} vs {len(speaker_utters)}."
)
else: else:
self.speakers.append(speaker) self.speakers.append(speaker)
self.speaker_to_utters[speaker] = speaker_utters self.speaker_to_utters[speaker] = speaker_utters
def __len__(self): def __len__(self):
return int(1e+10) return int(1e10)
def __sample_speaker(self): def __sample_speaker(self):
speaker = random.sample(self.speakers, 1)[0] speaker = random.sample(self.speakers, 1)[0]
if self.num_utter_per_speaker > len(self.speaker_to_utters[speaker]): if self.num_utter_per_speaker > len(self.speaker_to_utters[speaker]):
utters = random.choices(self.speaker_to_utters[speaker], k=self.num_utter_per_speaker) utters = random.choices(
self.speaker_to_utters[speaker], k=self.num_utter_per_speaker
)
else: else:
utters = random.sample(self.speaker_to_utters[speaker], self.num_utter_per_speaker) utters = random.sample(
self.speaker_to_utters[speaker], self.num_utter_per_speaker
)
return speaker, utters return speaker, utters
def __sample_speaker_utterances(self, speaker): def __sample_speaker_utterances(self, speaker):
@ -92,7 +88,7 @@ class MyDataset(Dataset):
""" """
feats = [] feats = []
labels = [] labels = []
for idx in range(self.num_utter_per_speaker): for _ in range(self.num_utter_per_speaker):
# TODO:dummy but works # TODO:dummy but works
while True: while True:
if len(self.speaker_to_utters[speaker]) > 0: if len(self.speaker_to_utters[speaker]) > 0:
@ -104,11 +100,10 @@ class MyDataset(Dataset):
wav = self.load_wav(utter) wav = self.load_wav(utter)
if wav.shape[0] - self.seq_len > 0: if wav.shape[0] - self.seq_len > 0:
break break
else:
self.speaker_to_utters[speaker].remove(utter) self.speaker_to_utters[speaker].remove(utter)
offset = random.randint(0, wav.shape[0] - self.seq_len) offset = random.randint(0, wav.shape[0] - self.seq_len)
mel = self.ap.melspectrogram(wav[offset:offset+self.seq_len]) mel = self.ap.melspectrogram(wav[offset : offset + self.seq_len])
feats.append(torch.FloatTensor(mel)) feats.append(torch.FloatTensor(mel))
labels.append(speaker) labels.append(speaker)
return feats, labels return feats, labels

View File

@ -5,9 +5,8 @@ import torch.nn.functional as F
# adapted from https://github.com/cvqluu/GE2E-Loss # adapted from https://github.com/cvqluu/GE2E-Loss
class GE2ELoss(nn.Module): class GE2ELoss(nn.Module):
def __init__(self, init_w=10.0, init_b=-5.0, loss_method="softmax"):
def __init__(self, init_w=10.0, init_b=-5.0, loss_method='softmax'): """
'''
Implementation of the Generalized End-to-End loss defined in https://arxiv.org/abs/1710.10467 [1] Implementation of the Generalized End-to-End loss defined in https://arxiv.org/abs/1710.10467 [1]
Accepts an input of size (N, M, D) Accepts an input of size (N, M, D)
where N is the number of speakers in the batch, where N is the number of speakers in the batch,
@ -16,24 +15,27 @@ class GE2ELoss(nn.Module):
Args: Args:
- init_w (float): defines the initial value of w in Equation (5) of [1] - init_w (float): defines the initial value of w in Equation (5) of [1]
- init_b (float): definies the initial value of b in Equation (5) of [1] - init_b (float): definies the initial value of b in Equation (5) of [1]
''' """
super(GE2ELoss, self).__init__() super(GE2ELoss, self).__init__()
# pylint: disable=E1102
self.w = nn.Parameter(torch.tensor(init_w)) self.w = nn.Parameter(torch.tensor(init_w))
# pylint: disable=E1102
self.b = nn.Parameter(torch.tensor(init_b)) self.b = nn.Parameter(torch.tensor(init_b))
self.loss_method = loss_method self.loss_method = loss_method
assert self.loss_method in ['softmax', 'contrast'] assert self.loss_method in ["softmax", "contrast"]
if self.loss_method == 'softmax': if self.loss_method == "softmax":
self.embed_loss = self.embed_loss_softmax self.embed_loss = self.embed_loss_softmax
if self.loss_method == 'contrast': if self.loss_method == "contrast":
self.embed_loss = self.embed_loss_contrast self.embed_loss = self.embed_loss_contrast
# pylint: disable=R0201
def calc_new_centroids(self, dvecs, centroids, spkr, utt): def calc_new_centroids(self, dvecs, centroids, spkr, utt):
''' """
Calculates the new centroids excluding the reference utterance Calculates the new centroids excluding the reference utterance
''' """
excl = torch.cat((dvecs[spkr, :utt], dvecs[spkr, utt+1:])) excl = torch.cat((dvecs[spkr, :utt], dvecs[spkr, utt + 1 :]))
excl = torch.mean(excl, 0) excl = torch.mean(excl, 0)
new_centroids = [] new_centroids = []
for i, centroid in enumerate(centroids): for i, centroid in enumerate(centroids):
@ -44,26 +46,36 @@ class GE2ELoss(nn.Module):
return torch.stack(new_centroids) return torch.stack(new_centroids)
def calc_cosine_sim(self, dvecs, centroids): def calc_cosine_sim(self, dvecs, centroids):
''' """
Make the cosine similarity matrix with dims (N,M,N) Make the cosine similarity matrix with dims (N,M,N)
''' """
cos_sim_matrix = [] cos_sim_matrix = []
for spkr_idx, speaker in enumerate(dvecs): for spkr_idx, speaker in enumerate(dvecs):
cs_row = [] cs_row = []
for utt_idx, utterance in enumerate(speaker): for utt_idx, utterance in enumerate(speaker):
new_centroids = self.calc_new_centroids( new_centroids = self.calc_new_centroids(
dvecs, centroids, spkr_idx, utt_idx) dvecs, centroids, spkr_idx, utt_idx
)
# vector based cosine similarity for speed # vector based cosine similarity for speed
cs_row.append(torch.clamp(torch.mm(utterance.unsqueeze(1).transpose(0, 1), new_centroids.transpose( cs_row.append(
0, 1)) / (torch.norm(utterance) * torch.norm(new_centroids, dim=1)), 1e-6)) torch.clamp(
torch.mm(
utterance.unsqueeze(1).transpose(0, 1),
new_centroids.transpose(0, 1),
)
/ (torch.norm(utterance) * torch.norm(new_centroids, dim=1)),
1e-6,
)
)
cs_row = torch.cat(cs_row, dim=0) cs_row = torch.cat(cs_row, dim=0)
cos_sim_matrix.append(cs_row) cos_sim_matrix.append(cs_row)
return torch.stack(cos_sim_matrix) return torch.stack(cos_sim_matrix)
# pylint: disable=R0201
def embed_loss_softmax(self, dvecs, cos_sim_matrix): def embed_loss_softmax(self, dvecs, cos_sim_matrix):
''' """
Calculates the loss on each embedding $L(e_{ji})$ by taking softmax Calculates the loss on each embedding $L(e_{ji})$ by taking softmax
''' """
N, M, _ = dvecs.shape N, M, _ = dvecs.shape
L = [] L = []
for j in range(N): for j in range(N):
@ -74,10 +86,11 @@ class GE2ELoss(nn.Module):
L.append(L_row) L.append(L_row)
return torch.stack(L) return torch.stack(L)
# pylint: disable=R0201
def embed_loss_contrast(self, dvecs, cos_sim_matrix): def embed_loss_contrast(self, dvecs, cos_sim_matrix):
''' """
Calculates the loss on each embedding $L(e_{ji})$ by contrast loss with closest centroid Calculates the loss on each embedding $L(e_{ji})$ by contrast loss with closest centroid
''' """
N, M, _ = dvecs.shape N, M, _ = dvecs.shape
L = [] L = []
for j in range(N): for j in range(N):
@ -85,17 +98,21 @@ class GE2ELoss(nn.Module):
for i in range(M): for i in range(M):
centroids_sigmoids = torch.sigmoid(cos_sim_matrix[j, i]) centroids_sigmoids = torch.sigmoid(cos_sim_matrix[j, i])
excl_centroids_sigmoids = torch.cat( excl_centroids_sigmoids = torch.cat(
(centroids_sigmoids[:j], centroids_sigmoids[j+1:])) (centroids_sigmoids[:j], centroids_sigmoids[j + 1 :])
)
L_row.append( L_row.append(
1. - torch.sigmoid(cos_sim_matrix[j, i, j]) + torch.max(excl_centroids_sigmoids)) 1.0
- torch.sigmoid(cos_sim_matrix[j, i, j])
+ torch.max(excl_centroids_sigmoids)
)
L_row = torch.stack(L_row) L_row = torch.stack(L_row)
L.append(L_row) L.append(L_row)
return torch.stack(L) return torch.stack(L)
def forward(self, dvecs): def forward(self, dvecs):
''' """
Calculates the GE2E loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats) Calculates the GE2E loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
''' """
centroids = torch.mean(dvecs, 1) centroids = torch.mean(dvecs, 1)
cos_sim_matrix = self.calc_cosine_sim(dvecs, centroids) cos_sim_matrix = self.calc_cosine_sim(dvecs, centroids)
torch.clamp(self.w, 1e-6) torch.clamp(self.w, 1e-6)

View File

@ -13,7 +13,7 @@ class LSTMWithProjection(nn.Module):
def forward(self, x): def forward(self, x):
self.lstm.flatten_parameters() self.lstm.flatten_parameters()
o, (h, c) = self.lstm(x) o, (_, _) = self.lstm(x)
return self.linear(o) return self.linear(o)
@ -22,16 +22,16 @@ class SpeakerEncoder(nn.Module):
super().__init__() super().__init__()
layers = [] layers = []
layers.append(LSTMWithProjection(input_dim, lstm_dim, proj_dim)) layers.append(LSTMWithProjection(input_dim, lstm_dim, proj_dim))
for _ in range(num_lstm_layers-1): for _ in range(num_lstm_layers - 1):
layers.append(LSTMWithProjection(proj_dim, lstm_dim, proj_dim)) layers.append(LSTMWithProjection(proj_dim, lstm_dim, proj_dim))
self.layers = nn.Sequential(*layers) self.layers = nn.Sequential(*layers)
self._init_layers() self._init_layers()
def _init_layers(self): def _init_layers(self):
for name, param in self.layers.named_parameters(): for name, param in self.layers.named_parameters():
if 'bias' in name: if "bias" in name:
nn.init.constant_(param, 0.0) nn.init.constant_(param, 0.0)
elif 'weight' in name: elif "weight" in name:
nn.init.xavier_normal_(param) nn.init.xavier_normal_(param)
def forward(self, x): def forward(self, x):
@ -81,7 +81,8 @@ class SpeakerEncoder(nn.Module):
if embed is None: if embed is None:
embed = self.inference(frames) embed = self.inference(frames)
else: else:
embed[cur_iter <= num_iters, :] += self.inference(frames[cur_iter <= num_iters, :, :]) embed[cur_iter <= num_iters, :] += self.inference(
frames[cur_iter <= num_iters, :, :]
)
return embed / num_iters return embed / num_iters

View File

@ -4,22 +4,21 @@ import torch as T
from TTS.speaker_encoder.model import SpeakerEncoder from TTS.speaker_encoder.model import SpeakerEncoder
from TTS.speaker_encoder.loss import GE2ELoss from TTS.speaker_encoder.loss import GE2ELoss
from TTS.speaker_encoder.dataset import MyDataset
from TTS.utils.audio import AudioProcessor
from torch.utils.data import DataLoader
from TTS.datasets.preprocess import libri_tts
from TTS.utils.generic_utils import load_config from TTS.utils.generic_utils import load_config
file_path = os.path.dirname(os.path.realpath(__file__)) + "/../tests/" file_path = os.path.dirname(os.path.realpath(__file__)) + "/../tests/"
c = load_config(os.path.join(file_path, 'test_config.json')) c = load_config(os.path.join(file_path, "test_config.json"))
class SpeakerEncoderTests(unittest.TestCase): class SpeakerEncoderTests(unittest.TestCase):
# pylint: disable=R0201
def test_in_out(self): def test_in_out(self):
dummy_input = T.rand(4, 20, 80) # B x T x D dummy_input = T.rand(4, 20, 80) # B x T x D
dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)] dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)]
model = SpeakerEncoder(input_dim=80, proj_dim=256, lstm_dim=768, num_lstm_layers=3) model = SpeakerEncoder(
input_dim=80, proj_dim=256, lstm_dim=768, num_lstm_layers=3
)
# computing d vectors # computing d vectors
output = model.forward(dummy_input) output = model.forward(dummy_input)
assert output.shape[0] == 4 assert output.shape[0] == 4
@ -35,8 +34,10 @@ class SpeakerEncoderTests(unittest.TestCase):
# check normalization # check normalization
output_norm = T.nn.functional.normalize(output, dim=1, p=2) output_norm = T.nn.functional.normalize(output, dim=1, p=2)
assert_diff = (output_norm - output).sum().item() assert_diff = (output_norm - output).sum().item()
assert output.type() == 'torch.FloatTensor' assert output.type() == "torch.FloatTensor"
assert abs(assert_diff) < 1e-4, f" [!] output_norm has wrong values - {assert_diff}" assert (
abs(assert_diff) < 1e-4
), f" [!] output_norm has wrong values - {assert_diff}"
# compute d for a given batch # compute d for a given batch
dummy_input = T.rand(1, 240, 80) # B x T x D dummy_input = T.rand(1, 240, 80) # B x T x D
output = model.compute_embedding(dummy_input, num_frames=160, overlap=0.5) output = model.compute_embedding(dummy_input, num_frames=160, overlap=0.5)
@ -45,23 +46,29 @@ class SpeakerEncoderTests(unittest.TestCase):
assert len(output.shape) == 2 assert len(output.shape) == 2
class GE2ELossTests(unittest.TestCase): class GE2ELossTests(unittest.TestCase):
# pylint: disable=R0201
def test_in_out(self): def test_in_out(self):
# check random input # check random input
dummy_input = T.rand(4, 5, 64) # num_speaker x num_utterance x dim dummy_input = T.rand(4, 5, 64) # num_speaker x num_utterance x dim
loss = GE2ELoss(loss_method='softmax') loss = GE2ELoss(loss_method="softmax")
output = loss.forward(dummy_input) output = loss.forward(dummy_input)
assert output.item() >= 0. assert output.item() >= 0.0
# check all zeros # check all zeros
dummy_input = T.ones(4, 5, 64) # num_speaker x num_utterance x dim dummy_input = T.ones(4, 5, 64) # num_speaker x num_utterance x dim
loss = GE2ELoss(loss_method='softmax') loss = GE2ELoss(loss_method="softmax")
output = loss.forward(dummy_input) output = loss.forward(dummy_input)
# check speaker loss with orthogonal d-vectors # check speaker loss with orthogonal d-vectors
dummy_input = T.empty(3, 64) dummy_input = T.empty(3, 64)
dummy_input = T.nn.init.orthogonal(dummy_input) dummy_input = T.nn.init.orthogonal(dummy_input)
dummy_input = T.cat([dummy_input[0].repeat(5, 1, 1).transpose(0, 1), dummy_input[1].repeat(5, 1, 1).transpose(0, 1), dummy_input[2].repeat(5, 1, 1).transpose(0, 1)]) # num_speaker x num_utterance x dim dummy_input = T.cat(
loss = GE2ELoss(loss_method='softmax') [
dummy_input[0].repeat(5, 1, 1).transpose(0, 1),
dummy_input[1].repeat(5, 1, 1).transpose(0, 1),
dummy_input[2].repeat(5, 1, 1).transpose(0, 1),
]
) # num_speaker x num_utterance x dim
loss = GE2ELoss(loss_method="softmax")
output = loss.forward(dummy_input) output = loss.forward(dummy_input)
assert output.item() < 0.005 assert output.item() < 0.005

View File

@ -5,24 +5,21 @@ import time
import traceback import traceback
import torch import torch
from torch import optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from TTS.datasets.preprocess import load_meta_data from TTS.datasets.preprocess import load_meta_data
from TTS.speaker_encoder.dataset import MyDataset from TTS.speaker_encoder.dataset import MyDataset
from TTS.speaker_encoder.generic_utils import save_best_model, save_checkpoint
from TTS.speaker_encoder.loss import GE2ELoss from TTS.speaker_encoder.loss import GE2ELoss
from TTS.speaker_encoder.model import SpeakerEncoder from TTS.speaker_encoder.model import SpeakerEncoder
from TTS.speaker_encoder.visual import plot_embeddings from TTS.speaker_encoder.visual import plot_embeddings
from TTS.speaker_encoder.generic_utils import save_best_model, save_checkpoint
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import (NoamLR, check_update, copy_config_file, from TTS.utils.generic_utils import (NoamLR, check_update, copy_config_file,
count_parameters, count_parameters,
create_experiment_folder, get_git_branch, create_experiment_folder, get_git_branch,
gradual_training_scheduler, load_config, load_config,
remove_experiment_folder, set_init_dict, remove_experiment_folder, set_init_dict)
setup_model, split_dataset)
from TTS.utils.logger import Logger from TTS.utils.logger import Logger
from TTS.utils.radam import RAdam from TTS.utils.radam import RAdam
from TTS.utils.visual import plot_alignment, plot_spectrogram
torch.backends.cudnn.enabled = True torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
@ -34,10 +31,6 @@ print(" > Number of GPUs: ", num_gpus)
def setup_loader(ap, is_val=False, verbose=False): def setup_loader(ap, is_val=False, verbose=False):
global meta_data_train
global meta_data_eval
if "meta_data_train" not in globals():
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
if is_val: if is_val:
loader = None loader = None
else: else:
@ -63,12 +56,11 @@ def train(model, criterion, optimizer, scheduler, ap, global_step):
best_loss = float('inf') best_loss = float('inf')
avg_loss = 0 avg_loss = 0
end_time = time.time() end_time = time.time()
for num_iter, data in enumerate(data_loader): for _, data in enumerate(data_loader):
start_time = time.time() start_time = time.time()
# setup input data # setup input data
inputs = data[0] inputs = data[0]
labels = data[1]
loader_time = time.time() - end_time loader_time = time.time() - end_time
global_step += 1 global_step += 1
@ -132,68 +124,11 @@ def train(model, criterion, optimizer, scheduler, ap, global_step):
return avg_loss, global_step return avg_loss, global_step
# def evaluate(model, criterion, ap, global_step, epoch):
# data_loader = setup_loader(ap, is_val=True)
# model.eval()
# epoch_time = 0
# avg_loss = 0
# print("\n > Validation")
# with torch.no_grad():
# if data_loader is not None:
# for num_iter, data in enumerate(data_loader):
# start_time = time.time()
# # setup input data
# inputs = data[0]
# labels = data[1]
# # dispatch data to GPU
# if use_cuda:
# inputs = inputs.cuda()
# # labels = labels.cuda()
# # forward pass
# outputs = model.forward(inputs)
# # loss computation
# loss = criterion(outputs.reshape(
# c.num_speakers_in_batch, outputs.shape[0] // c.num_speakers_in_batch, -1))
# step_time = time.time() - start_time
# epoch_time += step_time
# if num_iter % c.print_step == 0:
# print(
# " | > Loss: {:.5f} ".format(loss.item()),
# flush=True)
# avg_loss += float(loss.item())
# eval_figures = {
# "prediction": plot_spectrogram(const_spec, ap),
# "ground_truth": plot_spectrogram(gt_spec, ap),
# "alignment": plot_alignment(align_img)
# }
# tb_logger.tb_eval_figures(global_step, eval_figures)
# # Sample audio
# if c.model in ["Tacotron", "TacotronGST"]:
# eval_audio = ap.inv_spectrogram(const_spec.T)
# else:
# eval_audio = ap.inv_mel_spectrogram(const_spec.T)
# tb_logger.tb_eval_audios(
# global_step, {"ValAudio": eval_audio}, c.audio["sample_rate"])
# # compute average losses
# avg_loss /= (num_iter + 1)
# # Plot Validation Stats
# epoch_stats = {"GE2Eloss": avg_loss}
# tb_logger.tb_eval_stats(global_step, epoch_stats)
# return avg_loss
# FIXME: move args definition/parsing inside of main?
def main(args): # pylint: disable=redefined-outer-name def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=global-variable-undefined
global meta_data_train
global meta_data_eval
ap = AudioProcessor(**c.audio) ap = AudioProcessor(**c.audio)
model = SpeakerEncoder(input_dim=40, model = SpeakerEncoder(input_dim=40,
proj_dim=128, proj_dim=128,
@ -211,7 +146,7 @@ def main(args): # pylint: disable=redefined-outer-name
if c.reinit_layers: if c.reinit_layers:
raise RuntimeError raise RuntimeError
model.load_state_dict(checkpoint['model']) model.load_state_dict(checkpoint['model'])
except: except KeyError:
print(" > Partial model initialization.") print(" > Partial model initialization.")
model_dict = model.state_dict() model_dict = model.state_dict()
model_dict = set_init_dict(model_dict, checkpoint, c) model_dict = set_init_dict(model_dict, checkpoint, c)
@ -239,6 +174,9 @@ def main(args): # pylint: disable=redefined-outer-name
num_params = count_parameters(model) num_params = count_parameters(model)
print("\n > Model has {} parameters".format(num_params), flush=True) print("\n > Model has {} parameters".format(num_params), flush=True)
# pylint: disable=redefined-outer-name
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
global_step = args.restore_step global_step = args.restore_step
train_loss, global_step = train(model, criterion, optimizer, scheduler, ap, train_loss, global_step = train(model, criterion, optimizer, scheduler, ap,
global_step) global_step)

View File

@ -3,10 +3,12 @@ import numpy as np
import matplotlib import matplotlib
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
matplotlib.use('Agg') matplotlib.use("Agg")
colormap = np.array([ colormap = (
np.array(
[
[76, 255, 0], [76, 255, 0],
[0, 127, 70], [0, 127, 70],
[255, 0, 0], [255, 0, 0],
@ -20,11 +22,15 @@ colormap = np.array([
[33, 0, 127], [33, 0, 127],
[0, 0, 0], [0, 0, 0],
[183, 183, 183], [183, 183, 183],
], dtype=np.float) / 255 ],
dtype=np.float,
)
/ 255
)
def plot_embeddings(embeddings, num_utter_per_speaker): def plot_embeddings(embeddings, num_utter_per_speaker):
embeddings = embeddings[:10*num_utter_per_speaker] embeddings = embeddings[: 10 * num_utter_per_speaker]
model = umap.UMAP() model = umap.UMAP()
projection = model.fit_transform(embeddings) projection = model.fit_transform(embeddings)
num_speakers = embeddings.shape[0] // num_utter_per_speaker num_speakers = embeddings.shape[0] // num_utter_per_speaker
@ -32,7 +38,7 @@ def plot_embeddings(embeddings, num_utter_per_speaker):
colors = [colormap[i] for i in ground_truth] colors = [colormap[i] for i in ground_truth]
fig, ax = plt.subplots(figsize=(16, 10)) fig, ax = plt.subplots(figsize=(16, 10))
im = ax.scatter(projection[:, 0], projection[:, 1], c=colors) _ = ax.scatter(projection[:, 0], projection[:, 1], c=colors)
plt.gca().set_aspect("equal", "datalim") plt.gca().set_aspect("equal", "datalim")
plt.title("UMAP projection") plt.title("UMAP projection")
plt.tight_layout() plt.tight_layout()

View File

@ -44,6 +44,8 @@
"prenet_dropout": true, // ONLY TACOTRON2 - enable/disable dropout at prenet. "prenet_dropout": true, // ONLY TACOTRON2 - enable/disable dropout at prenet.
"use_forward_attn": true, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster. "use_forward_attn": true, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster.
"forward_attn_mask": false, "forward_attn_mask": false,
"attention_type": "original",
"attention_heads": 5,
"bidirectional_decoder": false, "bidirectional_decoder": false,
"transition_agent": false, // ONLY TACOTRON2 - enable/disable transition agent of forward attention. "transition_agent": false, // ONLY TACOTRON2 - enable/disable transition agent of forward attention.
"location_attn": false, // ONLY TACOTRON2 - enable_disable location sensitive attention. It is enabled for TACOTRON by default. "location_attn": false, // ONLY TACOTRON2 - enable_disable location sensitive attention. It is enabled for TACOTRON by default.

View File

@ -5,7 +5,8 @@ from TTS.layers.tacotron import Prenet, CBHG, Decoder, Encoder
from TTS.layers.losses import L1LossMasked from TTS.layers.losses import L1LossMasked
from TTS.utils.generic_utils import sequence_mask from TTS.utils.generic_utils import sequence_mask
#pylint: disable=unused-variable # pylint: disable=unused-variable
class PrenetTests(unittest.TestCase): class PrenetTests(unittest.TestCase):
def test_in_out(self): def test_in_out(self):
@ -49,6 +50,8 @@ class DecoderTests(unittest.TestCase):
memory_size=4, memory_size=4,
attn_windowing=False, attn_windowing=False,
attn_norm="sigmoid", attn_norm="sigmoid",
attn_K=5,
attn_type="original",
prenet_type='original', prenet_type='original',
prenet_dropout=True, prenet_dropout=True,
forward_attn=True, forward_attn=True,
@ -77,6 +80,8 @@ class DecoderTests(unittest.TestCase):
memory_size=4, memory_size=4,
attn_windowing=False, attn_windowing=False,
attn_norm="sigmoid", attn_norm="sigmoid",
attn_K=5,
attn_type="graves",
prenet_type='original', prenet_type='original',
prenet_dropout=True, prenet_dropout=True,
forward_attn=True, forward_attn=True,

View File

@ -117,7 +117,8 @@ def format_data(data):
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
ap, global_step, epoch): ap, global_step, epoch):
data_loader = setup_loader(ap, model.decoder.r, is_val=False, verbose=(epoch == 0)) data_loader = setup_loader(ap, model.decoder.r, is_val=False,
verbose=(epoch == 0))
model.train() model.train()
epoch_time = 0 epoch_time = 0
train_values = { train_values = {

View File

@ -15,7 +15,7 @@ class AudioProcessor(object):
ref_level_db=None, ref_level_db=None,
num_freq=None, num_freq=None,
power=None, power=None,
preemphasis=None, preemphasis=0.0,
signal_norm=None, signal_norm=None,
symmetric_norm=None, symmetric_norm=None,
max_norm=None, max_norm=None,
@ -48,7 +48,7 @@ class AudioProcessor(object):
self.do_trim_silence = do_trim_silence self.do_trim_silence = do_trim_silence
self.sound_norm = sound_norm self.sound_norm = sound_norm
self.n_fft, self.hop_length, self.win_length = self._stft_parameters() self.n_fft, self.hop_length, self.win_length = self._stft_parameters()
assert min_level_db ~= 0.0, " [!] min_level_db is 0" assert min_level_db != 0.0, " [!] min_level_db is 0"
members = vars(self) members = vars(self)
for key, value in members.items(): for key, value in members.items():
print(" | > {}:{}".format(key, value)) print(" | > {}:{}".format(key, value))
@ -132,12 +132,12 @@ class AudioProcessor(object):
def apply_preemphasis(self, x): def apply_preemphasis(self, x):
if self.preemphasis == 0: if self.preemphasis == 0:
raise RuntimeError(" !! Preemphasis is applied with factor 0.0. ") raise RuntimeError(" [!] Preemphasis is set 0.0.")
return scipy.signal.lfilter([1, -self.preemphasis], [1], x) return scipy.signal.lfilter([1, -self.preemphasis], [1], x)
def apply_inv_preemphasis(self, x): def apply_inv_preemphasis(self, x):
if self.preemphasis == 0: if self.preemphasis == 0:
raise RuntimeError(" !! Preemphasis is applied with factor 0.0. ") raise RuntimeError(" [!] Preemphasis is set 0.0.")
return scipy.signal.lfilter([1], [1, -self.preemphasis], x) return scipy.signal.lfilter([1], [1, -self.preemphasis], x)
def spectrogram(self, y): def spectrogram(self, y):