mirror of https://github.com/coqui-ai/TTS.git
add batched speaker encoder inference
This commit is contained in:
parent
825734a3a9
commit
208bb0f0ee
|
@ -2,7 +2,6 @@ import argparse
|
|||
import glob
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
|
|
|
@ -174,15 +174,17 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
|
||||
offsets = np.linspace(0, max_len-num_frames, num=num_eval)
|
||||
|
||||
embeddings = []
|
||||
frames_batch = []
|
||||
for offset in offsets:
|
||||
offset = int(offset)
|
||||
end_offset = int(offset+num_frames)
|
||||
frames = x[:, offset:end_offset]
|
||||
embed = self.forward(frames, l2_norm=True)
|
||||
embeddings.append(embed)
|
||||
frames_batch.append(frames)
|
||||
|
||||
frames_batch = torch.cat(frames_batch, dim=0)
|
||||
embeddings = self.forward(frames_batch, l2_norm=True)
|
||||
|
||||
embeddings = torch.stack(embeddings)
|
||||
if return_mean:
|
||||
embeddings = torch.mean(embeddings, dim=0)
|
||||
embeddings = torch.mean(embeddings, dim=0, keepdim=True)
|
||||
|
||||
return embeddings
|
||||
|
|
Loading…
Reference in New Issue