mirror of https://github.com/coqui-ai/TTS.git
Make style
This commit is contained in:
parent
00b24eeb6e
commit
1a9ca35e14
|
@ -562,21 +562,15 @@ class DPM_Solver:
|
|||
if order == 3:
|
||||
K = steps // 3 + 1
|
||||
if steps % 3 == 0:
|
||||
orders = [
|
||||
3,
|
||||
] * (
|
||||
orders = [3,] * (
|
||||
K - 2
|
||||
) + [2, 1]
|
||||
elif steps % 3 == 1:
|
||||
orders = [
|
||||
3,
|
||||
] * (
|
||||
orders = [3,] * (
|
||||
K - 1
|
||||
) + [1]
|
||||
else:
|
||||
orders = [
|
||||
3,
|
||||
] * (
|
||||
orders = [3,] * (
|
||||
K - 1
|
||||
) + [2]
|
||||
elif order == 2:
|
||||
|
@ -587,9 +581,7 @@ class DPM_Solver:
|
|||
] * K
|
||||
else:
|
||||
K = steps // 2 + 1
|
||||
orders = [
|
||||
2,
|
||||
] * (
|
||||
orders = [2,] * (
|
||||
K - 1
|
||||
) + [1]
|
||||
elif order == 1:
|
||||
|
@ -1448,10 +1440,7 @@ class DPM_Solver:
|
|||
model_prev_list[-1] = self.model_fn(x, t)
|
||||
elif method in ["singlestep", "singlestep_fixed"]:
|
||||
if method == "singlestep":
|
||||
(
|
||||
timesteps_outer,
|
||||
orders,
|
||||
) = self.get_orders_and_timesteps_for_singlestep_solver(
|
||||
(timesteps_outer, orders,) = self.get_orders_and_timesteps_for_singlestep_solver(
|
||||
steps=steps,
|
||||
order=order,
|
||||
skip_type=skip_type,
|
||||
|
|
|
@ -1,16 +1,14 @@
|
|||
# Adapted from https://github.com/lucidrains/naturalspeech2-pytorch/blob/659bec7f7543e7747e809e950cc2f84242fbeec7/naturalspeech2_pytorch/naturalspeech2_pytorch.py#L532
|
||||
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from collections import namedtuple
|
||||
from functools import wraps
|
||||
from packaging import version
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
from packaging import version
|
||||
from torch import einsum, nn
|
||||
|
||||
|
||||
def exists(val):
|
||||
|
|
|
@ -4,13 +4,12 @@ import re
|
|||
|
||||
import pypinyin
|
||||
import torch
|
||||
from hangul_romanize import Transliter
|
||||
from hangul_romanize.rule import academic
|
||||
from num2words import num2words
|
||||
from tokenizers import Tokenizer
|
||||
|
||||
from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
|
||||
from hangul_romanize import Transliter
|
||||
from hangul_romanize.rule import academic
|
||||
|
||||
|
||||
_whitespace_re = re.compile(r"\s+")
|
||||
|
||||
|
|
|
@ -252,12 +252,7 @@ class BaseTacotron(BaseTTS):
|
|||
|
||||
def compute_capacitron_VAE_embedding(self, inputs, reference_mel_info, text_info=None, speaker_embedding=None):
|
||||
"""Capacitron Variational Autoencoder"""
|
||||
(
|
||||
VAE_outputs,
|
||||
posterior_distribution,
|
||||
prior_distribution,
|
||||
capacitron_beta,
|
||||
) = self.capacitron_vae_layer(
|
||||
(VAE_outputs, posterior_distribution, prior_distribution, capacitron_beta,) = self.capacitron_vae_layer(
|
||||
reference_mel_info,
|
||||
text_info,
|
||||
speaker_embedding, # pylint: disable=not-callable
|
||||
|
|
|
@ -676,12 +676,7 @@ class Tortoise(BaseTTS):
|
|||
), "Too much text provided. Break the text up into separate segments and re-try inference."
|
||||
|
||||
if voice_samples is not None:
|
||||
(
|
||||
auto_conditioning,
|
||||
diffusion_conditioning,
|
||||
_,
|
||||
_,
|
||||
) = self.get_conditioning_latents(
|
||||
(auto_conditioning, diffusion_conditioning, _, _,) = self.get_conditioning_latents(
|
||||
voice_samples,
|
||||
return_mels=True,
|
||||
latent_averaging_mode=latent_averaging_mode,
|
||||
|
|
|
@ -7,7 +7,6 @@ from TTS.tts.datasets import load_tts_samples
|
|||
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
|
||||
from TTS.utils.manage import ModelManager
|
||||
|
||||
|
||||
# Logging parameters
|
||||
RUN_NAME = "GPT_XTTS_v2.0_LJSpeech_FT"
|
||||
PROJECT_NAME = "XTTS_trainer"
|
||||
|
|
|
@ -34,9 +34,7 @@ os.makedirs(OUT_PATH, exist_ok=True)
|
|||
# DVAE parameters: For the training we need the dvae to extract the dvae tokens, given that you must provide the paths for this model
|
||||
DVAE_CHECKPOINT = os.path.join(OUT_PATH, "dvae.pth") # DVAE checkpoint
|
||||
# Mel spectrogram norms, required for dvae mel spectrogram extraction
|
||||
MEL_NORM_FILE = os.path.join(
|
||||
OUT_PATH, "mel_stats.pth"
|
||||
)
|
||||
MEL_NORM_FILE = os.path.join(OUT_PATH, "mel_stats.pth")
|
||||
dvae = DiscreteVAE(
|
||||
channels=80,
|
||||
normalization=None,
|
||||
|
|
Loading…
Reference in New Issue