mirror of https://github.com/coqui-ai/TTS.git
Code linting
This commit is contained in:
parent
b094979f1a
commit
55801049de
|
@ -381,33 +381,26 @@ class Xtts(BaseTTS):
|
||||||
audio_22k = torchaudio.functional.resample(audio, sr, 22050)
|
audio_22k = torchaudio.functional.resample(audio, sr, 22050)
|
||||||
audio_22k = audio_22k[:, : 22050 * length]
|
audio_22k = audio_22k[:, : 22050 * length]
|
||||||
if self.args.gpt_use_perceiver_resampler:
|
if self.args.gpt_use_perceiver_resampler:
|
||||||
mel = wav_to_mel_cloning(
|
n_fft = 2048
|
||||||
audio_22k,
|
hop_length = 256
|
||||||
mel_norms=self.mel_stats.cpu(),
|
win_length = 1024
|
||||||
n_fft=2048,
|
|
||||||
hop_length=256,
|
|
||||||
win_length=1024,
|
|
||||||
power=2,
|
|
||||||
normalized=False,
|
|
||||||
sample_rate=22050,
|
|
||||||
f_min=0,
|
|
||||||
f_max=8000,
|
|
||||||
n_mels=80,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
mel = wav_to_mel_cloning(
|
n_fft = 4096
|
||||||
audio_22k,
|
hop_length = 1024
|
||||||
mel_norms=self.mel_stats.cpu(),
|
win_length = 4096
|
||||||
n_fft=4096,
|
mel = wav_to_mel_cloning(
|
||||||
hop_length=1024,
|
audio_22k,
|
||||||
win_length=4096,
|
mel_norms=self.mel_stats.cpu(),
|
||||||
power=2,
|
n_fft=n_fft,
|
||||||
normalized=False,
|
hop_length=hop_length,
|
||||||
sample_rate=22050,
|
win_length=win_length,
|
||||||
f_min=0,
|
power=2,
|
||||||
f_max=8000,
|
normalized=False,
|
||||||
n_mels=80,
|
sample_rate=22050,
|
||||||
)
|
f_min=0,
|
||||||
|
f_max=8000,
|
||||||
|
n_mels=80,
|
||||||
|
)
|
||||||
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
|
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
|
||||||
return cond_latent.transpose(1, 2)
|
return cond_latent.transpose(1, 2)
|
||||||
|
|
||||||
|
|
|
@ -27,16 +27,16 @@ PROJECT_NAME = "XTTS_trainer"
|
||||||
DASHBOARD_LOGGER = "tensorboard"
|
DASHBOARD_LOGGER = "tensorboard"
|
||||||
LOGGER_URI = None
|
LOGGER_URI = None
|
||||||
|
|
||||||
# Set here the path that the checkpoints will be saved. Default: ./run/training/
|
|
||||||
OUT_PATH = os.path.join(get_tests_output_path(), "train_outputs", "xtts_tests")
|
OUT_PATH = os.path.join(get_tests_output_path(), "train_outputs", "xtts_tests")
|
||||||
os.makedirs(OUT_PATH, exist_ok=True)
|
os.makedirs(OUT_PATH, exist_ok=True)
|
||||||
|
|
||||||
# Create DVAE checkpoint and mel_norms on test time
|
# Create DVAE checkpoint and mel_norms on test time
|
||||||
# DVAE parameters: For the training we need the dvae to extract the dvae tokens, given that you must provide the paths for this model
|
# DVAE parameters: For the training we need the dvae to extract the dvae tokens, given that you must provide the paths for this model
|
||||||
DVAE_CHECKPOINT = os.path.join(OUT_PATH, "dvae.pth") # DVAE checkpoint
|
DVAE_CHECKPOINT = os.path.join(OUT_PATH, "dvae.pth") # DVAE checkpoint
|
||||||
|
# Mel spectrogram norms, required for dvae mel spectrogram extraction
|
||||||
MEL_NORM_FILE = os.path.join(
|
MEL_NORM_FILE = os.path.join(
|
||||||
OUT_PATH, "mel_stats.pth"
|
OUT_PATH, "mel_stats.pth"
|
||||||
) # Mel spectrogram norms, required for dvae mel spectrogram extraction
|
)
|
||||||
dvae = DiscreteVAE(
|
dvae = DiscreteVAE(
|
||||||
channels=80,
|
channels=80,
|
||||||
normalization=None,
|
normalization=None,
|
||||||
|
@ -99,9 +99,7 @@ config = GPTTrainerConfig(
|
||||||
model_args=model_args,
|
model_args=model_args,
|
||||||
run_name=RUN_NAME,
|
run_name=RUN_NAME,
|
||||||
project_name=PROJECT_NAME,
|
project_name=PROJECT_NAME,
|
||||||
run_description="""
|
run_description="GPT XTTS training",
|
||||||
GPT XTTS training
|
|
||||||
""",
|
|
||||||
dashboard_logger=DASHBOARD_LOGGER,
|
dashboard_logger=DASHBOARD_LOGGER,
|
||||||
logger_uri=LOGGER_URI,
|
logger_uri=LOGGER_URI,
|
||||||
audio=audio_config,
|
audio=audio_config,
|
||||||
|
|
Loading…
Reference in New Issue