From f69195739e0e2bc0cd968984934c84b7e4f1d3f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 23 Apr 2021 17:46:50 +0200 Subject: [PATCH] let speaker manager compute mean x_vector from multiple wav files --- TTS/tts/utils/speakers.py | 34 +++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index 45a55166..7bf91bc4 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -9,6 +9,8 @@ from TTS.speaker_encoder.utils.generic_utils import setup_model from TTS.utils.audio import AudioProcessor from TTS.utils.io import load_config +from typing import Union + def make_speakers_json_path(out_path): """Returns conventional speakers.json location.""" @@ -228,7 +230,7 @@ class SpeakerManager: def get_clips(self): return sorted(self.x_vectors.keys()) - def init_speaker_encoder(self, model_path: str, config_path: str): + def init_speaker_encoder(self, model_path: str, config_path: str) -> None: self.speaker_encoder_config = load_config(config_path) self.speaker_encoder = setup_model(self.speaker_encoder_config) self.speaker_encoder.load_checkpoint(config_path, model_path, True) @@ -238,14 +240,28 @@ class SpeakerManager: self.speaker_encoder_ap.do_sound_norm = True self.speaker_encoder_ap.do_trim_silence = True - def compute_x_vector_from_clip(self, wav_file): - waveform = self.speaker_encoder_ap.load_wav( - wav_file, sr=self.speaker_encoder_ap.sample_rate) - spec = self.speaker_encoder_ap.melspectrogram(waveform) - spec = torch.from_numpy(spec.T) - spec = spec.unsqueeze(0) - x_vector = self.speaker_encoder.compute_embedding(spec) - return x_vector + def compute_x_vector_from_clip(self, wav_file: Union[str, list]) -> list: + def _compute(wav_file: str): + waveform = self.speaker_encoder_ap.load_wav( + wav_file, sr=self.speaker_encoder_ap.sample_rate) + spec = self.speaker_encoder_ap.melspectrogram(waveform) + spec = torch.from_numpy(spec.T) + spec = spec.unsqueeze(0) + x_vector = self.speaker_encoder.compute_embedding(spec) + return x_vector + if isinstance(wav_file, list): + # compute the mean x_vector + x_vectors = None + for wf in wav_file: + x_vector = _compute(wf) + if x_vectors is None: + x_vectors = x_vector + else: + x_vectors += x_vector + return (x_vectors / len(wav_file))[0].tolist() + else: + x_vector = _compute(wav_file) + return x_vector[0].tolist() def compute_x_vector(self, feats): if isinstance(feats, np.ndarray):