diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index af94675b..800ff612 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -405,19 +405,37 @@ class Xtts(BaseTTS): librosa_trim_db=None, sound_norm_refs=False, ): + # deal with multiples references + if not isinstance(audio_path, list): + audio_paths = list(audio_path) + else: + audio_paths = audio_path + + speaker_embeddings = [] + audios = [] speaker_embedding = None + for file_path in audio_paths: + audio, sr = torchaudio.load(file_path) + audio = audio[:, : sr * max_ref_length].to(self.device) + if audio.shape[0] > 1: + audio = audio.mean(0, keepdim=True) + if sound_norm_refs: + audio = (audio / torch.abs(audio).max()) * 0.75 + if librosa_trim_db is not None: + audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0] - audio, sr = torchaudio.load(audio_path) - audio = audio[:, : sr * max_ref_length].to(self.device) - if audio.shape[0] > 1: - audio = audio.mean(0, keepdim=True) - if sound_norm_refs: - audio = (audio / torch.abs(audio).max()) * 0.75 - if librosa_trim_db is not None: - audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0] + speaker_embedding = self.get_speaker_embedding(audio, sr) + speaker_embeddings.append(speaker_embedding) + audios.append(audio) + + # use a merge of all references for gpt cond latents + full_audio = torch.cat(audios, dim=-1) + gpt_cond_latents = self.get_gpt_cond_latents(full_audio, sr, length=gpt_cond_len) # [1, 1024, T] + + if speaker_embeddings: + speaker_embedding = torch.stack(speaker_embeddings) + speaker_embedding = speaker_embedding.mean(dim=0) - speaker_embedding = self.get_speaker_embedding(audio, sr) - gpt_cond_latents = self.get_gpt_cond_latents(audio, sr, length=gpt_cond_len) # [1, 1024, T] return gpt_cond_latents, speaker_embedding def synthesize(self, text, config, speaker_wav, language, **kwargs): @@ -436,11 +454,6 @@ class Xtts(BaseTTS): as latents used at inference. """ - - # Make the synthesizer happy 🥳 - if isinstance(speaker_wav, list): - speaker_wav = speaker_wav[0] - return self.inference_with_config(text, config, ref_audio_path=speaker_wav, language=language, **kwargs) def inference_with_config(self, text, config, ref_audio_path, language, **kwargs):