move bash script based tests to python with coqpit

This commit is contained in:
Eren Gölge 2021-05-05 02:38:19 +02:00
parent 647163397d
commit 35341d5482
5 changed files with 51 additions and 3 deletions

View File

@ -398,7 +398,6 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
"Prior to November 22, 1963.", "Prior to November 22, 1963.",
] ]
# test sentences # test sentences
test_audios = {} test_audios = {}
test_figures = {} test_figures = {}

View File

@ -2,12 +2,13 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
"""Argument parser for training scripts.""" """Argument parser for training scripts."""
import torch
import argparse import argparse
import glob import glob
import os import os
import re import re
import torch
from TTS.tts.utils.text.symbols import parse_symbols from TTS.tts.utils.text.symbols import parse_symbols
from TTS.utils.console_logger import ConsoleLogger from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.generic_utils import create_experiment_folder, get_git_branch from TTS.utils.generic_utils import create_experiment_folder, get_git_branch

View File

@ -111,7 +111,7 @@ def set_init_dict(model_dict, checkpoint_state, c):
# 2. filter out different size layers # 2. filter out different size layers
pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()} pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()}
# 3. skip reinit layers # 3. skip reinit layers
if c.has('reinit_layers') and c.reinit_layers is not None: if c.has("reinit_layers") and c.reinit_layers is not None:
for reinit_layer_name in c.reinit_layers: for reinit_layer_name in c.reinit_layers:
pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k} pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k}
# 4. overwrite entries in the existing state dict # 4. overwrite entries in the existing state dict

View File

View File

@ -0,0 +1,48 @@
import glob
import os
from tests import get_tests_output_path, run_cli
from TTS.config import BaseDatasetConfig
from TTS.tts.configs import SpeedySpeechConfig
config_path = os.path.join(get_tests_output_path(), "test_speedy_speech_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")
config = SpeedySpeechConfig(
batch_size=8,
eval_batch_size=8,
num_loader_workers=0,
num_val_loader_workers=0,
text_cleaner="english_cleaners",
use_phonemes=True,
phoneme_language="en-us",
phoneme_cache_path=os.path.join(get_tests_output_path(), "train_outputs/phoneme_cache/"),
run_eval=True,
test_delay_epochs=-1,
epochs=1,
print_step=1,
print_eval=True,
)
config.audio.do_trim_silence = True
config.audio.trim_db = 60
config.save_json(config_path)
# train the model for one epoch
command_train = (
f"CUDA_VISIBLE_DEVICES='' python TTS/bin/train_speedy_speech.py --config_path {config_path} "
f"--coqpit.output_path {output_path} "
"--coqpit.datasets.0.name ljspeech "
"--coqpit.datasets.0.meta_file_train metadata.csv "
"--coqpit.datasets.0.meta_file_val metadata.csv "
"--coqpit.datasets.0.path tests/data/ljspeech "
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt"
)
run_cli(command_train)
# Find latest folder
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
# restore the model and continue training for one more epoch
command_train = f"CUDA_VISIBLE_DEVICES='' python TTS/bin/train_speedy_speech.py --continue_path {continue_path} "
run_cli(command_train)