From 55801049de4132ee80b7c35c3c7288bad3eff71b Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Mon, 6 Nov 2023 11:23:38 +0100 Subject: [PATCH] Code linting --- TTS/tts/models/xtts.py | 45 +++++++++----------- tests/xtts_tests/test_xtts_v2-0_gpt_train.py | 8 ++-- 2 files changed, 22 insertions(+), 31 deletions(-) diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 4877e86d..ecb31a9a 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -381,33 +381,26 @@ class Xtts(BaseTTS): audio_22k = torchaudio.functional.resample(audio, sr, 22050) audio_22k = audio_22k[:, : 22050 * length] if self.args.gpt_use_perceiver_resampler: - mel = wav_to_mel_cloning( - audio_22k, - mel_norms=self.mel_stats.cpu(), - n_fft=2048, - hop_length=256, - win_length=1024, - power=2, - normalized=False, - sample_rate=22050, - f_min=0, - f_max=8000, - n_mels=80, - ) + n_fft = 2048 + hop_length = 256 + win_length = 1024 else: - mel = wav_to_mel_cloning( - audio_22k, - mel_norms=self.mel_stats.cpu(), - n_fft=4096, - hop_length=1024, - win_length=4096, - power=2, - normalized=False, - sample_rate=22050, - f_min=0, - f_max=8000, - n_mels=80, - ) + n_fft = 4096 + hop_length = 1024 + win_length = 4096 + mel = wav_to_mel_cloning( + audio_22k, + mel_norms=self.mel_stats.cpu(), + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + power=2, + normalized=False, + sample_rate=22050, + f_min=0, + f_max=8000, + n_mels=80, + ) cond_latent = self.gpt.get_style_emb(mel.to(self.device)) return cond_latent.transpose(1, 2) diff --git a/tests/xtts_tests/test_xtts_v2-0_gpt_train.py b/tests/xtts_tests/test_xtts_v2-0_gpt_train.py index c98ae804..81d1c4e5 100644 --- a/tests/xtts_tests/test_xtts_v2-0_gpt_train.py +++ b/tests/xtts_tests/test_xtts_v2-0_gpt_train.py @@ -27,16 +27,16 @@ PROJECT_NAME = "XTTS_trainer" DASHBOARD_LOGGER = "tensorboard" LOGGER_URI = None -# Set here the path that the checkpoints will be saved. Default: ./run/training/ OUT_PATH = os.path.join(get_tests_output_path(), "train_outputs", "xtts_tests") os.makedirs(OUT_PATH, exist_ok=True) # Create DVAE checkpoint and mel_norms on test time # DVAE parameters: For the training we need the dvae to extract the dvae tokens, given that you must provide the paths for this model DVAE_CHECKPOINT = os.path.join(OUT_PATH, "dvae.pth") # DVAE checkpoint +# Mel spectrogram norms, required for dvae mel spectrogram extraction MEL_NORM_FILE = os.path.join( OUT_PATH, "mel_stats.pth" -) # Mel spectrogram norms, required for dvae mel spectrogram extraction +) dvae = DiscreteVAE( channels=80, normalization=None, @@ -99,9 +99,7 @@ config = GPTTrainerConfig( model_args=model_args, run_name=RUN_NAME, project_name=PROJECT_NAME, - run_description=""" - GPT XTTS training - """, + run_description="GPT XTTS training", dashboard_logger=DASHBOARD_LOGGER, logger_uri=LOGGER_URI, audio=audio_config,