mirror of https://github.com/coqui-ai/TTS.git
Make style
This commit is contained in:
parent
26efdf6ee7
commit
44880f09ed
|
@ -17,7 +17,6 @@ from tqdm import tqdm
|
||||||
|
|
||||||
from TTS.tts.layers.tortoise.dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper
|
from TTS.tts.layers.tortoise.dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from k_diffusion.sampling import sample_dpmpp_2m, sample_euler_ancestral
|
from k_diffusion.sampling import sample_dpmpp_2m, sample_euler_ancestral
|
||||||
|
|
||||||
|
|
|
@ -441,7 +441,9 @@ class GPT(nn.Module):
|
||||||
audio_codes = F.pad(audio_codes[:, :max_mel_len], (0, 1), value=self.stop_audio_token)
|
audio_codes = F.pad(audio_codes[:, :max_mel_len], (0, 1), value=self.stop_audio_token)
|
||||||
|
|
||||||
# Pad mel codes with stop_audio_token
|
# Pad mel codes with stop_audio_token
|
||||||
audio_codes = self.set_mel_padding(audio_codes, code_lengths - 3) # -3 to get the real code lengths without consider start and stop tokens that was not added yet
|
audio_codes = self.set_mel_padding(
|
||||||
|
audio_codes, code_lengths - 3
|
||||||
|
) # -3 to get the real code lengths without consider start and stop tokens that was not added yet
|
||||||
|
|
||||||
# Build input and target tensors
|
# Build input and target tensors
|
||||||
# Prepend start token to inputs and append stop token to targets
|
# Prepend start token to inputs and append stop token to targets
|
||||||
|
|
|
@ -1,23 +1,22 @@
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import torch
|
|
||||||
import pypinyin
|
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
|
||||||
|
import pypinyin
|
||||||
|
import torch
|
||||||
from hangul_romanize import Transliter
|
from hangul_romanize import Transliter
|
||||||
from hangul_romanize.rule import academic
|
from hangul_romanize.rule import academic
|
||||||
from num2words import num2words
|
from num2words import num2words
|
||||||
|
from spacy.lang.ar import Arabic
|
||||||
|
from spacy.lang.en import English
|
||||||
|
from spacy.lang.es import Spanish
|
||||||
|
from spacy.lang.ja import Japanese
|
||||||
|
from spacy.lang.zh import Chinese
|
||||||
from tokenizers import Tokenizer
|
from tokenizers import Tokenizer
|
||||||
|
|
||||||
from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
|
from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
|
||||||
|
|
||||||
from spacy.lang.en import English
|
|
||||||
from spacy.lang.zh import Chinese
|
|
||||||
from spacy.lang.ja import Japanese
|
|
||||||
from spacy.lang.ar import Arabic
|
|
||||||
from spacy.lang.es import Spanish
|
|
||||||
|
|
||||||
|
|
||||||
def get_spacy_lang(lang):
|
def get_spacy_lang(lang):
|
||||||
if lang == "zh":
|
if lang == "zh":
|
||||||
|
@ -32,6 +31,7 @@ def get_spacy_lang(lang):
|
||||||
# For most languages, Enlish does the job
|
# For most languages, Enlish does the job
|
||||||
return English()
|
return English()
|
||||||
|
|
||||||
|
|
||||||
def split_sentence(text, lang, text_split_length=250):
|
def split_sentence(text, lang, text_split_length=250):
|
||||||
"""Preprocess the input text"""
|
"""Preprocess the input text"""
|
||||||
text_splits = []
|
text_splits = []
|
||||||
|
@ -67,6 +67,7 @@ def split_sentence(text, lang, text_split_length=250):
|
||||||
|
|
||||||
return text_splits
|
return text_splits
|
||||||
|
|
||||||
|
|
||||||
_whitespace_re = re.compile(r"\s+")
|
_whitespace_re = re.compile(r"\s+")
|
||||||
|
|
||||||
# List of (regular expression, replacement) pairs for abbreviations:
|
# List of (regular expression, replacement) pairs for abbreviations:
|
||||||
|
|
|
@ -563,9 +563,7 @@ class Xtts(BaseTTS):
|
||||||
|
|
||||||
if length_scale != 1.0:
|
if length_scale != 1.0:
|
||||||
gpt_latents = F.interpolate(
|
gpt_latents = F.interpolate(
|
||||||
gpt_latents.transpose(1, 2),
|
gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear"
|
||||||
scale_factor=length_scale,
|
|
||||||
mode="linear"
|
|
||||||
).transpose(1, 2)
|
).transpose(1, 2)
|
||||||
|
|
||||||
gpt_latents_list.append(gpt_latents.cpu())
|
gpt_latents_list.append(gpt_latents.cpu())
|
||||||
|
@ -675,9 +673,7 @@ class Xtts(BaseTTS):
|
||||||
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
|
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
|
||||||
if length_scale != 1.0:
|
if length_scale != 1.0:
|
||||||
gpt_latents = F.interpolate(
|
gpt_latents = F.interpolate(
|
||||||
gpt_latents.transpose(1, 2),
|
gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear"
|
||||||
scale_factor=length_scale,
|
|
||||||
mode="linear"
|
|
||||||
).transpose(1, 2)
|
).transpose(1, 2)
|
||||||
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
||||||
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
|
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
|
||||||
|
|
|
@ -186,7 +186,7 @@ def test_xtts_v2_streaming():
|
||||||
"en",
|
"en",
|
||||||
gpt_cond_latent,
|
gpt_cond_latent,
|
||||||
speaker_embedding,
|
speaker_embedding,
|
||||||
speed=1.5
|
speed=1.5,
|
||||||
)
|
)
|
||||||
wav_chuncks = []
|
wav_chuncks = []
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
|
@ -198,7 +198,7 @@ def test_xtts_v2_streaming():
|
||||||
"en",
|
"en",
|
||||||
gpt_cond_latent,
|
gpt_cond_latent,
|
||||||
speaker_embedding,
|
speaker_embedding,
|
||||||
speed=0.66
|
speed=0.66,
|
||||||
)
|
)
|
||||||
wav_chuncks = []
|
wav_chuncks = []
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
|
|
Loading…
Reference in New Issue