mirror of https://github.com/coqui-ai/TTS.git
Add support for multiples speaker references on XTTS inference
This commit is contained in:
parent
9942000c50
commit
459ad70dc8
|
@ -405,19 +405,37 @@ class Xtts(BaseTTS):
|
||||||
librosa_trim_db=None,
|
librosa_trim_db=None,
|
||||||
sound_norm_refs=False,
|
sound_norm_refs=False,
|
||||||
):
|
):
|
||||||
|
# deal with multiples references
|
||||||
|
if not isinstance(audio_path, list):
|
||||||
|
audio_paths = list(audio_path)
|
||||||
|
else:
|
||||||
|
audio_paths = audio_path
|
||||||
|
|
||||||
|
speaker_embeddings = []
|
||||||
|
audios = []
|
||||||
speaker_embedding = None
|
speaker_embedding = None
|
||||||
|
for file_path in audio_paths:
|
||||||
|
audio, sr = torchaudio.load(file_path)
|
||||||
|
audio = audio[:, : sr * max_ref_length].to(self.device)
|
||||||
|
if audio.shape[0] > 1:
|
||||||
|
audio = audio.mean(0, keepdim=True)
|
||||||
|
if sound_norm_refs:
|
||||||
|
audio = (audio / torch.abs(audio).max()) * 0.75
|
||||||
|
if librosa_trim_db is not None:
|
||||||
|
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]
|
||||||
|
|
||||||
audio, sr = torchaudio.load(audio_path)
|
speaker_embedding = self.get_speaker_embedding(audio, sr)
|
||||||
audio = audio[:, : sr * max_ref_length].to(self.device)
|
speaker_embeddings.append(speaker_embedding)
|
||||||
if audio.shape[0] > 1:
|
audios.append(audio)
|
||||||
audio = audio.mean(0, keepdim=True)
|
|
||||||
if sound_norm_refs:
|
# use a merge of all references for gpt cond latents
|
||||||
audio = (audio / torch.abs(audio).max()) * 0.75
|
full_audio = torch.cat(audios, dim=-1)
|
||||||
if librosa_trim_db is not None:
|
gpt_cond_latents = self.get_gpt_cond_latents(full_audio, sr, length=gpt_cond_len) # [1, 1024, T]
|
||||||
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]
|
|
||||||
|
if speaker_embeddings:
|
||||||
|
speaker_embedding = torch.stack(speaker_embeddings)
|
||||||
|
speaker_embedding = speaker_embedding.mean(dim=0)
|
||||||
|
|
||||||
speaker_embedding = self.get_speaker_embedding(audio, sr)
|
|
||||||
gpt_cond_latents = self.get_gpt_cond_latents(audio, sr, length=gpt_cond_len) # [1, 1024, T]
|
|
||||||
return gpt_cond_latents, speaker_embedding
|
return gpt_cond_latents, speaker_embedding
|
||||||
|
|
||||||
def synthesize(self, text, config, speaker_wav, language, **kwargs):
|
def synthesize(self, text, config, speaker_wav, language, **kwargs):
|
||||||
|
@ -436,11 +454,6 @@ class Xtts(BaseTTS):
|
||||||
as latents used at inference.
|
as latents used at inference.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Make the synthesizer happy 🥳
|
|
||||||
if isinstance(speaker_wav, list):
|
|
||||||
speaker_wav = speaker_wav[0]
|
|
||||||
|
|
||||||
return self.inference_with_config(text, config, ref_audio_path=speaker_wav, language=language, **kwargs)
|
return self.inference_with_config(text, config, ref_audio_path=speaker_wav, language=language, **kwargs)
|
||||||
|
|
||||||
def inference_with_config(self, text, config, ref_audio_path, language, **kwargs):
|
def inference_with_config(self, text, config, ref_audio_path, language, **kwargs):
|
||||||
|
|
Loading…
Reference in New Issue