mirror of https://github.com/coqui-ai/TTS.git
Load reference in 24khz to avoid issued with multiple sr references
This commit is contained in:
parent
00294ffdf6
commit
72b2bac0f8
|
@ -2,13 +2,10 @@ import os
|
|||
import random
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.data
|
||||
import torchaudio
|
||||
from torchaudio.backend.soundfile_backend import load as torchaudio_soundfile_load
|
||||
from torchaudio.backend.sox_io_backend import load as torchaudio_sox_load
|
||||
from TTS.tts.models.xtts import load_audio
|
||||
|
||||
torch.set_num_threads(1)
|
||||
|
||||
|
@ -50,31 +47,6 @@ def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate,
|
|||
return rel_clip, rel_clip.shape[-1], cond_idxs
|
||||
|
||||
|
||||
def load_audio(audiopath, sampling_rate):
|
||||
# better load setting following: https://github.com/faroit/python_audio_loading_benchmark
|
||||
if audiopath[-4:] == ".mp3":
|
||||
# it uses torchaudio with sox backend to load mp3
|
||||
audio, lsr = torchaudio_sox_load(audiopath)
|
||||
else:
|
||||
# it uses torchaudio soundfile backend to load all the others data type
|
||||
audio, lsr = torchaudio_soundfile_load(audiopath)
|
||||
|
||||
# stereo to mono if needed
|
||||
if audio.size(0) != 1:
|
||||
audio = torch.mean(audio, dim=0, keepdim=True)
|
||||
|
||||
if lsr != sampling_rate:
|
||||
audio = torchaudio.functional.resample(audio, lsr, sampling_rate)
|
||||
|
||||
# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
|
||||
# '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
|
||||
if torch.any(audio > 10) or not torch.any(audio < 0):
|
||||
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
|
||||
# clip audio invalid values
|
||||
audio.clip_(-1, 1)
|
||||
return audio
|
||||
|
||||
|
||||
class XTTSDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, config, samples, tokenizer, sample_rate, is_eval=False):
|
||||
self.config = config
|
||||
|
|
|
@ -67,6 +67,31 @@ def wav_to_mel_cloning(
|
|||
return mel
|
||||
|
||||
|
||||
def load_audio(audiopath, sampling_rate):
|
||||
# better load setting following: https://github.com/faroit/python_audio_loading_benchmark
|
||||
if audiopath[-4:] == ".mp3":
|
||||
# it uses torchaudio with sox backend to load mp3
|
||||
audio, lsr = torchaudio.backend.sox_io_backend.load(audiopath)
|
||||
else:
|
||||
# it uses torchaudio soundfile backend to load all the others data type
|
||||
audio, lsr = torchaudio.backend.soundfile_backend.load(audiopath)
|
||||
|
||||
# stereo to mono if needed
|
||||
if audio.size(0) != 1:
|
||||
audio = torch.mean(audio, dim=0, keepdim=True)
|
||||
|
||||
if lsr != sampling_rate:
|
||||
audio = torchaudio.functional.resample(audio, lsr, sampling_rate)
|
||||
|
||||
# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
|
||||
# '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
|
||||
if torch.any(audio > 10) or not torch.any(audio < 0):
|
||||
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
|
||||
# clip audio invalid values
|
||||
audio.clip_(-1, 1)
|
||||
return audio
|
||||
|
||||
|
||||
def pad_or_truncate(t, length):
|
||||
"""
|
||||
Ensure a given tensor t has a specified sequence length by either padding it with zeros or clipping it.
|
||||
|
@ -404,6 +429,7 @@ class Xtts(BaseTTS):
|
|||
max_ref_length=10,
|
||||
librosa_trim_db=None,
|
||||
sound_norm_refs=False,
|
||||
load_sr=24000,
|
||||
):
|
||||
# deal with multiples references
|
||||
if not isinstance(audio_path, list):
|
||||
|
@ -415,8 +441,9 @@ class Xtts(BaseTTS):
|
|||
audios = []
|
||||
speaker_embedding = None
|
||||
for file_path in audio_paths:
|
||||
audio, sr = torchaudio.load(file_path)
|
||||
audio = audio[:, : sr * max_ref_length].to(self.device)
|
||||
# load the audio in 24khz to avoid issued with multiple sr references
|
||||
audio = load_audio(file_path, load_sr)
|
||||
audio = audio[:, : load_sr * max_ref_length].to(self.device)
|
||||
if audio.shape[0] > 1:
|
||||
audio = audio.mean(0, keepdim=True)
|
||||
if sound_norm_refs:
|
||||
|
@ -424,13 +451,14 @@ class Xtts(BaseTTS):
|
|||
if librosa_trim_db is not None:
|
||||
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]
|
||||
|
||||
speaker_embedding = self.get_speaker_embedding(audio, sr)
|
||||
speaker_embedding = self.get_speaker_embedding(audio, load_sr)
|
||||
speaker_embeddings.append(speaker_embedding)
|
||||
|
||||
audios.append(audio)
|
||||
|
||||
# use a merge of all references for gpt cond latents
|
||||
full_audio = torch.cat(audios, dim=-1)
|
||||
gpt_cond_latents = self.get_gpt_cond_latents(full_audio, sr, length=gpt_cond_len) # [1, 1024, T]
|
||||
gpt_cond_latents = self.get_gpt_cond_latents(full_audio, load_sr, length=gpt_cond_len) # [1, 1024, T]
|
||||
|
||||
if speaker_embeddings:
|
||||
speaker_embedding = torch.stack(speaker_embeddings)
|
||||
|
|
Loading…
Reference in New Issue