mirror of https://github.com/coqui-ai/TTS.git
Create a batch for more fast inference on LSTM Speaker Encoder
This commit is contained in:
parent
cc192b6843
commit
14b209c7e9
|
@ -1,4 +1,5 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
|
||||
|
||||
|
@ -70,24 +71,32 @@ class LSTMSpeakerEncoder(nn.Module):
|
|||
d = torch.nn.functional.normalize(d, p=2, dim=1)
|
||||
return d
|
||||
|
||||
def compute_embedding(self, x, num_frames=160, overlap=0.5):
|
||||
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True):
|
||||
"""
|
||||
Generate embeddings for a batch of utterances
|
||||
x: 1xTxD
|
||||
"""
|
||||
num_overlap = int(num_frames * overlap)
|
||||
max_len = x.shape[1]
|
||||
embed = None
|
||||
cur_iter = 0
|
||||
for offset in range(0, max_len, num_frames - num_overlap):
|
||||
cur_iter += 1
|
||||
end_offset = min(x.shape[1], offset + num_frames)
|
||||
|
||||
if max_len < num_frames:
|
||||
num_frames = max_len
|
||||
|
||||
offsets = np.linspace(0, max_len-num_frames, num=num_eval)
|
||||
|
||||
frames_batch = []
|
||||
for offset in offsets:
|
||||
offset = int(offset)
|
||||
end_offset = int(offset+num_frames)
|
||||
frames = x[:, offset:end_offset]
|
||||
if embed is None:
|
||||
embed = self.inference(frames)
|
||||
else:
|
||||
embed += self.inference(frames)
|
||||
return embed / cur_iter
|
||||
frames_batch.append(frames)
|
||||
|
||||
frames_batch = torch.cat(frames_batch, dim=0)
|
||||
embeddings = self.inference(frames_batch)
|
||||
|
||||
if return_mean:
|
||||
embeddings = torch.mean(embeddings, dim=0, keepdim=True)
|
||||
|
||||
return embeddings
|
||||
|
||||
def batch_compute_embedding(self, x, seq_lens, num_frames=160, overlap=0.5):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue