add compute embedding for the new speaker encoder

This commit is contained in:
Edresson 2021-05-12 03:06:46 -03:00
parent 3fcc748b2e
commit 3433c2f348
2 changed files with 33 additions and 4 deletions

View File

@ -6,7 +6,7 @@ import numpy as np
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from TTS.speaker_encoder.model import SpeakerEncoder from TTS.speaker_encoder.utils.generic_utils import setup_model
from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.utils.speakers import save_speaker_mapping from TTS.tts.utils.speakers import save_speaker_mapping
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
@ -77,7 +77,7 @@ for output_file in output_files:
os.makedirs(os.path.dirname(output_file), exist_ok=True) os.makedirs(os.path.dirname(output_file), exist_ok=True)
# define Encoder model # define Encoder model
model = SpeakerEncoder(**c.model) model = setup_model(c)
model.load_state_dict(torch.load(args.model_path)["model"]) model.load_state_dict(torch.load(args.model_path)["model"])
model.eval() model.eval()
if args.use_cuda: if args.use_cuda:

View File

@ -124,7 +124,7 @@ class ResNetSpeakerEncoder(nn.Module):
nn.init.xavier_normal_(out) nn.init.xavier_normal_(out)
return out return out
def forward(self, x): def forward(self, x, training=True):
x = x.transpose(1, 2) x = x.transpose(1, 2)
with torch.no_grad(): with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
@ -140,7 +140,7 @@ class ResNetSpeakerEncoder(nn.Module):
x = self.layer3(x) x = self.layer3(x)
x = self.layer4(x) x = self.layer4(x)
x = x.reshape(x.size()[0],-1,x.size()[-1]) x = x.reshape(x.size()[0], -1, x.size()[-1])
w = self.attention(x) w = self.attention(x)
@ -154,4 +154,33 @@ class ResNetSpeakerEncoder(nn.Module):
x = x.view(x.size()[0], -1) x = x.view(x.size()[0], -1)
x = self.fc(x) x = self.fc(x)
if not training:
x = torch.nn.functional.normalize(x, p=2, dim=1)
return x return x
@torch.no_grad()
def compute_embedding(self, x, num_frames=250, overlap=0.5):
"""
Generate embeddings for a batch of utterances
x: 1xTxD
"""
num_overlap = int(num_frames * overlap)
max_len = x.shape[1]
embed = None
cur_iter = 0
for offset in range(0, max_len, num_frames - num_overlap):
cur_iter += 1
end_offset = min(x.shape[1], offset + num_frames)
# ignore slices with two or less frames, because it's can break instance normalization
if end_offset-offset <= 1:
continue
frames = x[:, offset:end_offset]
if embed is None:
embed = self.forward(frames, training=False)
else:
embed += self.forward(frames, training=False)
return embed / cur_iter