Code linting

This commit is contained in:
Eren G??lge 2023-11-06 11:23:38 +01:00
parent b094979f1a
commit 55801049de
2 changed files with 22 additions and 31 deletions

View File

@ -381,33 +381,26 @@ class Xtts(BaseTTS):
audio_22k = torchaudio.functional.resample(audio, sr, 22050) audio_22k = torchaudio.functional.resample(audio, sr, 22050)
audio_22k = audio_22k[:, : 22050 * length] audio_22k = audio_22k[:, : 22050 * length]
if self.args.gpt_use_perceiver_resampler: if self.args.gpt_use_perceiver_resampler:
mel = wav_to_mel_cloning( n_fft = 2048
audio_22k, hop_length = 256
mel_norms=self.mel_stats.cpu(), win_length = 1024
n_fft=2048,
hop_length=256,
win_length=1024,
power=2,
normalized=False,
sample_rate=22050,
f_min=0,
f_max=8000,
n_mels=80,
)
else: else:
mel = wav_to_mel_cloning( n_fft = 4096
audio_22k, hop_length = 1024
mel_norms=self.mel_stats.cpu(), win_length = 4096
n_fft=4096, mel = wav_to_mel_cloning(
hop_length=1024, audio_22k,
win_length=4096, mel_norms=self.mel_stats.cpu(),
power=2, n_fft=n_fft,
normalized=False, hop_length=hop_length,
sample_rate=22050, win_length=win_length,
f_min=0, power=2,
f_max=8000, normalized=False,
n_mels=80, sample_rate=22050,
) f_min=0,
f_max=8000,
n_mels=80,
)
cond_latent = self.gpt.get_style_emb(mel.to(self.device)) cond_latent = self.gpt.get_style_emb(mel.to(self.device))
return cond_latent.transpose(1, 2) return cond_latent.transpose(1, 2)

View File

@ -27,16 +27,16 @@ PROJECT_NAME = "XTTS_trainer"
DASHBOARD_LOGGER = "tensorboard" DASHBOARD_LOGGER = "tensorboard"
LOGGER_URI = None LOGGER_URI = None
# Set here the path that the checkpoints will be saved. Default: ./run/training/
OUT_PATH = os.path.join(get_tests_output_path(), "train_outputs", "xtts_tests") OUT_PATH = os.path.join(get_tests_output_path(), "train_outputs", "xtts_tests")
os.makedirs(OUT_PATH, exist_ok=True) os.makedirs(OUT_PATH, exist_ok=True)
# Create DVAE checkpoint and mel_norms on test time # Create DVAE checkpoint and mel_norms on test time
# DVAE parameters: For the training we need the dvae to extract the dvae tokens, given that you must provide the paths for this model # DVAE parameters: For the training we need the dvae to extract the dvae tokens, given that you must provide the paths for this model
DVAE_CHECKPOINT = os.path.join(OUT_PATH, "dvae.pth") # DVAE checkpoint DVAE_CHECKPOINT = os.path.join(OUT_PATH, "dvae.pth") # DVAE checkpoint
# Mel spectrogram norms, required for dvae mel spectrogram extraction
MEL_NORM_FILE = os.path.join( MEL_NORM_FILE = os.path.join(
OUT_PATH, "mel_stats.pth" OUT_PATH, "mel_stats.pth"
) # Mel spectrogram norms, required for dvae mel spectrogram extraction )
dvae = DiscreteVAE( dvae = DiscreteVAE(
channels=80, channels=80,
normalization=None, normalization=None,
@ -99,9 +99,7 @@ config = GPTTrainerConfig(
model_args=model_args, model_args=model_args,
run_name=RUN_NAME, run_name=RUN_NAME,
project_name=PROJECT_NAME, project_name=PROJECT_NAME,
run_description=""" run_description="GPT XTTS training",
GPT XTTS training
""",
dashboard_logger=DASHBOARD_LOGGER, dashboard_logger=DASHBOARD_LOGGER,
logger_uri=LOGGER_URI, logger_uri=LOGGER_URI,
audio=audio_config, audio=audio_config,