mirror of https://github.com/coqui-ai/TTS.git
Add features to get_conditioning_latents
This commit is contained in:
parent
c1133724a1
commit
d4e08c8d6c
|
@ -5,6 +5,7 @@ from dataclasses import dataclass
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
import librosa
|
||||
from coqpit import Coqpit
|
||||
|
||||
from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to_univnet_mel
|
||||
|
@ -34,12 +35,8 @@ def load_audio(audiopath, sr=22050):
|
|||
"""
|
||||
audio, sampling_rate = torchaudio.load(audiopath)
|
||||
|
||||
if len(audio.shape) > 1:
|
||||
if audio.shape[0] < 5:
|
||||
audio = audio[0]
|
||||
else:
|
||||
assert audio.shape[1] < 5
|
||||
audio = audio[:, 0]
|
||||
if audio.shape[0] > 1:
|
||||
audio = audio.mean(0, keepdim=True)
|
||||
|
||||
if sampling_rate != sr:
|
||||
resampler = torchaudio.transforms.Resample(sampling_rate, sr)
|
||||
|
@ -376,7 +373,7 @@ class Xtts(BaseTTS):
|
|||
return next(self.parameters()).device
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_gpt_cond_latents(self, audio_path: str, length: int = 3):
|
||||
def get_gpt_cond_latents(self, audio, sr, length: int = 3):
|
||||
"""Compute the conditioning latents for the GPT model from the given audio.
|
||||
|
||||
Args:
|
||||
|
@ -384,24 +381,21 @@ class Xtts(BaseTTS):
|
|||
length (int): Length of the audio in seconds. Defaults to 3.
|
||||
"""
|
||||
|
||||
audio = load_audio(audio_path)
|
||||
audio = audio[:, : 22050 * length]
|
||||
mel = wav_to_mel_cloning(audio, mel_norms=self.mel_stats.cpu())
|
||||
audio_22k = torchaudio.functional.resample(audio, sr, 22050)
|
||||
audio_22k = audio_22k[:, : 22050 * length]
|
||||
mel = wav_to_mel_cloning(audio_22k, mel_norms=self.mel_stats.cpu())
|
||||
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
|
||||
return cond_latent.transpose(1, 2)
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_diffusion_cond_latents(
|
||||
self,
|
||||
audio_path,
|
||||
):
|
||||
def get_diffusion_cond_latents(self, audio, sr):
|
||||
from math import ceil
|
||||
|
||||
diffusion_conds = []
|
||||
CHUNK_SIZE = 102400
|
||||
audio = load_audio(audio_path, 24000)
|
||||
for chunk in range(ceil(audio.shape[1] / CHUNK_SIZE)):
|
||||
current_sample = audio[:, chunk * CHUNK_SIZE : (chunk + 1) * CHUNK_SIZE]
|
||||
audio_24k = torchaudio.functional.resample(audio, sr, 24000)
|
||||
for chunk in range(ceil(audio_24k.shape[1] / CHUNK_SIZE)):
|
||||
current_sample = audio_24k[:, chunk * CHUNK_SIZE : (chunk + 1) * CHUNK_SIZE]
|
||||
current_sample = pad_or_truncate(current_sample, CHUNK_SIZE)
|
||||
cond_mel = wav_to_univnet_mel(
|
||||
current_sample.to(self.device),
|
||||
|
@ -414,27 +408,38 @@ class Xtts(BaseTTS):
|
|||
return diffusion_latent
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_speaker_embedding(self, audio_path):
|
||||
audio = load_audio(audio_path, self.hifigan_decoder.speaker_encoder_audio_config["sample_rate"])
|
||||
speaker_embedding = (
|
||||
self.hifigan_decoder.speaker_encoder.forward(audio.to(self.device), l2_norm=True)
|
||||
.unsqueeze(-1)
|
||||
.to(self.device)
|
||||
)
|
||||
return speaker_embedding
|
||||
def get_speaker_embedding(self, audio, sr):
|
||||
audio_16k = torchaudio.functional.resample(audio, sr, 16000)
|
||||
return self.hifigan_decoder.speaker_encoder.forward(
|
||||
audio_16k.to(self.device), l2_norm=True
|
||||
).unsqueeze(-1).to(self.device)
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_conditioning_latents(
|
||||
self,
|
||||
audio_path,
|
||||
gpt_cond_len=3,
|
||||
gpt_cond_len=6,
|
||||
max_ref_length=10,
|
||||
librosa_trim_db=None,
|
||||
sound_norm_refs=False,
|
||||
):
|
||||
speaker_embedding = None
|
||||
diffusion_cond_latents = None
|
||||
|
||||
audio, sr = torchaudio.load(audio_path)
|
||||
audio = audio[:, : sr * max_ref_length].to(self.device)
|
||||
if audio.shape[0] > 1:
|
||||
audio = audio.mean(0, keepdim=True)
|
||||
if sound_norm_refs:
|
||||
audio = (audio / torch.abs(audio).max()) * 0.75
|
||||
if librosa_trim_db is not None:
|
||||
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]
|
||||
|
||||
if self.args.use_hifigan or self.args.use_ne_hifigan:
|
||||
speaker_embedding = self.get_speaker_embedding(audio_path)
|
||||
speaker_embedding = self.get_speaker_embedding(audio, sr)
|
||||
else:
|
||||
diffusion_cond_latents = self.get_diffusion_cond_latents(audio_path)
|
||||
gpt_cond_latents = self.get_gpt_cond_latents(audio_path, length=gpt_cond_len) # [1, 1024, T]
|
||||
diffusion_cond_latents = self.get_diffusion_cond_latents(audio, sr)
|
||||
gpt_cond_latents = self.get_gpt_cond_latents(audio, sr, length=gpt_cond_len) # [1, 1024, T]
|
||||
return gpt_cond_latents, diffusion_cond_latents, speaker_embedding
|
||||
|
||||
def synthesize(self, text, config, speaker_wav, language, **kwargs):
|
||||
|
@ -494,7 +499,7 @@ class Xtts(BaseTTS):
|
|||
repetition_penalty=2.0,
|
||||
top_k=50,
|
||||
top_p=0.85,
|
||||
gpt_cond_len=4,
|
||||
gpt_cond_len=6,
|
||||
do_sample=True,
|
||||
# Decoder inference
|
||||
decoder_iterations=100,
|
||||
|
@ -531,7 +536,7 @@ class Xtts(BaseTTS):
|
|||
(aka boring) outputs. Defaults to 0.8.
|
||||
|
||||
gpt_cond_len: (int) Length of the audio used for cloning. If audio is shorter, then audio length is used
|
||||
else the first `gpt_cond_len` secs is used. Defaults to 3 seconds.
|
||||
else the first `gpt_cond_len` secs is used. Defaults to 6 seconds.
|
||||
|
||||
decoder_iterations: (int) Number of diffusion steps to perform. [0,4000]. More steps means the network has
|
||||
more chances to iteratively refine the output, which should theoretically mean a higher quality output.
|
||||
|
|
Loading…
Reference in New Issue