mirror of https://github.com/coqui-ai/TTS.git
commit
6471273c1d
|
@ -28,7 +28,7 @@
|
||||||
📚 Utilities for dataset analysis and curation.
|
📚 Utilities for dataset analysis and curation.
|
||||||
______________________________________________________________________
|
______________________________________________________________________
|
||||||
|
|
||||||
[](https://discord.gg/5eXr5seRrv)
|
[](https://discord.gg/5eXr5seRrv)
|
||||||
[](https://opensource.org/licenses/MPL-2.0)
|
[](https://opensource.org/licenses/MPL-2.0)
|
||||||
[](https://badge.fury.io/py/TTS)
|
[](https://badge.fury.io/py/TTS)
|
||||||
[](https://github.com/coqui-ai/TTS/blob/master/CODE_OF_CONDUCT.md)
|
[](https://github.com/coqui-ai/TTS/blob/master/CODE_OF_CONDUCT.md)
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
0.20.4
|
0.20.5
|
||||||
|
|
|
@ -426,15 +426,6 @@ class GPT(nn.Module):
|
||||||
if max_mel_len > audio_codes.shape[-1]:
|
if max_mel_len > audio_codes.shape[-1]:
|
||||||
audio_codes = F.pad(audio_codes, (0, max_mel_len - audio_codes.shape[-1]))
|
audio_codes = F.pad(audio_codes, (0, max_mel_len - audio_codes.shape[-1]))
|
||||||
|
|
||||||
silence = True
|
|
||||||
for idx, l in enumerate(code_lengths):
|
|
||||||
length = l.item()
|
|
||||||
while silence:
|
|
||||||
if audio_codes[idx, length - 1] != 83:
|
|
||||||
break
|
|
||||||
length -= 1
|
|
||||||
code_lengths[idx] = length
|
|
||||||
|
|
||||||
# 💖 Lovely assertions
|
# 💖 Lovely assertions
|
||||||
assert (
|
assert (
|
||||||
max_mel_len <= audio_codes.shape[-1]
|
max_mel_len <= audio_codes.shape[-1]
|
||||||
|
@ -450,7 +441,7 @@ class GPT(nn.Module):
|
||||||
audio_codes = F.pad(audio_codes[:, :max_mel_len], (0, 1), value=self.stop_audio_token)
|
audio_codes = F.pad(audio_codes[:, :max_mel_len], (0, 1), value=self.stop_audio_token)
|
||||||
|
|
||||||
# Pad mel codes with stop_audio_token
|
# Pad mel codes with stop_audio_token
|
||||||
audio_codes = self.set_mel_padding(audio_codes, code_lengths)
|
audio_codes = self.set_mel_padding(audio_codes, code_lengths - 3) # -3 to get the real code lengths without consider start and stop tokens that was not added yet
|
||||||
|
|
||||||
# Build input and target tensors
|
# Build input and target tensors
|
||||||
# Prepend start token to inputs and append stop token to targets
|
# Prepend start token to inputs and append stop token to targets
|
||||||
|
|
|
@ -115,7 +115,7 @@ _abbreviations = {
|
||||||
# There are not many common abbreviations in Arabic as in English.
|
# There are not many common abbreviations in Arabic as in English.
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
"zh": [
|
"zh-cn": [
|
||||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||||
for x in [
|
for x in [
|
||||||
# Chinese doesn't typically use abbreviations in the same way as Latin-based scripts.
|
# Chinese doesn't typically use abbreviations in the same way as Latin-based scripts.
|
||||||
|
@ -280,7 +280,7 @@ _symbols_multilingual = {
|
||||||
("°", " درجة "),
|
("°", " درجة "),
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
"zh": [
|
"zh-cn": [
|
||||||
# Chinese
|
# Chinese
|
||||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||||
for x in [
|
for x in [
|
||||||
|
@ -571,7 +571,7 @@ class VoiceBpeTokenizer:
|
||||||
)
|
)
|
||||||
|
|
||||||
def preprocess_text(self, txt, lang):
|
def preprocess_text(self, txt, lang):
|
||||||
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh", "zh-cn"}:
|
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh-cn", "zh-cn"}:
|
||||||
txt = multilingual_cleaners(txt, lang)
|
txt = multilingual_cleaners(txt, lang)
|
||||||
if lang in {"zh", "zh-cn"}:
|
if lang in {"zh", "zh-cn"}:
|
||||||
txt = chinese_transliterate(txt)
|
txt = chinese_transliterate(txt)
|
||||||
|
@ -682,8 +682,8 @@ def test_expand_numbers_multilingual():
|
||||||
("Dat wordt dan $20 meneer.", "Dat wordt dan twintig dollar meneer.", "nl"),
|
("Dat wordt dan $20 meneer.", "Dat wordt dan twintig dollar meneer.", "nl"),
|
||||||
("Dat wordt dan 20€ meneer.", "Dat wordt dan twintig euro meneer.", "nl"),
|
("Dat wordt dan 20€ meneer.", "Dat wordt dan twintig euro meneer.", "nl"),
|
||||||
# Chinese (Simplified)
|
# Chinese (Simplified)
|
||||||
("在12.5秒内", "在十二点五秒内", "zh"),
|
("在12.5秒内", "在十二点五秒内", "zh-cn"),
|
||||||
("有50名士兵", "有五十名士兵", "zh"),
|
("有50名士兵", "有五十名士兵", "zh-cn"),
|
||||||
# ("那将是$20先生", '那将是二十美元先生', 'zh'), currency doesn't work
|
# ("那将是$20先生", '那将是二十美元先生', 'zh'), currency doesn't work
|
||||||
# ("那将是20€先生", '那将是二十欧元先生', 'zh'),
|
# ("那将是20€先生", '那将是二十欧元先生', 'zh'),
|
||||||
# Turkish
|
# Turkish
|
||||||
|
@ -764,7 +764,7 @@ def test_symbols_multilingual():
|
||||||
("Ik heb 14% batterij", "Ik heb 14 procent batterij", "nl"),
|
("Ik heb 14% batterij", "Ik heb 14 procent batterij", "nl"),
|
||||||
("Ik zie je @ het feest", "Ik zie je bij het feest", "nl"),
|
("Ik zie je @ het feest", "Ik zie je bij het feest", "nl"),
|
||||||
("لدي 14% في البطارية", "لدي 14 في المئة في البطارية", "ar"),
|
("لدي 14% في البطارية", "لدي 14 في المئة في البطارية", "ar"),
|
||||||
("我的电量为 14%", "我的电量为 14 百分之", "zh"),
|
("我的电量为 14%", "我的电量为 14 百分之", "zh-cn"),
|
||||||
("Pilim %14 dolu.", "Pilim yüzde 14 dolu.", "tr"),
|
("Pilim %14 dolu.", "Pilim yüzde 14 dolu.", "tr"),
|
||||||
("Az akkumulátorom töltöttsége 14%", "Az akkumulátorom töltöttsége 14 százalék", "hu"),
|
("Az akkumulátorom töltöttsége 14%", "Az akkumulátorom töltöttsége 14 százalék", "hu"),
|
||||||
("배터리 잔량이 14%입니다.", "배터리 잔량이 14 퍼센트입니다.", "ko"),
|
("배터리 잔량이 14%입니다.", "배터리 잔량이 14 퍼센트입니다.", "ko"),
|
||||||
|
|
|
@ -7,7 +7,6 @@ import torch.nn.functional as F
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
|
|
||||||
from TTS.tts.layers.tortoise.audio_utils import wav_to_univnet_mel
|
|
||||||
from TTS.tts.layers.xtts.gpt import GPT
|
from TTS.tts.layers.xtts.gpt import GPT
|
||||||
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
||||||
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
||||||
|
@ -308,26 +307,6 @@ class Xtts(BaseTTS):
|
||||||
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
|
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
|
||||||
return cond_latent.transpose(1, 2)
|
return cond_latent.transpose(1, 2)
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def get_diffusion_cond_latents(self, audio, sr):
|
|
||||||
from math import ceil
|
|
||||||
|
|
||||||
diffusion_conds = []
|
|
||||||
CHUNK_SIZE = 102400
|
|
||||||
audio_24k = torchaudio.functional.resample(audio, sr, 24000)
|
|
||||||
for chunk in range(ceil(audio_24k.shape[1] / CHUNK_SIZE)):
|
|
||||||
current_sample = audio_24k[:, chunk * CHUNK_SIZE : (chunk + 1) * CHUNK_SIZE]
|
|
||||||
current_sample = pad_or_truncate(current_sample, CHUNK_SIZE)
|
|
||||||
cond_mel = wav_to_univnet_mel(
|
|
||||||
current_sample.to(self.device),
|
|
||||||
do_normalization=False,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
diffusion_conds.append(cond_mel)
|
|
||||||
diffusion_conds = torch.stack(diffusion_conds, dim=1)
|
|
||||||
diffusion_latent = self.diffusion_decoder.get_conditioning(diffusion_conds)
|
|
||||||
return diffusion_latent
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def get_speaker_embedding(self, audio, sr):
|
def get_speaker_embedding(self, audio, sr):
|
||||||
audio_16k = torchaudio.functional.resample(audio, sr, 16000)
|
audio_16k = torchaudio.functional.resample(audio, sr, 16000)
|
||||||
|
@ -530,8 +509,10 @@ class Xtts(BaseTTS):
|
||||||
top_p=0.85,
|
top_p=0.85,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
num_beams=1,
|
num_beams=1,
|
||||||
|
speed=1.0,
|
||||||
**hf_generate_kwargs,
|
**hf_generate_kwargs,
|
||||||
):
|
):
|
||||||
|
length_scale = 1.0 / max(speed, 0.05)
|
||||||
text = text.strip().lower()
|
text = text.strip().lower()
|
||||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
|
@ -573,16 +554,13 @@ class Xtts(BaseTTS):
|
||||||
return_attentions=False,
|
return_attentions=False,
|
||||||
return_latent=True,
|
return_latent=True,
|
||||||
)
|
)
|
||||||
silence_token = 83
|
|
||||||
ctokens = 0
|
if length_scale != 1.0:
|
||||||
for k in range(gpt_codes.shape[-1]):
|
gpt_latents = F.interpolate(
|
||||||
if gpt_codes[0, k] == silence_token:
|
gpt_latents.transpose(1, 2),
|
||||||
ctokens += 1
|
scale_factor=length_scale,
|
||||||
else:
|
mode="linear"
|
||||||
ctokens = 0
|
).transpose(1, 2)
|
||||||
if ctokens > 8:
|
|
||||||
gpt_latents = gpt_latents[:, :k]
|
|
||||||
break
|
|
||||||
|
|
||||||
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
|
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
|
||||||
|
|
||||||
|
@ -634,8 +612,10 @@ class Xtts(BaseTTS):
|
||||||
top_k=50,
|
top_k=50,
|
||||||
top_p=0.85,
|
top_p=0.85,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
|
speed=1.0,
|
||||||
**hf_generate_kwargs,
|
**hf_generate_kwargs,
|
||||||
):
|
):
|
||||||
|
length_scale = 1.0 / max(speed, 0.05)
|
||||||
text = text.strip().lower()
|
text = text.strip().lower()
|
||||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
|
@ -674,6 +654,12 @@ class Xtts(BaseTTS):
|
||||||
|
|
||||||
if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
|
if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
|
||||||
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
|
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
|
||||||
|
if length_scale != 1.0:
|
||||||
|
gpt_latents = F.interpolate(
|
||||||
|
gpt_latents.transpose(1, 2),
|
||||||
|
scale_factor=length_scale,
|
||||||
|
mode="linear"
|
||||||
|
).transpose(1, 2)
|
||||||
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
||||||
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
|
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
|
||||||
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
|
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
|
||||||
|
|
|
@ -111,7 +111,7 @@ def test_xtts_streaming():
|
||||||
model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
||||||
|
|
||||||
print("Computing speaker latents...")
|
print("Computing speaker latents...")
|
||||||
gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav)
|
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav)
|
||||||
|
|
||||||
print("Inference...")
|
print("Inference...")
|
||||||
chunks = model.inference_stream(
|
chunks = model.inference_stream(
|
||||||
|
@ -139,7 +139,7 @@ def test_xtts_v2():
|
||||||
"yes | "
|
"yes | "
|
||||||
f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 "
|
f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 "
|
||||||
f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True '
|
f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True '
|
||||||
f'--speaker_wav "{speaker_wav}" "{speaker_wav_2}" "--language_idx "en"'
|
f'--speaker_wav "{speaker_wav}" "{speaker_wav_2}" --language_idx "en"'
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
run_cli(
|
run_cli(
|
||||||
|
@ -164,7 +164,7 @@ def test_xtts_v2_streaming():
|
||||||
model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
||||||
|
|
||||||
print("Computing speaker latents...")
|
print("Computing speaker latents...")
|
||||||
gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav)
|
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav)
|
||||||
|
|
||||||
print("Inference...")
|
print("Inference...")
|
||||||
chunks = model.inference_stream(
|
chunks = model.inference_stream(
|
||||||
|
@ -179,6 +179,34 @@ def test_xtts_v2_streaming():
|
||||||
assert chunk.shape[-1] > 5000
|
assert chunk.shape[-1] > 5000
|
||||||
wav_chuncks.append(chunk)
|
wav_chuncks.append(chunk)
|
||||||
assert len(wav_chuncks) > 1
|
assert len(wav_chuncks) > 1
|
||||||
|
normal_len = sum([len(chunk) for chunk in wav_chuncks])
|
||||||
|
|
||||||
|
chunks = model.inference_stream(
|
||||||
|
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
|
||||||
|
"en",
|
||||||
|
gpt_cond_latent,
|
||||||
|
speaker_embedding,
|
||||||
|
speed=1.5
|
||||||
|
)
|
||||||
|
wav_chuncks = []
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
wav_chuncks.append(chunk)
|
||||||
|
fast_len = sum([len(chunk) for chunk in wav_chuncks])
|
||||||
|
|
||||||
|
chunks = model.inference_stream(
|
||||||
|
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
|
||||||
|
"en",
|
||||||
|
gpt_cond_latent,
|
||||||
|
speaker_embedding,
|
||||||
|
speed=0.66
|
||||||
|
)
|
||||||
|
wav_chuncks = []
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
wav_chuncks.append(chunk)
|
||||||
|
slow_len = sum([len(chunk) for chunk in wav_chuncks])
|
||||||
|
|
||||||
|
assert slow_len > normal_len
|
||||||
|
assert normal_len > fast_len
|
||||||
|
|
||||||
|
|
||||||
def test_tortoise():
|
def test_tortoise():
|
||||||
|
|
Loading…
Reference in New Issue