mirror of https://github.com/coqui-ai/TTS.git
Code linting
This commit is contained in:
parent
b094979f1a
commit
55801049de
|
@ -381,33 +381,26 @@ class Xtts(BaseTTS):
|
|||
audio_22k = torchaudio.functional.resample(audio, sr, 22050)
|
||||
audio_22k = audio_22k[:, : 22050 * length]
|
||||
if self.args.gpt_use_perceiver_resampler:
|
||||
mel = wav_to_mel_cloning(
|
||||
audio_22k,
|
||||
mel_norms=self.mel_stats.cpu(),
|
||||
n_fft=2048,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
power=2,
|
||||
normalized=False,
|
||||
sample_rate=22050,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
n_mels=80,
|
||||
)
|
||||
n_fft = 2048
|
||||
hop_length = 256
|
||||
win_length = 1024
|
||||
else:
|
||||
mel = wav_to_mel_cloning(
|
||||
audio_22k,
|
||||
mel_norms=self.mel_stats.cpu(),
|
||||
n_fft=4096,
|
||||
hop_length=1024,
|
||||
win_length=4096,
|
||||
power=2,
|
||||
normalized=False,
|
||||
sample_rate=22050,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
n_mels=80,
|
||||
)
|
||||
n_fft = 4096
|
||||
hop_length = 1024
|
||||
win_length = 4096
|
||||
mel = wav_to_mel_cloning(
|
||||
audio_22k,
|
||||
mel_norms=self.mel_stats.cpu(),
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
power=2,
|
||||
normalized=False,
|
||||
sample_rate=22050,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
n_mels=80,
|
||||
)
|
||||
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
|
||||
return cond_latent.transpose(1, 2)
|
||||
|
||||
|
|
|
@ -27,16 +27,16 @@ PROJECT_NAME = "XTTS_trainer"
|
|||
DASHBOARD_LOGGER = "tensorboard"
|
||||
LOGGER_URI = None
|
||||
|
||||
# Set here the path that the checkpoints will be saved. Default: ./run/training/
|
||||
OUT_PATH = os.path.join(get_tests_output_path(), "train_outputs", "xtts_tests")
|
||||
os.makedirs(OUT_PATH, exist_ok=True)
|
||||
|
||||
# Create DVAE checkpoint and mel_norms on test time
|
||||
# DVAE parameters: For the training we need the dvae to extract the dvae tokens, given that you must provide the paths for this model
|
||||
DVAE_CHECKPOINT = os.path.join(OUT_PATH, "dvae.pth") # DVAE checkpoint
|
||||
# Mel spectrogram norms, required for dvae mel spectrogram extraction
|
||||
MEL_NORM_FILE = os.path.join(
|
||||
OUT_PATH, "mel_stats.pth"
|
||||
) # Mel spectrogram norms, required for dvae mel spectrogram extraction
|
||||
)
|
||||
dvae = DiscreteVAE(
|
||||
channels=80,
|
||||
normalization=None,
|
||||
|
@ -99,9 +99,7 @@ config = GPTTrainerConfig(
|
|||
model_args=model_args,
|
||||
run_name=RUN_NAME,
|
||||
project_name=PROJECT_NAME,
|
||||
run_description="""
|
||||
GPT XTTS training
|
||||
""",
|
||||
run_description="GPT XTTS training",
|
||||
dashboard_logger=DASHBOARD_LOGGER,
|
||||
logger_uri=LOGGER_URI,
|
||||
audio=audio_config,
|
||||
|
|
Loading…
Reference in New Issue