mirror of https://github.com/coqui-ai/TTS.git
let speaker manager compute mean x_vector from multiple wav files
This commit is contained in:
parent
179722e3a7
commit
f69195739e
|
@ -9,6 +9,8 @@ from TTS.speaker_encoder.utils.generic_utils import setup_model
|
|||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.io import load_config
|
||||
|
||||
from typing import Union
|
||||
|
||||
|
||||
def make_speakers_json_path(out_path):
|
||||
"""Returns conventional speakers.json location."""
|
||||
|
@ -228,7 +230,7 @@ class SpeakerManager:
|
|||
def get_clips(self):
|
||||
return sorted(self.x_vectors.keys())
|
||||
|
||||
def init_speaker_encoder(self, model_path: str, config_path: str):
|
||||
def init_speaker_encoder(self, model_path: str, config_path: str) -> None:
|
||||
self.speaker_encoder_config = load_config(config_path)
|
||||
self.speaker_encoder = setup_model(self.speaker_encoder_config)
|
||||
self.speaker_encoder.load_checkpoint(config_path, model_path, True)
|
||||
|
@ -238,14 +240,28 @@ class SpeakerManager:
|
|||
self.speaker_encoder_ap.do_sound_norm = True
|
||||
self.speaker_encoder_ap.do_trim_silence = True
|
||||
|
||||
def compute_x_vector_from_clip(self, wav_file):
|
||||
waveform = self.speaker_encoder_ap.load_wav(
|
||||
wav_file, sr=self.speaker_encoder_ap.sample_rate)
|
||||
spec = self.speaker_encoder_ap.melspectrogram(waveform)
|
||||
spec = torch.from_numpy(spec.T)
|
||||
spec = spec.unsqueeze(0)
|
||||
x_vector = self.speaker_encoder.compute_embedding(spec)
|
||||
return x_vector
|
||||
def compute_x_vector_from_clip(self, wav_file: Union[str, list]) -> list:
|
||||
def _compute(wav_file: str):
|
||||
waveform = self.speaker_encoder_ap.load_wav(
|
||||
wav_file, sr=self.speaker_encoder_ap.sample_rate)
|
||||
spec = self.speaker_encoder_ap.melspectrogram(waveform)
|
||||
spec = torch.from_numpy(spec.T)
|
||||
spec = spec.unsqueeze(0)
|
||||
x_vector = self.speaker_encoder.compute_embedding(spec)
|
||||
return x_vector
|
||||
if isinstance(wav_file, list):
|
||||
# compute the mean x_vector
|
||||
x_vectors = None
|
||||
for wf in wav_file:
|
||||
x_vector = _compute(wf)
|
||||
if x_vectors is None:
|
||||
x_vectors = x_vector
|
||||
else:
|
||||
x_vectors += x_vector
|
||||
return (x_vectors / len(wav_file))[0].tolist()
|
||||
else:
|
||||
x_vector = _compute(wav_file)
|
||||
return x_vector[0].tolist()
|
||||
|
||||
def compute_x_vector(self, feats):
|
||||
if isinstance(feats, np.ndarray):
|
||||
|
|
Loading…
Reference in New Issue