add batched speaker encoder inference

This commit is contained in:
Edresson 2021-05-27 20:01:00 -03:00
parent 825734a3a9
commit 208bb0f0ee
2 changed files with 7 additions and 6 deletions

View File

@ -2,7 +2,6 @@ import argparse
import glob
import os
import numpy as np
import torch
from tqdm import tqdm

View File

@ -174,15 +174,17 @@ class ResNetSpeakerEncoder(nn.Module):
offsets = np.linspace(0, max_len-num_frames, num=num_eval)
embeddings = []
frames_batch = []
for offset in offsets:
offset = int(offset)
end_offset = int(offset+num_frames)
frames = x[:, offset:end_offset]
embed = self.forward(frames, l2_norm=True)
embeddings.append(embed)
frames_batch.append(frames)
frames_batch = torch.cat(frames_batch, dim=0)
embeddings = self.forward(frames_batch, l2_norm=True)
embeddings = torch.stack(embeddings)
if return_mean:
embeddings = torch.mean(embeddings, dim=0)
embeddings = torch.mean(embeddings, dim=0, keepdim=True)
return embeddings