mirror of https://github.com/coqui-ai/TTS.git
UPDATE add audio edit function
This commit is contained in:
parent
5dcc16d193
commit
3559581dc9
|
@ -691,6 +691,180 @@ class Xtts(BaseTTS):
|
||||||
last_tokens = []
|
last_tokens = []
|
||||||
yield wav_chunk
|
yield wav_chunk
|
||||||
|
|
||||||
|
def init_for_audio_edit(self, dvae_checkpoint, mel_norm_file):
|
||||||
|
from TTS.tts.layers.xtts.dvae import DiscreteVAE
|
||||||
|
from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram
|
||||||
|
|
||||||
|
dvae_sample_rate: int = 22050
|
||||||
|
|
||||||
|
# Load DVAE
|
||||||
|
self.dvae = DiscreteVAE(
|
||||||
|
channels=80,
|
||||||
|
normalization=None,
|
||||||
|
positional_dims=1,
|
||||||
|
num_tokens=self.args.gpt_num_audio_tokens - 2,
|
||||||
|
codebook_dim=512,
|
||||||
|
hidden_dim=512,
|
||||||
|
num_resnet_blocks=3,
|
||||||
|
kernel_size=3,
|
||||||
|
num_layers=2,
|
||||||
|
use_transposed_convs=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.dvae.eval()
|
||||||
|
dvae_dict = torch.load(dvae_checkpoint, map_location=torch.device("cpu"))
|
||||||
|
self.dvae.load_state_dict(dvae_dict, strict=False)
|
||||||
|
print(">> DVAE weights restored from:", dvae_checkpoint)
|
||||||
|
|
||||||
|
# Mel spectrogram extractor for DVAE
|
||||||
|
self.torch_mel_spectrogram_dvae = TorchMelSpectrogram(
|
||||||
|
mel_norm_file=mel_norm_file, sampling_rate=dvae_sample_rate
|
||||||
|
)
|
||||||
|
|
||||||
|
self.dvae = self.dvae.to(self.device)
|
||||||
|
self.torch_mel_spectrogram_dvae = self.torch_mel_spectrogram_dvae.to(self.device)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def audio_edit(
|
||||||
|
self,
|
||||||
|
text_left,
|
||||||
|
text_edit,
|
||||||
|
text_right,
|
||||||
|
audio_left_path,
|
||||||
|
audio_right_path,
|
||||||
|
language,
|
||||||
|
gpt_cond_latent,
|
||||||
|
speaker_embedding,
|
||||||
|
generate_times=5,
|
||||||
|
# GPT inference
|
||||||
|
temperature=0.75,
|
||||||
|
length_penalty=1.0,
|
||||||
|
repetition_penalty=10.0,
|
||||||
|
top_k=50,
|
||||||
|
top_p=0.85,
|
||||||
|
do_sample=True,
|
||||||
|
num_beams=1,
|
||||||
|
speed=1.0,
|
||||||
|
enable_text_splitting=False,
|
||||||
|
**hf_generate_kwargs,):
|
||||||
|
|
||||||
|
load_sr: int = 22050
|
||||||
|
audio_left = load_audio(audio_left_path, load_sr)
|
||||||
|
audio_left = audio_left.to(self.device)
|
||||||
|
audio_right = load_audio(audio_right_path, load_sr)
|
||||||
|
audio_right = audio_right.to(self.device)
|
||||||
|
|
||||||
|
dvae_mel_spec_left = self.torch_mel_spectrogram_dvae(audio_left)
|
||||||
|
codes_left = self.dvae.get_codebook_indices(dvae_mel_spec_left)
|
||||||
|
dvae_mel_spec_right = self.torch_mel_spectrogram_dvae(audio_right)
|
||||||
|
codes_right = self.dvae.get_codebook_indices(dvae_mel_spec_right)
|
||||||
|
|
||||||
|
language = language.split("-")[0] # remove the country code
|
||||||
|
length_scale = 1.0 / max(speed, 0.05)
|
||||||
|
gpt_cond_latent = gpt_cond_latent.to(self.device)
|
||||||
|
speaker_embedding = speaker_embedding.to(self.device)
|
||||||
|
|
||||||
|
wavs = []
|
||||||
|
with torch.no_grad():
|
||||||
|
sent = text_left + text_edit
|
||||||
|
sent = sent.strip().lower()
|
||||||
|
text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
|
sent_left = text_left.strip().lower()
|
||||||
|
text_left_tokens = torch.IntTensor(self.tokenizer.encode(sent_left, lang=language)).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
|
sent_right = text_edit + text_right
|
||||||
|
sent_right = sent_right.strip().lower()
|
||||||
|
text_right_tokens = torch.IntTensor(self.tokenizer.encode(sent_right, lang=language)).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
|
||||||
|
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
|
||||||
|
|
||||||
|
gpt_inputs = self.gpt.compute_embeddings(gpt_cond_latent, text_tokens)
|
||||||
|
# add codes_left to inputs
|
||||||
|
gpt_inputs = torch.cat([gpt_inputs, codes_left], dim=1)
|
||||||
|
|
||||||
|
best_gpt_latents = None
|
||||||
|
best_logits = 0
|
||||||
|
for _ in range(generate_times):
|
||||||
|
gen = self.gpt.gpt_inference.generate(
|
||||||
|
gpt_inputs,
|
||||||
|
bos_token_id=self.gpt.start_audio_token,
|
||||||
|
pad_token_id=self.gpt.stop_audio_token,
|
||||||
|
eos_token_id=self.gpt.stop_audio_token,
|
||||||
|
max_length=self.gpt.max_gen_mel_tokens + gpt_inputs.shape[-1],
|
||||||
|
do_sample=do_sample,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
temperature=temperature,
|
||||||
|
num_return_sequences=self.gpt_batch_size,
|
||||||
|
num_beams=num_beams,
|
||||||
|
length_penalty=length_penalty,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
output_attentions=False,
|
||||||
|
**hf_generate_kwargs,
|
||||||
|
)
|
||||||
|
gpt_codes = gen[:, gpt_inputs.shape[1]: -1]
|
||||||
|
|
||||||
|
gen_len = gpt_codes.shape[-1]
|
||||||
|
|
||||||
|
gpt_right_codes = torch.cat([gpt_codes, codes_right], dim=1)
|
||||||
|
expected_output_len = torch.tensor(
|
||||||
|
[gpt_right_codes.shape[-1] * self.gpt.code_stride_len], device=text_right_tokens.device
|
||||||
|
)
|
||||||
|
text_len = torch.tensor([text_right_tokens.shape[-1]], device=self.device)
|
||||||
|
|
||||||
|
gpt_latents = self.gpt(
|
||||||
|
text_right_tokens,
|
||||||
|
text_len,
|
||||||
|
gpt_right_codes,
|
||||||
|
expected_output_len,
|
||||||
|
cond_latents=gpt_cond_latent,
|
||||||
|
return_attentions=False,
|
||||||
|
return_latent=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
right_logits = torch.gather(F.softmax(gpt_latents[:, gen_len:], dim=-1),
|
||||||
|
-1,
|
||||||
|
codes_right.unsqueeze(0).transpose(-1, -2))
|
||||||
|
sum_logits = torch.sum(right_logits, dim=-2)
|
||||||
|
sum_logits = float(sum_logits[0,0])
|
||||||
|
|
||||||
|
if sum_logits > best_logits:
|
||||||
|
best_logits = sum_logits
|
||||||
|
best_gpt_latents = gpt_latents.detach()
|
||||||
|
|
||||||
|
|
||||||
|
expected_left_len = torch.tensor(
|
||||||
|
[codes_left.shape[-1] * self.gpt.code_stride_len], device=text_left_tokens.device
|
||||||
|
)
|
||||||
|
text_left_len = torch.tensor([text_left_tokens.shape[-1]], device=self.device)
|
||||||
|
gpt_left_latents = self.gpt(
|
||||||
|
text_left_tokens,
|
||||||
|
text_left_len,
|
||||||
|
codes_left,
|
||||||
|
expected_left_len,
|
||||||
|
cond_latents=gpt_cond_latent,
|
||||||
|
return_attentions=False,
|
||||||
|
return_latent=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# if length_scale != 1.0:
|
||||||
|
# best_gpt_latents = F.interpolate(
|
||||||
|
# best_gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear"
|
||||||
|
# ).transpose(1, 2)
|
||||||
|
|
||||||
|
# wavs.append(load_audio(audio_left_path, self.args.output_sample_rate).cpu().squeeze())
|
||||||
|
# wavs.append(self.hifigan_decoder(best_gpt_latents, g=speaker_embedding).cpu().squeeze())
|
||||||
|
# wavs.append(load_audio(audio_right_path, self.args.output_sample_rate).cpu().squeeze())
|
||||||
|
wavs.append(self.hifigan_decoder(torch.cat([gpt_left_latents, best_gpt_latents], dim=1),
|
||||||
|
g=speaker_embedding).cpu().squeeze())
|
||||||
|
|
||||||
|
return {
|
||||||
|
"wav": torch.cat(wavs, dim=0).numpy(),
|
||||||
|
}
|
||||||
|
|
||||||
def forward(self):
|
def forward(self):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training"
|
"XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training"
|
||||||
|
|
Loading…
Reference in New Issue