Merge pull request #3226 from coqui-ai/dev

v0.20.5
This commit is contained in:
Eren Gölge 2023-11-15 15:34:29 +01:00 committed by GitHub
commit 6471273c1d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 57 additions and 52 deletions

View File

@ -28,7 +28,7 @@
📚 Utilities for dataset analysis and curation.
______________________________________________________________________
[![Dicord](https://img.shields.io/discord/1037326658807533628?color=%239B59B6&label=chat%20on%20discord)](https://discord.gg/5eXr5seRrv)
[![Discord](https://img.shields.io/discord/1037326658807533628?color=%239B59B6&label=chat%20on%20discord)](https://discord.gg/5eXr5seRrv)
[![License](<https://img.shields.io/badge/License-MPL%202.0-brightgreen.svg>)](https://opensource.org/licenses/MPL-2.0)
[![PyPI version](https://badge.fury.io/py/TTS.svg)](https://badge.fury.io/py/TTS)
[![Covenant](https://camo.githubusercontent.com/7d620efaa3eac1c5b060ece5d6aacfcc8b81a74a04d05cd0398689c01c4463bb/68747470733a2f2f696d672e736869656c64732e696f2f62616467652f436f6e7472696275746f72253230436f76656e616e742d76322e3025323061646f707465642d6666363962342e737667)](https://github.com/coqui-ai/TTS/blob/master/CODE_OF_CONDUCT.md)

View File

@ -1 +1 @@
0.20.4
0.20.5

View File

@ -426,15 +426,6 @@ class GPT(nn.Module):
if max_mel_len > audio_codes.shape[-1]:
audio_codes = F.pad(audio_codes, (0, max_mel_len - audio_codes.shape[-1]))
silence = True
for idx, l in enumerate(code_lengths):
length = l.item()
while silence:
if audio_codes[idx, length - 1] != 83:
break
length -= 1
code_lengths[idx] = length
# 💖 Lovely assertions
assert (
max_mel_len <= audio_codes.shape[-1]
@ -450,7 +441,7 @@ class GPT(nn.Module):
audio_codes = F.pad(audio_codes[:, :max_mel_len], (0, 1), value=self.stop_audio_token)
# Pad mel codes with stop_audio_token
audio_codes = self.set_mel_padding(audio_codes, code_lengths)
audio_codes = self.set_mel_padding(audio_codes, code_lengths - 3) # -3 to get the real code lengths without consider start and stop tokens that was not added yet
# Build input and target tensors
# Prepend start token to inputs and append stop token to targets

View File

@ -115,7 +115,7 @@ _abbreviations = {
# There are not many common abbreviations in Arabic as in English.
]
],
"zh": [
"zh-cn": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
# Chinese doesn't typically use abbreviations in the same way as Latin-based scripts.
@ -280,7 +280,7 @@ _symbols_multilingual = {
("°", " درجة "),
]
],
"zh": [
"zh-cn": [
# Chinese
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
@ -571,7 +571,7 @@ class VoiceBpeTokenizer:
)
def preprocess_text(self, txt, lang):
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh", "zh-cn"}:
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh-cn", "zh-cn"}:
txt = multilingual_cleaners(txt, lang)
if lang in {"zh", "zh-cn"}:
txt = chinese_transliterate(txt)
@ -682,8 +682,8 @@ def test_expand_numbers_multilingual():
("Dat wordt dan $20 meneer.", "Dat wordt dan twintig dollar meneer.", "nl"),
("Dat wordt dan 20€ meneer.", "Dat wordt dan twintig euro meneer.", "nl"),
# Chinese (Simplified)
("在12.5秒内", "在十二点五秒内", "zh"),
("有50名士兵", "有五十名士兵", "zh"),
("在12.5秒内", "在十二点五秒内", "zh-cn"),
("有50名士兵", "有五十名士兵", "zh-cn"),
# ("那将是$20先生", '那将是二十美元先生', 'zh'), currency doesn't work
# ("那将是20€先生", '那将是二十欧元先生', 'zh'),
# Turkish
@ -764,7 +764,7 @@ def test_symbols_multilingual():
("Ik heb 14% batterij", "Ik heb 14 procent batterij", "nl"),
("Ik zie je @ het feest", "Ik zie je bij het feest", "nl"),
("لدي 14% في البطارية", "لدي 14 في المئة في البطارية", "ar"),
("我的电量为 14%", "我的电量为 14 百分之", "zh"),
("我的电量为 14%", "我的电量为 14 百分之", "zh-cn"),
("Pilim %14 dolu.", "Pilim yüzde 14 dolu.", "tr"),
("Az akkumulátorom töltöttsége 14%", "Az akkumulátorom töltöttsége 14 százalék", "hu"),
("배터리 잔량이 14%입니다.", "배터리 잔량이 14 퍼센트입니다.", "ko"),

View File

@ -7,7 +7,6 @@ import torch.nn.functional as F
import torchaudio
from coqpit import Coqpit
from TTS.tts.layers.tortoise.audio_utils import wav_to_univnet_mel
from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
from TTS.tts.layers.xtts.stream_generator import init_stream_support
@ -308,26 +307,6 @@ class Xtts(BaseTTS):
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
return cond_latent.transpose(1, 2)
@torch.inference_mode()
def get_diffusion_cond_latents(self, audio, sr):
from math import ceil
diffusion_conds = []
CHUNK_SIZE = 102400
audio_24k = torchaudio.functional.resample(audio, sr, 24000)
for chunk in range(ceil(audio_24k.shape[1] / CHUNK_SIZE)):
current_sample = audio_24k[:, chunk * CHUNK_SIZE : (chunk + 1) * CHUNK_SIZE]
current_sample = pad_or_truncate(current_sample, CHUNK_SIZE)
cond_mel = wav_to_univnet_mel(
current_sample.to(self.device),
do_normalization=False,
device=self.device,
)
diffusion_conds.append(cond_mel)
diffusion_conds = torch.stack(diffusion_conds, dim=1)
diffusion_latent = self.diffusion_decoder.get_conditioning(diffusion_conds)
return diffusion_latent
@torch.inference_mode()
def get_speaker_embedding(self, audio, sr):
audio_16k = torchaudio.functional.resample(audio, sr, 16000)
@ -530,8 +509,10 @@ class Xtts(BaseTTS):
top_p=0.85,
do_sample=True,
num_beams=1,
speed=1.0,
**hf_generate_kwargs,
):
length_scale = 1.0 / max(speed, 0.05)
text = text.strip().lower()
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
@ -573,16 +554,13 @@ class Xtts(BaseTTS):
return_attentions=False,
return_latent=True,
)
silence_token = 83
ctokens = 0
for k in range(gpt_codes.shape[-1]):
if gpt_codes[0, k] == silence_token:
ctokens += 1
else:
ctokens = 0
if ctokens > 8:
gpt_latents = gpt_latents[:, :k]
break
if length_scale != 1.0:
gpt_latents = F.interpolate(
gpt_latents.transpose(1, 2),
scale_factor=length_scale,
mode="linear"
).transpose(1, 2)
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
@ -634,8 +612,10 @@ class Xtts(BaseTTS):
top_k=50,
top_p=0.85,
do_sample=True,
speed=1.0,
**hf_generate_kwargs,
):
length_scale = 1.0 / max(speed, 0.05)
text = text.strip().lower()
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
@ -674,6 +654,12 @@ class Xtts(BaseTTS):
if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
if length_scale != 1.0:
gpt_latents = F.interpolate(
gpt_latents.transpose(1, 2),
scale_factor=length_scale,
mode="linear"
).transpose(1, 2)
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len

View File

@ -111,7 +111,7 @@ def test_xtts_streaming():
model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
print("Computing speaker latents...")
gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav)
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav)
print("Inference...")
chunks = model.inference_stream(
@ -139,7 +139,7 @@ def test_xtts_v2():
"yes | "
f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 "
f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True '
f'--speaker_wav "{speaker_wav}" "{speaker_wav_2}" "--language_idx "en"'
f'--speaker_wav "{speaker_wav}" "{speaker_wav_2}" --language_idx "en"'
)
else:
run_cli(
@ -164,7 +164,7 @@ def test_xtts_v2_streaming():
model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
print("Computing speaker latents...")
gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav)
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav)
print("Inference...")
chunks = model.inference_stream(
@ -179,6 +179,34 @@ def test_xtts_v2_streaming():
assert chunk.shape[-1] > 5000
wav_chuncks.append(chunk)
assert len(wav_chuncks) > 1
normal_len = sum([len(chunk) for chunk in wav_chuncks])
chunks = model.inference_stream(
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
"en",
gpt_cond_latent,
speaker_embedding,
speed=1.5
)
wav_chuncks = []
for i, chunk in enumerate(chunks):
wav_chuncks.append(chunk)
fast_len = sum([len(chunk) for chunk in wav_chuncks])
chunks = model.inference_stream(
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
"en",
gpt_cond_latent,
speaker_embedding,
speed=0.66
)
wav_chuncks = []
for i, chunk in enumerate(chunks):
wav_chuncks.append(chunk)
slow_len = sum([len(chunk) for chunk in wav_chuncks])
assert slow_len > normal_len
assert normal_len > fast_len
def test_tortoise():