mirror of https://github.com/coqui-ai/TTS.git
Merge pull request #581 from Edresson/dev
Compute speaker embeddings in batch for the LSTM Speaker Encoder and Compute embeddings/ finding chars using config file.
This commit is contained in:
commit
30eed347b6
|
@ -1,80 +1,46 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import glob
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from TTS.config import BaseDatasetConfig, load_config
|
from argparse import RawTextHelpFormatter
|
||||||
from TTS.speaker_encoder.utils.generic_utils import setup_model
|
from TTS.config import load_config
|
||||||
from TTS.tts.datasets import load_meta_data
|
from TTS.tts.datasets import load_meta_data
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.utils.audio import AudioProcessor
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description='Compute embedding vectors for each wav file in a dataset. If "target_dataset" is defined, it generates "speakers.json" necessary for training a multi-speaker model.'
|
description="""Compute embedding vectors for each wav file in a dataset.\n\n"""
|
||||||
|
"""
|
||||||
|
Example runs:
|
||||||
|
python TTS/bin/compute_embeddings.py speaker_encoder_model.pth.tar speaker_encoder_config.json dataset_config.json embeddings_output_path/
|
||||||
|
""",
|
||||||
|
formatter_class=RawTextHelpFormatter,
|
||||||
)
|
)
|
||||||
parser.add_argument("model_path", type=str, help="Path to model outputs (checkpoint, tensorboard etc.).")
|
parser.add_argument("model_path", type=str, help="Path to model checkpoint file.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"config_path",
|
"config_path",
|
||||||
type=str,
|
type=str,
|
||||||
help="Path to config file for training.",
|
help="Path to model config file.",
|
||||||
)
|
)
|
||||||
parser.add_argument("data_path", type=str, help="Data path for wav files - directory or CSV file")
|
|
||||||
parser.add_argument("output_path", type=str, help="path for output speakers.json.")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--target_dataset",
|
"config_dataset_path",
|
||||||
type=str,
|
type=str,
|
||||||
default="",
|
help="Path to dataset config file.",
|
||||||
help="Target dataset to pick a processor from TTS.tts.dataset.preprocess. Necessary to create a speakers.json file.",
|
|
||||||
)
|
)
|
||||||
|
parser.add_argument("output_path", type=str, help="path for output speakers.json and/or speakers.npy.")
|
||||||
parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True)
|
parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True)
|
||||||
parser.add_argument("--separator", type=str, help="Separator used in file if CSV is passed for data_path", default="|")
|
parser.add_argument("--eval", type=bool, help="compute eval.", default=True)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
c_dataset = load_config(args.config_dataset_path)
|
||||||
|
|
||||||
c = load_config(args.config_path)
|
meta_data_train, meta_data_eval = load_meta_data(c_dataset.datasets, eval_split=args.eval)
|
||||||
ap = AudioProcessor(**c["audio"])
|
wav_files = meta_data_train + meta_data_eval
|
||||||
|
|
||||||
data_path = args.data_path
|
speaker_manager = SpeakerManager(encoder_model_path=args.model_path, encoder_config_path=args.config_path, use_cuda=args.use_cuda)
|
||||||
split_ext = os.path.splitext(data_path)
|
|
||||||
sep = args.separator
|
|
||||||
|
|
||||||
if args.target_dataset != "":
|
|
||||||
# if target dataset is defined
|
|
||||||
dataset_config = [
|
|
||||||
BaseDatasetConfig(name=args.target_dataset, path=args.data_path, meta_file_train=None, meta_file_val=None),
|
|
||||||
]
|
|
||||||
wav_files, _ = load_meta_data(dataset_config, eval_split=False)
|
|
||||||
else:
|
|
||||||
# if target dataset is not defined
|
|
||||||
if len(split_ext) > 0 and split_ext[1].lower() == ".csv":
|
|
||||||
# Parse CSV
|
|
||||||
print(f"CSV file: {data_path}")
|
|
||||||
with open(data_path) as f:
|
|
||||||
wav_path = os.path.join(os.path.dirname(data_path), "wavs")
|
|
||||||
wav_files = []
|
|
||||||
print(f"Separator is: {sep}")
|
|
||||||
for line in f:
|
|
||||||
components = line.split(sep)
|
|
||||||
if len(components) != 2:
|
|
||||||
print("Invalid line")
|
|
||||||
continue
|
|
||||||
wav_file = os.path.join(wav_path, components[0] + ".wav")
|
|
||||||
# print(f'wav_file: {wav_file}')
|
|
||||||
if os.path.exists(wav_file):
|
|
||||||
wav_files.append(wav_file)
|
|
||||||
print(f"Count of wavs imported: {len(wav_files)}")
|
|
||||||
else:
|
|
||||||
# Parse all wav files in data_path
|
|
||||||
wav_files = glob.glob(data_path + "/**/*.wav", recursive=True)
|
|
||||||
|
|
||||||
# define Encoder model
|
|
||||||
model = setup_model(c)
|
|
||||||
model.load_state_dict(torch.load(args.model_path)["model"])
|
|
||||||
model.eval()
|
|
||||||
if args.use_cuda:
|
|
||||||
model.cuda()
|
|
||||||
|
|
||||||
# compute speaker embeddings
|
# compute speaker embeddings
|
||||||
speaker_mapping = {}
|
speaker_mapping = {}
|
||||||
|
@ -85,27 +51,24 @@ for idx, wav_file in enumerate(tqdm(wav_files)):
|
||||||
else:
|
else:
|
||||||
speaker_name = None
|
speaker_name = None
|
||||||
|
|
||||||
mel_spec = ap.melspectrogram(ap.load_wav(wav_file, sr=ap.sample_rate)).T
|
# extract the embedding
|
||||||
mel_spec = torch.FloatTensor(mel_spec[None, :, :])
|
embedd = speaker_manager.compute_d_vector_from_clip(wav_file)
|
||||||
if args.use_cuda:
|
|
||||||
mel_spec = mel_spec.cuda()
|
|
||||||
embedd = model.compute_embedding(mel_spec)
|
|
||||||
embedd = embedd.detach().cpu().numpy()
|
|
||||||
|
|
||||||
# create speaker_mapping if target dataset is defined
|
# create speaker_mapping if target dataset is defined
|
||||||
wav_file_name = os.path.basename(wav_file)
|
wav_file_name = os.path.basename(wav_file)
|
||||||
speaker_mapping[wav_file_name] = {}
|
speaker_mapping[wav_file_name] = {}
|
||||||
speaker_mapping[wav_file_name]["name"] = speaker_name
|
speaker_mapping[wav_file_name]["name"] = speaker_name
|
||||||
speaker_mapping[wav_file_name]["embedding"] = embedd.flatten().tolist()
|
speaker_mapping[wav_file_name]["embedding"] = embedd
|
||||||
|
|
||||||
if speaker_mapping:
|
if speaker_mapping:
|
||||||
# save speaker_mapping if target dataset is defined
|
# save speaker_mapping if target dataset is defined
|
||||||
if ".json" not in args.output_path:
|
if '.json' not in args.output_path:
|
||||||
mapping_file_path = os.path.join(args.output_path, "speakers.json")
|
mapping_file_path = os.path.join(args.output_path, "speakers.json")
|
||||||
else:
|
else:
|
||||||
mapping_file_path = args.output_path
|
mapping_file_path = args.output_path
|
||||||
|
|
||||||
os.makedirs(os.path.dirname(mapping_file_path), exist_ok=True)
|
os.makedirs(os.path.dirname(mapping_file_path), exist_ok=True)
|
||||||
speaker_manager = SpeakerManager()
|
|
||||||
# pylint: disable=W0212
|
# pylint: disable=W0212
|
||||||
speaker_manager._save_json(mapping_file_path, speaker_mapping)
|
speaker_manager._save_json(mapping_file_path, speaker_mapping)
|
||||||
print("Speaker embeddings saved at:", mapping_file_path)
|
print("Speaker embeddings saved at:", mapping_file_path)
|
||||||
|
|
|
@ -227,7 +227,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
ap = AudioProcessor(**c.audio)
|
ap = AudioProcessor(**c.audio)
|
||||||
|
|
||||||
# load data instances
|
# load data instances
|
||||||
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
|
meta_data_train, meta_data_eval = load_meta_data(c.datasets, eval_split=args.eval)
|
||||||
|
|
||||||
# use eval and training partitions
|
# use eval and training partitions
|
||||||
meta_data = meta_data_train + meta_data_eval
|
meta_data = meta_data_train + meta_data_eval
|
||||||
|
@ -271,6 +271,7 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--debug", default=False, action="store_true", help="Save audio files for debug")
|
parser.add_argument("--debug", default=False, action="store_true", help="Save audio files for debug")
|
||||||
parser.add_argument("--save_audio", default=False, action="store_true", help="Save audio files")
|
parser.add_argument("--save_audio", default=False, action="store_true", help="Save audio files")
|
||||||
parser.add_argument("--quantized", action="store_true", help="Save quantized audio files")
|
parser.add_argument("--quantized", action="store_true", help="Save quantized audio files")
|
||||||
|
parser.add_argument("--eval", type=bool, help="compute eval.", default=True)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
c = load_config(args.config_path)
|
c = load_config(args.config_path)
|
||||||
|
|
|
@ -1,41 +1,42 @@
|
||||||
"""Find all the unique characters in a dataset"""
|
"""Find all the unique characters in a dataset"""
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
|
||||||
from argparse import RawTextHelpFormatter
|
from argparse import RawTextHelpFormatter
|
||||||
|
from TTS.tts.datasets import load_meta_data
|
||||||
from TTS.tts.datasets import _get_preprocessor_by_name
|
from TTS.config import load_config
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# pylint: disable=bad-option-value
|
# pylint: disable=bad-option-value
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="""Find all the unique characters or phonemes in a dataset.\n\n"""
|
description="""Find all the unique characters or phonemes in a dataset.\n\n"""
|
||||||
"""Target dataset must be defined in TTS.tts.datasets.formatters\n\n"""
|
|
||||||
"""
|
"""
|
||||||
Example runs:
|
Example runs:
|
||||||
|
|
||||||
python TTS/bin/find_unique_chars.py --dataset ljspeech --meta_file /path/to/LJSpeech/metadata.csv
|
python TTS/bin/find_unique_chars.py --config_path config.json
|
||||||
""",
|
""",
|
||||||
formatter_class=RawTextHelpFormatter,
|
formatter_class=RawTextHelpFormatter,
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dataset", type=str, default="", help="One of the target dataset names in TTS.tts.datasets.formatters."
|
"--config_path", type=str, help="Path to dataset config file.", required=True
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument("--meta_file", type=str, default=None, help="Path to the transcriptions file of the dataset.")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
preprocessor = _get_preprocessor_by_name(args.dataset)
|
c = load_config(args.config_path)
|
||||||
items = preprocessor(os.path.dirname(args.meta_file), os.path.basename(args.meta_file))
|
|
||||||
|
# load all datasets
|
||||||
|
train_items, eval_items = load_meta_data(c.datasets, eval_split=True)
|
||||||
|
items = train_items + eval_items
|
||||||
|
|
||||||
texts = "".join(item[0] for item in items)
|
texts = "".join(item[0] for item in items)
|
||||||
chars = set(texts)
|
chars = set(texts)
|
||||||
lower_chars = filter(lambda c: c.islower(), chars)
|
lower_chars = filter(lambda c: c.islower(), chars)
|
||||||
|
chars_force_lower = [c.lower() for c in chars]
|
||||||
|
chars_force_lower = set(chars_force_lower)
|
||||||
|
|
||||||
print(f" > Number of unique characters: {len(chars)}")
|
print(f" > Number of unique characters: {len(chars)}")
|
||||||
print(f" > Unique characters: {''.join(sorted(chars))}")
|
print(f" > Unique characters: {''.join(sorted(chars))}")
|
||||||
print(f" > Unique lower characters: {''.join(sorted(lower_chars))}")
|
print(f" > Unique lower characters: {''.join(sorted(lower_chars))}")
|
||||||
|
print(f" > Unique all forced to lower characters: {''.join(sorted(chars_force_lower))}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
|
import numpy as np
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
@ -70,24 +71,32 @@ class LSTMSpeakerEncoder(nn.Module):
|
||||||
d = torch.nn.functional.normalize(d, p=2, dim=1)
|
d = torch.nn.functional.normalize(d, p=2, dim=1)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def compute_embedding(self, x, num_frames=160, overlap=0.5):
|
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True):
|
||||||
"""
|
"""
|
||||||
Generate embeddings for a batch of utterances
|
Generate embeddings for a batch of utterances
|
||||||
x: 1xTxD
|
x: 1xTxD
|
||||||
"""
|
"""
|
||||||
num_overlap = int(num_frames * overlap)
|
|
||||||
max_len = x.shape[1]
|
max_len = x.shape[1]
|
||||||
embed = None
|
|
||||||
cur_iter = 0
|
if max_len < num_frames:
|
||||||
for offset in range(0, max_len, num_frames - num_overlap):
|
num_frames = max_len
|
||||||
cur_iter += 1
|
|
||||||
end_offset = min(x.shape[1], offset + num_frames)
|
offsets = np.linspace(0, max_len-num_frames, num=num_eval)
|
||||||
|
|
||||||
|
frames_batch = []
|
||||||
|
for offset in offsets:
|
||||||
|
offset = int(offset)
|
||||||
|
end_offset = int(offset+num_frames)
|
||||||
frames = x[:, offset:end_offset]
|
frames = x[:, offset:end_offset]
|
||||||
if embed is None:
|
frames_batch.append(frames)
|
||||||
embed = self.inference(frames)
|
|
||||||
else:
|
frames_batch = torch.cat(frames_batch, dim=0)
|
||||||
embed += self.inference(frames)
|
embeddings = self.inference(frames_batch)
|
||||||
return embed / cur_iter
|
|
||||||
|
if return_mean:
|
||||||
|
embeddings = torch.mean(embeddings, dim=0, keepdim=True)
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
def batch_compute_embedding(self, x, seq_lens, num_frames=160, overlap=0.5):
|
def batch_compute_embedding(self, x, seq_lens, num_frames=160, overlap=0.5):
|
||||||
"""
|
"""
|
||||||
|
@ -110,9 +119,11 @@ class LSTMSpeakerEncoder(nn.Module):
|
||||||
return embed / num_iters
|
return embed / num_iters
|
||||||
|
|
||||||
# pylint: disable=unused-argument, redefined-builtin
|
# pylint: disable=unused-argument, redefined-builtin
|
||||||
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False):
|
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
|
if use_cuda:
|
||||||
|
self.cuda()
|
||||||
if eval:
|
if eval:
|
||||||
self.eval()
|
self.eval()
|
||||||
assert not self.training
|
assert not self.training
|
||||||
|
|
|
@ -199,3 +199,12 @@ class ResNetSpeakerEncoder(nn.Module):
|
||||||
embeddings = torch.mean(embeddings, dim=0, keepdim=True)
|
embeddings = torch.mean(embeddings, dim=0, keepdim=True)
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
|
||||||
|
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
|
self.load_state_dict(state["model"])
|
||||||
|
if use_cuda:
|
||||||
|
self.cuda()
|
||||||
|
if eval:
|
||||||
|
self.eval()
|
||||||
|
assert not self.training
|
|
@ -202,16 +202,20 @@ def libri_tts(root_path, meta_files=None):
|
||||||
items = []
|
items = []
|
||||||
if meta_files is None:
|
if meta_files is None:
|
||||||
meta_files = glob(f"{root_path}/**/*trans.tsv", recursive=True)
|
meta_files = glob(f"{root_path}/**/*trans.tsv", recursive=True)
|
||||||
|
else:
|
||||||
|
if isinstance(meta_files, str):
|
||||||
|
meta_files = [os.path.join(root_path, meta_files)]
|
||||||
|
|
||||||
for meta_file in meta_files:
|
for meta_file in meta_files:
|
||||||
_meta_file = os.path.basename(meta_file).split(".")[0]
|
_meta_file = os.path.basename(meta_file).split(".")[0]
|
||||||
speaker_name = _meta_file.split("_")[0]
|
|
||||||
chapter_id = _meta_file.split("_")[1]
|
|
||||||
_root_path = os.path.join(root_path, f"{speaker_name}/{chapter_id}")
|
|
||||||
with open(meta_file, "r") as ttf:
|
with open(meta_file, "r") as ttf:
|
||||||
for line in ttf:
|
for line in ttf:
|
||||||
cols = line.split("\t")
|
cols = line.split("\t")
|
||||||
wav_file = os.path.join(_root_path, cols[0] + ".wav")
|
file_name = cols[0]
|
||||||
text = cols[1]
|
speaker_name, chapter_id, *_ = cols[0].split("_")
|
||||||
|
_root_path = os.path.join(root_path, f"{speaker_name}/{chapter_id}")
|
||||||
|
wav_file = os.path.join(_root_path, file_name + ".wav")
|
||||||
|
text = cols[2]
|
||||||
items.append([text, wav_file, "LTTS_" + speaker_name])
|
items.append([text, wav_file, "LTTS_" + speaker_name])
|
||||||
for item in items:
|
for item in items:
|
||||||
assert os.path.exists(item[1]), f" [!] wav files don't exist - {item[1]}"
|
assert os.path.exists(item[1]), f" [!] wav files don't exist - {item[1]}"
|
||||||
|
@ -287,6 +291,17 @@ def vctk_slim(root_path, meta_files=None, wavs_path="wav48"):
|
||||||
|
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
def mls(root_path, meta_files=None):
|
||||||
|
"""http://www.openslr.org/94/"""
|
||||||
|
items = []
|
||||||
|
with open(os.path.join(root_path, meta_files), "r") as meta:
|
||||||
|
for line in meta:
|
||||||
|
file, text = line.split('\t')
|
||||||
|
text = text[:-1]
|
||||||
|
speaker, book, *_ = file.split('_')
|
||||||
|
wav_file = os.path.join(root_path, os.path.dirname(meta_files), 'audio', speaker, book, file + ".wav")
|
||||||
|
items.append([text, wav_file, "MLS_" + speaker])
|
||||||
|
return items
|
||||||
|
|
||||||
# ======================================== VOX CELEB ===========================================
|
# ======================================== VOX CELEB ===========================================
|
||||||
def voxceleb2(root_path, meta_file=None):
|
def voxceleb2(root_path, meta_file=None):
|
||||||
|
|
|
@ -59,6 +59,7 @@ class SpeakerManager:
|
||||||
speaker_id_file_path: str = "",
|
speaker_id_file_path: str = "",
|
||||||
encoder_model_path: str = "",
|
encoder_model_path: str = "",
|
||||||
encoder_config_path: str = "",
|
encoder_config_path: str = "",
|
||||||
|
use_cuda: bool = False,
|
||||||
):
|
):
|
||||||
|
|
||||||
self.data_items = []
|
self.data_items = []
|
||||||
|
@ -67,6 +68,7 @@ class SpeakerManager:
|
||||||
self.clip_ids = []
|
self.clip_ids = []
|
||||||
self.speaker_encoder = None
|
self.speaker_encoder = None
|
||||||
self.speaker_encoder_ap = None
|
self.speaker_encoder_ap = None
|
||||||
|
self.use_cuda = use_cuda
|
||||||
|
|
||||||
if data_items:
|
if data_items:
|
||||||
self.speaker_ids, self.speaker_names, _ = self.parse_speakers_from_data(self.data_items)
|
self.speaker_ids, self.speaker_names, _ = self.parse_speakers_from_data(self.data_items)
|
||||||
|
@ -222,11 +224,11 @@ class SpeakerManager:
|
||||||
"""
|
"""
|
||||||
self.speaker_encoder_config = load_config(config_path)
|
self.speaker_encoder_config = load_config(config_path)
|
||||||
self.speaker_encoder = setup_model(self.speaker_encoder_config)
|
self.speaker_encoder = setup_model(self.speaker_encoder_config)
|
||||||
self.speaker_encoder.load_checkpoint(config_path, model_path, True)
|
self.speaker_encoder.load_checkpoint(config_path, model_path, eval=True, use_cuda=self.use_cuda)
|
||||||
self.speaker_encoder_ap = AudioProcessor(**self.speaker_encoder_config.audio)
|
self.speaker_encoder_ap = AudioProcessor(**self.speaker_encoder_config.audio)
|
||||||
# normalize the input audio level and trim silences
|
# normalize the input audio level and trim silences
|
||||||
self.speaker_encoder_ap.do_sound_norm = True
|
# self.speaker_encoder_ap.do_sound_norm = True
|
||||||
self.speaker_encoder_ap.do_trim_silence = True
|
# self.speaker_encoder_ap.do_trim_silence = True
|
||||||
|
|
||||||
def compute_d_vector_from_clip(self, wav_file: Union[str, list]) -> list:
|
def compute_d_vector_from_clip(self, wav_file: Union[str, list]) -> list:
|
||||||
"""Compute a d_vector from a given audio file.
|
"""Compute a d_vector from a given audio file.
|
||||||
|
@ -242,6 +244,8 @@ class SpeakerManager:
|
||||||
waveform = self.speaker_encoder_ap.load_wav(wav_file, sr=self.speaker_encoder_ap.sample_rate)
|
waveform = self.speaker_encoder_ap.load_wav(wav_file, sr=self.speaker_encoder_ap.sample_rate)
|
||||||
spec = self.speaker_encoder_ap.melspectrogram(waveform)
|
spec = self.speaker_encoder_ap.melspectrogram(waveform)
|
||||||
spec = torch.from_numpy(spec.T)
|
spec = torch.from_numpy(spec.T)
|
||||||
|
if self.use_cuda:
|
||||||
|
spec = spec.cuda()
|
||||||
spec = spec.unsqueeze(0)
|
spec = spec.unsqueeze(0)
|
||||||
d_vector = self.speaker_encoder.compute_embedding(spec)
|
d_vector = self.speaker_encoder.compute_embedding(spec)
|
||||||
return d_vector
|
return d_vector
|
||||||
|
@ -272,6 +276,8 @@ class SpeakerManager:
|
||||||
feats = torch.from_numpy(feats)
|
feats = torch.from_numpy(feats)
|
||||||
if feats.ndim == 2:
|
if feats.ndim == 2:
|
||||||
feats = feats.unsqueeze(0)
|
feats = feats.unsqueeze(0)
|
||||||
|
if self.use_cuda:
|
||||||
|
feats = feats.cuda()
|
||||||
return self.speaker_encoder.compute_embedding(feats)
|
return self.speaker_encoder.compute_embedding(feats)
|
||||||
|
|
||||||
def run_umap(self):
|
def run_umap(self):
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -35,7 +35,7 @@ class LSTMSpeakerEncoderTests(unittest.TestCase):
|
||||||
assert abs(assert_diff) < 1e-4, f" [!] output_norm has wrong values - {assert_diff}"
|
assert abs(assert_diff) < 1e-4, f" [!] output_norm has wrong values - {assert_diff}"
|
||||||
# compute d for a given batch
|
# compute d for a given batch
|
||||||
dummy_input = T.rand(1, 240, 80) # B x T x D
|
dummy_input = T.rand(1, 240, 80) # B x T x D
|
||||||
output = model.compute_embedding(dummy_input, num_frames=160, overlap=0.5)
|
output = model.compute_embedding(dummy_input, num_frames=160, num_eval=5)
|
||||||
assert output.shape[0] == 1
|
assert output.shape[0] == 1
|
||||||
assert output.shape[1] == 256
|
assert output.shape[1] == 256
|
||||||
assert len(output.shape) == 2
|
assert len(output.shape) == 2
|
||||||
|
|
Loading…
Reference in New Issue