Code linting

This commit is contained in:
Eren G??lge 2023-11-06 11:23:38 +01:00
parent b094979f1a
commit 55801049de
2 changed files with 22 additions and 31 deletions

View File

@ -381,33 +381,26 @@ class Xtts(BaseTTS):
audio_22k = torchaudio.functional.resample(audio, sr, 22050)
audio_22k = audio_22k[:, : 22050 * length]
if self.args.gpt_use_perceiver_resampler:
mel = wav_to_mel_cloning(
audio_22k,
mel_norms=self.mel_stats.cpu(),
n_fft=2048,
hop_length=256,
win_length=1024,
power=2,
normalized=False,
sample_rate=22050,
f_min=0,
f_max=8000,
n_mels=80,
)
n_fft = 2048
hop_length = 256
win_length = 1024
else:
mel = wav_to_mel_cloning(
audio_22k,
mel_norms=self.mel_stats.cpu(),
n_fft=4096,
hop_length=1024,
win_length=4096,
power=2,
normalized=False,
sample_rate=22050,
f_min=0,
f_max=8000,
n_mels=80,
)
n_fft = 4096
hop_length = 1024
win_length = 4096
mel = wav_to_mel_cloning(
audio_22k,
mel_norms=self.mel_stats.cpu(),
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
power=2,
normalized=False,
sample_rate=22050,
f_min=0,
f_max=8000,
n_mels=80,
)
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
return cond_latent.transpose(1, 2)

View File

@ -27,16 +27,16 @@ PROJECT_NAME = "XTTS_trainer"
DASHBOARD_LOGGER = "tensorboard"
LOGGER_URI = None
# Set here the path that the checkpoints will be saved. Default: ./run/training/
OUT_PATH = os.path.join(get_tests_output_path(), "train_outputs", "xtts_tests")
os.makedirs(OUT_PATH, exist_ok=True)
# Create DVAE checkpoint and mel_norms on test time
# DVAE parameters: For the training we need the dvae to extract the dvae tokens, given that you must provide the paths for this model
DVAE_CHECKPOINT = os.path.join(OUT_PATH, "dvae.pth") # DVAE checkpoint
# Mel spectrogram norms, required for dvae mel spectrogram extraction
MEL_NORM_FILE = os.path.join(
OUT_PATH, "mel_stats.pth"
) # Mel spectrogram norms, required for dvae mel spectrogram extraction
)
dvae = DiscreteVAE(
channels=80,
normalization=None,
@ -99,9 +99,7 @@ config = GPTTrainerConfig(
model_args=model_args,
run_name=RUN_NAME,
project_name=PROJECT_NAME,
run_description="""
GPT XTTS training
""",
run_description="GPT XTTS training",
dashboard_logger=DASHBOARD_LOGGER,
logger_uri=LOGGER_URI,
audio=audio_config,