mirror of https://github.com/coqui-ai/TTS.git
commit
8716bf855b
|
@ -71,7 +71,7 @@ class TTS(nn.Module):
|
||||||
self.voice_converter = None
|
self.voice_converter = None
|
||||||
self.csapi = None
|
self.csapi = None
|
||||||
self.cs_api_model = cs_api_model
|
self.cs_api_model = cs_api_model
|
||||||
self.model_name = None
|
self.model_name = ""
|
||||||
|
|
||||||
if gpu:
|
if gpu:
|
||||||
warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.")
|
warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.")
|
||||||
|
@ -460,7 +460,7 @@ class TTS(nn.Module):
|
||||||
"""
|
"""
|
||||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
|
||||||
# Lazy code... save it to a temp file to resample it while reading it for VC
|
# Lazy code... save it to a temp file to resample it while reading it for VC
|
||||||
self.tts_to_file(text=text, speaker=None, language=language, file_path=fp.name)
|
self.tts_to_file(text=text, speaker=None, language=language, file_path=fp.name,speaker_wav=speaker_wav)
|
||||||
if self.voice_converter is None:
|
if self.voice_converter is None:
|
||||||
self.load_vc_model_by_name("voice_conversion_models/multilingual/vctk/freevc24")
|
self.load_vc_model_by_name("voice_conversion_models/multilingual/vctk/freevc24")
|
||||||
wav = self.voice_converter.voice_conversion(source_wav=fp.name, target_wav=speaker_wav)
|
wav = self.voice_converter.voice_conversion(source_wav=fp.name, target_wav=speaker_wav)
|
||||||
|
|
|
@ -483,13 +483,10 @@ class VoiceBpeTokenizer:
|
||||||
if lang == "zh-cn":
|
if lang == "zh-cn":
|
||||||
txt = chinese_transliterate(txt)
|
txt = chinese_transliterate(txt)
|
||||||
elif lang == "ja":
|
elif lang == "ja":
|
||||||
assert txt[:4] == "[ja]", "Japanese speech should start with the [ja] token."
|
|
||||||
txt = txt[4:]
|
|
||||||
if self.katsu is None:
|
if self.katsu is None:
|
||||||
import cutlet
|
import cutlet
|
||||||
self.katsu = cutlet.Cutlet()
|
self.katsu = cutlet.Cutlet()
|
||||||
txt = japanese_cleaners(txt, self.katsu)
|
txt = japanese_cleaners(txt, self.katsu)
|
||||||
txt = "[ja]" + txt
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
return txt
|
return txt
|
||||||
|
@ -497,6 +494,7 @@ class VoiceBpeTokenizer:
|
||||||
def encode(self, txt, lang):
|
def encode(self, txt, lang):
|
||||||
if self.preprocess:
|
if self.preprocess:
|
||||||
txt = self.preprocess_text(txt, lang)
|
txt = self.preprocess_text(txt, lang)
|
||||||
|
txt = f"[{lang}]{txt}"
|
||||||
txt = txt.replace(" ", "[SPACE]")
|
txt = txt.replace(" ", "[SPACE]")
|
||||||
return self.tokenizer.encode(txt).ids
|
return self.tokenizer.encode(txt).ids
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ from dataclasses import dataclass
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
import librosa
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
|
|
||||||
from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to_univnet_mel
|
from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to_univnet_mel
|
||||||
|
@ -21,34 +22,6 @@ from TTS.utils.io import load_fsspec
|
||||||
init_stream_support()
|
init_stream_support()
|
||||||
|
|
||||||
|
|
||||||
def load_audio(audiopath, sr=22050):
|
|
||||||
"""
|
|
||||||
Load an audio file from disk and resample it to the specified sampling rate.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
audiopath (str): Path to the audio file.
|
|
||||||
sr (int): Target sampling rate.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: Audio waveform tensor with shape (1, T), where T is the number of samples.
|
|
||||||
"""
|
|
||||||
audio, sampling_rate = torchaudio.load(audiopath)
|
|
||||||
|
|
||||||
if len(audio.shape) > 1:
|
|
||||||
if audio.shape[0] < 5:
|
|
||||||
audio = audio[0]
|
|
||||||
else:
|
|
||||||
assert audio.shape[1] < 5
|
|
||||||
audio = audio[:, 0]
|
|
||||||
|
|
||||||
if sampling_rate != sr:
|
|
||||||
resampler = torchaudio.transforms.Resample(sampling_rate, sr)
|
|
||||||
audio = resampler(audio)
|
|
||||||
|
|
||||||
audio = audio.clamp_(-1, 1)
|
|
||||||
return audio.unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
def wav_to_mel_cloning(
|
def wav_to_mel_cloning(
|
||||||
wav, mel_norms_file="../experiments/clips_mel_norms.pth", mel_norms=None, device=torch.device("cpu")
|
wav, mel_norms_file="../experiments/clips_mel_norms.pth", mel_norms=None, device=torch.device("cpu")
|
||||||
):
|
):
|
||||||
|
@ -376,7 +349,7 @@ class Xtts(BaseTTS):
|
||||||
return next(self.parameters()).device
|
return next(self.parameters()).device
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def get_gpt_cond_latents(self, audio_path: str, length: int = 3):
|
def get_gpt_cond_latents(self, audio, sr, length: int = 3):
|
||||||
"""Compute the conditioning latents for the GPT model from the given audio.
|
"""Compute the conditioning latents for the GPT model from the given audio.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -384,24 +357,21 @@ class Xtts(BaseTTS):
|
||||||
length (int): Length of the audio in seconds. Defaults to 3.
|
length (int): Length of the audio in seconds. Defaults to 3.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
audio = load_audio(audio_path)
|
audio_22k = torchaudio.functional.resample(audio, sr, 22050)
|
||||||
audio = audio[:, : 22050 * length]
|
audio_22k = audio_22k[:, : 22050 * length]
|
||||||
mel = wav_to_mel_cloning(audio, mel_norms=self.mel_stats.cpu())
|
mel = wav_to_mel_cloning(audio_22k, mel_norms=self.mel_stats.cpu())
|
||||||
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
|
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
|
||||||
return cond_latent.transpose(1, 2)
|
return cond_latent.transpose(1, 2)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def get_diffusion_cond_latents(
|
def get_diffusion_cond_latents(self, audio, sr):
|
||||||
self,
|
|
||||||
audio_path,
|
|
||||||
):
|
|
||||||
from math import ceil
|
from math import ceil
|
||||||
|
|
||||||
diffusion_conds = []
|
diffusion_conds = []
|
||||||
CHUNK_SIZE = 102400
|
CHUNK_SIZE = 102400
|
||||||
audio = load_audio(audio_path, 24000)
|
audio_24k = torchaudio.functional.resample(audio, sr, 24000)
|
||||||
for chunk in range(ceil(audio.shape[1] / CHUNK_SIZE)):
|
for chunk in range(ceil(audio_24k.shape[1] / CHUNK_SIZE)):
|
||||||
current_sample = audio[:, chunk * CHUNK_SIZE : (chunk + 1) * CHUNK_SIZE]
|
current_sample = audio_24k[:, chunk * CHUNK_SIZE : (chunk + 1) * CHUNK_SIZE]
|
||||||
current_sample = pad_or_truncate(current_sample, CHUNK_SIZE)
|
current_sample = pad_or_truncate(current_sample, CHUNK_SIZE)
|
||||||
cond_mel = wav_to_univnet_mel(
|
cond_mel = wav_to_univnet_mel(
|
||||||
current_sample.to(self.device),
|
current_sample.to(self.device),
|
||||||
|
@ -414,27 +384,38 @@ class Xtts(BaseTTS):
|
||||||
return diffusion_latent
|
return diffusion_latent
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def get_speaker_embedding(self, audio_path):
|
def get_speaker_embedding(self, audio, sr):
|
||||||
audio = load_audio(audio_path, self.hifigan_decoder.speaker_encoder_audio_config["sample_rate"])
|
audio_16k = torchaudio.functional.resample(audio, sr, 16000)
|
||||||
speaker_embedding = (
|
return self.hifigan_decoder.speaker_encoder.forward(
|
||||||
self.hifigan_decoder.speaker_encoder.forward(audio.to(self.device), l2_norm=True)
|
audio_16k.to(self.device), l2_norm=True
|
||||||
.unsqueeze(-1)
|
).unsqueeze(-1).to(self.device)
|
||||||
.to(self.device)
|
|
||||||
)
|
|
||||||
return speaker_embedding
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
def get_conditioning_latents(
|
def get_conditioning_latents(
|
||||||
self,
|
self,
|
||||||
audio_path,
|
audio_path,
|
||||||
gpt_cond_len=3,
|
gpt_cond_len=6,
|
||||||
|
max_ref_length=10,
|
||||||
|
librosa_trim_db=None,
|
||||||
|
sound_norm_refs=False,
|
||||||
):
|
):
|
||||||
speaker_embedding = None
|
speaker_embedding = None
|
||||||
diffusion_cond_latents = None
|
diffusion_cond_latents = None
|
||||||
if self.args.use_hifigan:
|
|
||||||
speaker_embedding = self.get_speaker_embedding(audio_path)
|
audio, sr = torchaudio.load(audio_path)
|
||||||
|
audio = audio[:, : sr * max_ref_length].to(self.device)
|
||||||
|
if audio.shape[0] > 1:
|
||||||
|
audio = audio.mean(0, keepdim=True)
|
||||||
|
if sound_norm_refs:
|
||||||
|
audio = (audio / torch.abs(audio).max()) * 0.75
|
||||||
|
if librosa_trim_db is not None:
|
||||||
|
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]
|
||||||
|
|
||||||
|
if self.args.use_hifigan or self.args.use_ne_hifigan:
|
||||||
|
speaker_embedding = self.get_speaker_embedding(audio, sr)
|
||||||
else:
|
else:
|
||||||
diffusion_cond_latents = self.get_diffusion_cond_latents(audio_path)
|
diffusion_cond_latents = self.get_diffusion_cond_latents(audio, sr)
|
||||||
gpt_cond_latents = self.get_gpt_cond_latents(audio_path, length=gpt_cond_len) # [1, 1024, T]
|
gpt_cond_latents = self.get_gpt_cond_latents(audio, sr, length=gpt_cond_len) # [1, 1024, T]
|
||||||
return gpt_cond_latents, diffusion_cond_latents, speaker_embedding
|
return gpt_cond_latents, diffusion_cond_latents, speaker_embedding
|
||||||
|
|
||||||
def synthesize(self, text, config, speaker_wav, language, **kwargs):
|
def synthesize(self, text, config, speaker_wav, language, **kwargs):
|
||||||
|
@ -494,7 +475,7 @@ class Xtts(BaseTTS):
|
||||||
repetition_penalty=2.0,
|
repetition_penalty=2.0,
|
||||||
top_k=50,
|
top_k=50,
|
||||||
top_p=0.85,
|
top_p=0.85,
|
||||||
gpt_cond_len=4,
|
gpt_cond_len=6,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
# Decoder inference
|
# Decoder inference
|
||||||
decoder_iterations=100,
|
decoder_iterations=100,
|
||||||
|
@ -531,7 +512,7 @@ class Xtts(BaseTTS):
|
||||||
(aka boring) outputs. Defaults to 0.8.
|
(aka boring) outputs. Defaults to 0.8.
|
||||||
|
|
||||||
gpt_cond_len: (int) Length of the audio used for cloning. If audio is shorter, then audio length is used
|
gpt_cond_len: (int) Length of the audio used for cloning. If audio is shorter, then audio length is used
|
||||||
else the first `gpt_cond_len` secs is used. Defaults to 3 seconds.
|
else the first `gpt_cond_len` secs is used. Defaults to 6 seconds.
|
||||||
|
|
||||||
decoder_iterations: (int) Number of diffusion steps to perform. [0,4000]. More steps means the network has
|
decoder_iterations: (int) Number of diffusion steps to perform. [0,4000]. More steps means the network has
|
||||||
more chances to iteratively refine the output, which should theoretically mean a higher quality output.
|
more chances to iteratively refine the output, which should theoretically mean a higher quality output.
|
||||||
|
@ -610,7 +591,7 @@ class Xtts(BaseTTS):
|
||||||
decoder="hifigan",
|
decoder="hifigan",
|
||||||
**hf_generate_kwargs,
|
**hf_generate_kwargs,
|
||||||
):
|
):
|
||||||
text = f"[{language}]{text.strip().lower()}"
|
text = text.strip().lower()
|
||||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
|
@ -722,7 +703,7 @@ class Xtts(BaseTTS):
|
||||||
assert hasattr(
|
assert hasattr(
|
||||||
self, "hifigan_decoder"
|
self, "hifigan_decoder"
|
||||||
), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream."
|
), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream."
|
||||||
text = f"[{language}]{text.strip().lower()}"
|
text = text.strip().lower()
|
||||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
fake_inputs = self.gpt.compute_embeddings(
|
fake_inputs = self.gpt.compute_embeddings(
|
||||||
|
|
Loading…
Reference in New Issue