mirror of https://github.com/coqui-ai/TTS.git
Add max_audio_length parameter
This commit is contained in:
parent
ceb8b05abe
commit
1a60767d83
|
@ -8,7 +8,7 @@ from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrai
|
||||||
from TTS.utils.manage import ModelManager
|
from TTS.utils.manage import ModelManager
|
||||||
|
|
||||||
|
|
||||||
def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path):
|
def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path, max_audio_length=255995):
|
||||||
# Logging parameters
|
# Logging parameters
|
||||||
RUN_NAME = "GPT_XTTS_FT"
|
RUN_NAME = "GPT_XTTS_FT"
|
||||||
PROJECT_NAME = "XTTS_trainer"
|
PROJECT_NAME = "XTTS_trainer"
|
||||||
|
@ -79,7 +79,7 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
|
||||||
max_conditioning_length=132300, # 6 secs
|
max_conditioning_length=132300, # 6 secs
|
||||||
min_conditioning_length=66150, # 3 secs
|
min_conditioning_length=66150, # 3 secs
|
||||||
debug_loading_failures=False,
|
debug_loading_failures=False,
|
||||||
max_wav_length=255995, # ~11.6 seconds
|
max_wav_length=max_audio_length, # ~11.6 seconds
|
||||||
max_text_length=200,
|
max_text_length=200,
|
||||||
mel_norm_file=MEL_NORM_FILE,
|
mel_norm_file=MEL_NORM_FILE,
|
||||||
dvae_checkpoint=DVAE_CHECKPOINT,
|
dvae_checkpoint=DVAE_CHECKPOINT,
|
||||||
|
|
|
@ -147,6 +147,13 @@ if __name__ == "__main__":
|
||||||
help="Grad accumulation steps. Default: 1",
|
help="Grad accumulation steps. Default: 1",
|
||||||
default=1,
|
default=1,
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_audio_length",
|
||||||
|
type=int,
|
||||||
|
help="Max permitted audio size in seconds. Default: 11",
|
||||||
|
default=11,
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
|
@ -250,6 +257,13 @@ if __name__ == "__main__":
|
||||||
step=1,
|
step=1,
|
||||||
value=args.grad_acumm,
|
value=args.grad_acumm,
|
||||||
)
|
)
|
||||||
|
max_audio_length = gr.Slider(
|
||||||
|
label="Max permitted audio size in seconds:",
|
||||||
|
minimum=2,
|
||||||
|
maximum=20,
|
||||||
|
step=1,
|
||||||
|
value=args.max_audio_length,
|
||||||
|
)
|
||||||
progress_train = gr.Label(
|
progress_train = gr.Label(
|
||||||
label="Progress:"
|
label="Progress:"
|
||||||
)
|
)
|
||||||
|
@ -260,12 +274,14 @@ if __name__ == "__main__":
|
||||||
demo.load(read_logs, None, logs_tts_train, every=1)
|
demo.load(read_logs, None, logs_tts_train, every=1)
|
||||||
train_btn = gr.Button(value="Step 2 - Run the training")
|
train_btn = gr.Button(value="Step 2 - Run the training")
|
||||||
|
|
||||||
def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path):
|
def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length):
|
||||||
clear_gpu_cache()
|
clear_gpu_cache()
|
||||||
if not train_csv or not eval_csv:
|
if not train_csv or not eval_csv:
|
||||||
return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", ""
|
return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", ""
|
||||||
try:
|
try:
|
||||||
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path)
|
# convert seconds to waveform frames
|
||||||
|
max_audio_length = int(max_audio_length * 22050)
|
||||||
|
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path, max_audio_length=max_audio_length)
|
||||||
except:
|
except:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
error = traceback.format_exc()
|
error = traceback.format_exc()
|
||||||
|
@ -280,7 +296,6 @@ if __name__ == "__main__":
|
||||||
clear_gpu_cache()
|
clear_gpu_cache()
|
||||||
return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint, speaker_wav
|
return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint, speaker_wav
|
||||||
|
|
||||||
|
|
||||||
with gr.Tab("3 - Inference"):
|
with gr.Tab("3 - Inference"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column() as col1:
|
with gr.Column() as col1:
|
||||||
|
@ -367,6 +382,7 @@ if __name__ == "__main__":
|
||||||
batch_size,
|
batch_size,
|
||||||
grad_acumm,
|
grad_acumm,
|
||||||
out_path,
|
out_path,
|
||||||
|
max_audio_length,
|
||||||
],
|
],
|
||||||
outputs=[progress_train, xtts_config, xtts_vocab, xtts_checkpoint, speaker_reference_audio],
|
outputs=[progress_train, xtts_config, xtts_vocab, xtts_checkpoint, speaker_reference_audio],
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue