let speaker manager compute mean x_vector from multiple wav files

This commit is contained in:
Eren Gölge 2021-04-23 17:46:50 +02:00
parent 179722e3a7
commit f69195739e
1 changed files with 25 additions and 9 deletions

View File

@ -9,6 +9,8 @@ from TTS.speaker_encoder.utils.generic_utils import setup_model
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_config from TTS.utils.io import load_config
from typing import Union
def make_speakers_json_path(out_path): def make_speakers_json_path(out_path):
"""Returns conventional speakers.json location.""" """Returns conventional speakers.json location."""
@ -228,7 +230,7 @@ class SpeakerManager:
def get_clips(self): def get_clips(self):
return sorted(self.x_vectors.keys()) return sorted(self.x_vectors.keys())
def init_speaker_encoder(self, model_path: str, config_path: str): def init_speaker_encoder(self, model_path: str, config_path: str) -> None:
self.speaker_encoder_config = load_config(config_path) self.speaker_encoder_config = load_config(config_path)
self.speaker_encoder = setup_model(self.speaker_encoder_config) self.speaker_encoder = setup_model(self.speaker_encoder_config)
self.speaker_encoder.load_checkpoint(config_path, model_path, True) self.speaker_encoder.load_checkpoint(config_path, model_path, True)
@ -238,7 +240,8 @@ class SpeakerManager:
self.speaker_encoder_ap.do_sound_norm = True self.speaker_encoder_ap.do_sound_norm = True
self.speaker_encoder_ap.do_trim_silence = True self.speaker_encoder_ap.do_trim_silence = True
def compute_x_vector_from_clip(self, wav_file): def compute_x_vector_from_clip(self, wav_file: Union[str, list]) -> list:
def _compute(wav_file: str):
waveform = self.speaker_encoder_ap.load_wav( waveform = self.speaker_encoder_ap.load_wav(
wav_file, sr=self.speaker_encoder_ap.sample_rate) wav_file, sr=self.speaker_encoder_ap.sample_rate)
spec = self.speaker_encoder_ap.melspectrogram(waveform) spec = self.speaker_encoder_ap.melspectrogram(waveform)
@ -246,6 +249,19 @@ class SpeakerManager:
spec = spec.unsqueeze(0) spec = spec.unsqueeze(0)
x_vector = self.speaker_encoder.compute_embedding(spec) x_vector = self.speaker_encoder.compute_embedding(spec)
return x_vector return x_vector
if isinstance(wav_file, list):
# compute the mean x_vector
x_vectors = None
for wf in wav_file:
x_vector = _compute(wf)
if x_vectors is None:
x_vectors = x_vector
else:
x_vectors += x_vector
return (x_vectors / len(wav_file))[0].tolist()
else:
x_vector = _compute(wav_file)
return x_vector[0].tolist()
def compute_x_vector(self, feats): def compute_x_vector(self, feats):
if isinstance(feats, np.ndarray): if isinstance(feats, np.ndarray):