mirror of https://github.com/coqui-ai/TTS.git
add: check_config for speaker_encoder
This commit is contained in:
parent
0ea7f4e2bd
commit
df4caec4b7
|
@ -14,6 +14,7 @@ from TTS.speaker_encoder.dataset import MyDataset
|
|||
from TTS.speaker_encoder.generic_utils import save_best_model
|
||||
from TTS.speaker_encoder.losses import GE2ELoss, AngleProtoLoss
|
||||
from TTS.speaker_encoder.model import SpeakerEncoder
|
||||
from TTS.speaker_encoder.utils import check_config_speaker_encoder
|
||||
from TTS.speaker_encoder.visual import plot_embeddings
|
||||
from TTS.tts.datasets.preprocess import load_meta_data
|
||||
from TTS.utils.generic_utils import (
|
||||
|
@ -235,6 +236,7 @@ if __name__ == '__main__':
|
|||
|
||||
# setup output paths and read configs
|
||||
c = load_config(args.config_path)
|
||||
check_config_speaker_encoder(c)
|
||||
_ = os.path.dirname(os.path.realpath(__file__))
|
||||
if args.data_path != '':
|
||||
c.data_path = args.data_path
|
||||
|
|
|
@ -17,7 +17,7 @@ from TTS.tts.layers.losses import TacotronLoss
|
|||
from TTS.tts.utils.distribute import (DistributedSampler,
|
||||
apply_gradient_allreduce,
|
||||
init_distributed, reduce_tensor)
|
||||
from TTS.tts.utils.generic_utils import check_config, setup_model
|
||||
from TTS.tts.utils.generic_utils import setup_model, check_config_tts
|
||||
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.speakers import (get_speakers, load_speaker_mapping,
|
||||
|
@ -670,7 +670,7 @@ if __name__ == '__main__':
|
|||
|
||||
# setup output paths and read configs
|
||||
c = load_config(args.config_path)
|
||||
check_config(c)
|
||||
check_config_tts(c)
|
||||
_ = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
if c.apex_amp_level == 'O1':
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
from TTS.utils.generic_utils import check_argument
|
||||
|
||||
|
||||
def check_config_speaker_encoder(c):
|
||||
"""Check the config.json file of the speaker encoder"""
|
||||
check_argument('run_name', c, restricted=True, val_type=str)
|
||||
check_argument('run_description', c, val_type=str)
|
||||
|
||||
# audio processing parameters
|
||||
check_argument('audio', c, restricted=True, val_type=dict)
|
||||
check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056)
|
||||
check_argument('fft_size', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058)
|
||||
check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000)
|
||||
check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length')
|
||||
check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length')
|
||||
check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1)
|
||||
check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10)
|
||||
check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000)
|
||||
check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5)
|
||||
check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000)
|
||||
|
||||
# training parameters
|
||||
check_argument('loss', c, enum_list=['ge2e', 'angleproto'], restricted=True, val_type=str)
|
||||
check_argument('grad_clip', c, restricted=True, val_type=float)
|
||||
check_argument('epochs', c, restricted=True, val_type=int, min_val=1)
|
||||
check_argument('lr', c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument('lr_decay', c, restricted=True, val_type=bool)
|
||||
check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0)
|
||||
check_argument('tb_model_param_stats', c, restricted=True, val_type=bool)
|
||||
check_argument('num_speakers_in_batch', c, restricted=True, val_type=int)
|
||||
check_argument('num_loader_workers', c, restricted=True, val_type=int)
|
||||
check_argument('wd', c, restricted=True, val_type=float, min_val=0.0, max_val=1.0)
|
||||
|
||||
# checkpoint and output parameters
|
||||
check_argument('steps_plot_stats', c, restricted=True, val_type=int)
|
||||
check_argument('checkpoint', c, restricted=True, val_type=bool)
|
||||
check_argument('save_step', c, restricted=True, val_type=int)
|
||||
check_argument('print_step', c, restricted=True, val_type=int)
|
||||
check_argument('output_path', c, restricted=True, val_type=str)
|
||||
|
||||
# model parameters
|
||||
check_argument('model', c, restricted=True, val_type=dict)
|
||||
check_argument('input_dim', c['model'], restricted=True, val_type=int)
|
||||
check_argument('proj_dim', c['model'], restricted=True, val_type=int)
|
||||
check_argument('lstm_dim', c['model'], restricted=True, val_type=int)
|
||||
check_argument('num_lstm_layers', c['model'], restricted=True, val_type=int)
|
||||
check_argument('use_lstm_with_projection', c['model'], restricted=True, val_type=bool)
|
||||
|
||||
# in-memory storage parameters
|
||||
check_argument('storage', c, restricted=True, val_type=dict)
|
||||
check_argument('sample_from_storage_p', c['storage'], restricted=True, val_type=float, min_val=0.0, max_val=1.0)
|
||||
check_argument('storage_size', c['storage'], restricted=True, val_type=int, min_val=1, max_val=100)
|
||||
check_argument('additive_noise', c['storage'], restricted=True, val_type=float, min_val=0.0, max_val=1.0)
|
||||
|
||||
# datasets - checking only the first entry
|
||||
check_argument('datasets', c, restricted=True, val_type=list)
|
||||
for dataset_entry in c['datasets']:
|
||||
check_argument('name', dataset_entry, restricted=True, val_type=str)
|
||||
check_argument('path', dataset_entry, restricted=True, val_type=str)
|
||||
check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list])
|
||||
check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str)
|
|
@ -100,7 +100,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
|||
return model
|
||||
|
||||
|
||||
def check_config(c):
|
||||
def check_config_tts(c):
|
||||
check_argument('model', c, enum_list=['tacotron', 'tacotron2'], restricted=True, val_type=str)
|
||||
check_argument('run_name', c, restricted=True, val_type=str)
|
||||
check_argument('run_description', c, val_type=str)
|
||||
|
@ -140,12 +140,6 @@ def check_config(c):
|
|||
check_argument('do_trim_silence', c['audio'], restricted=True, val_type=bool)
|
||||
check_argument('trim_db', c['audio'], restricted=True, val_type=int)
|
||||
|
||||
# storage parameters (only for speaker encoder)
|
||||
if 'storage' in c.keys():
|
||||
check_argument('sample_from_storage_p', c['storage'], restricted=False, val_type=float, min_val=0.0, max_val=1.0)
|
||||
check_argument('storage_size', c['storage'], restricted=False, val_type=int, min_val=1, max_val=100)
|
||||
check_argument('additive_noise', c['storage'], restricted=False, val_type=float, min_val=0.0, max_val=1.0)
|
||||
|
||||
# training parameters
|
||||
check_argument('batch_size', c, restricted=True, val_type=int, min_val=1)
|
||||
check_argument('eval_batch_size', c, restricted=True, val_type=int, min_val=1)
|
||||
|
|
|
@ -21,3 +21,4 @@ nose==1.3.7
|
|||
cardboardlint==1.3.0
|
||||
pylint==2.5.3
|
||||
gdown
|
||||
umap
|
Loading…
Reference in New Issue