mirror of https://github.com/coqui-ai/TTS.git
add batched speaker encoder inference
This commit is contained in:
parent
825734a3a9
commit
208bb0f0ee
|
@ -2,7 +2,6 @@ import argparse
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
|
@ -174,15 +174,17 @@ class ResNetSpeakerEncoder(nn.Module):
|
||||||
|
|
||||||
offsets = np.linspace(0, max_len-num_frames, num=num_eval)
|
offsets = np.linspace(0, max_len-num_frames, num=num_eval)
|
||||||
|
|
||||||
embeddings = []
|
frames_batch = []
|
||||||
for offset in offsets:
|
for offset in offsets:
|
||||||
offset = int(offset)
|
offset = int(offset)
|
||||||
end_offset = int(offset+num_frames)
|
end_offset = int(offset+num_frames)
|
||||||
frames = x[:, offset:end_offset]
|
frames = x[:, offset:end_offset]
|
||||||
embed = self.forward(frames, l2_norm=True)
|
frames_batch.append(frames)
|
||||||
embeddings.append(embed)
|
|
||||||
|
frames_batch = torch.cat(frames_batch, dim=0)
|
||||||
|
embeddings = self.forward(frames_batch, l2_norm=True)
|
||||||
|
|
||||||
embeddings = torch.stack(embeddings)
|
|
||||||
if return_mean:
|
if return_mean:
|
||||||
embeddings = torch.mean(embeddings, dim=0)
|
embeddings = torch.mean(embeddings, dim=0, keepdim=True)
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
Loading…
Reference in New Issue