Create a batch for more fast inference on LSTM Speaker Encoder

This commit is contained in:
Edresson 2021-06-05 03:12:17 -03:00
parent cc192b6843
commit 14b209c7e9
1 changed files with 21 additions and 12 deletions

View File

@ -1,4 +1,5 @@
import torch
import numpy as np
from torch import nn
@ -70,24 +71,32 @@ class LSTMSpeakerEncoder(nn.Module):
d = torch.nn.functional.normalize(d, p=2, dim=1)
return d
def compute_embedding(self, x, num_frames=160, overlap=0.5):
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True):
"""
Generate embeddings for a batch of utterances
x: 1xTxD
"""
num_overlap = int(num_frames * overlap)
max_len = x.shape[1]
embed = None
cur_iter = 0
for offset in range(0, max_len, num_frames - num_overlap):
cur_iter += 1
end_offset = min(x.shape[1], offset + num_frames)
if max_len < num_frames:
num_frames = max_len
offsets = np.linspace(0, max_len-num_frames, num=num_eval)
frames_batch = []
for offset in offsets:
offset = int(offset)
end_offset = int(offset+num_frames)
frames = x[:, offset:end_offset]
if embed is None:
embed = self.inference(frames)
else:
embed += self.inference(frames)
return embed / cur_iter
frames_batch.append(frames)
frames_batch = torch.cat(frames_batch, dim=0)
embeddings = self.inference(frames_batch)
if return_mean:
embeddings = torch.mean(embeddings, dim=0, keepdim=True)
return embeddings
def batch_compute_embedding(self, x, seq_lens, num_frames=160, overlap=0.5):
"""