mirror of https://github.com/coqui-ai/TTS.git
add check arguments for GlowTTS and multispeaker training bug fix
This commit is contained in:
parent
f632f59f40
commit
b7f9ebd32b
|
@ -15,7 +15,7 @@ from TTS.tts.datasets.TTSDataset import MyDataset
|
||||||
from TTS.tts.layers.losses import GlowTTSLoss
|
from TTS.tts.layers.losses import GlowTTSLoss
|
||||||
from TTS.tts.utils.distribute import (DistributedSampler, init_distributed,
|
from TTS.tts.utils.distribute import (DistributedSampler, init_distributed,
|
||||||
reduce_tensor)
|
reduce_tensor)
|
||||||
from TTS.tts.utils.generic_utils import setup_model
|
from TTS.tts.utils.generic_utils import setup_model, check_config_tts
|
||||||
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
||||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||||
from TTS.tts.utils.speakers import (get_speakers, load_speaker_mapping,
|
from TTS.tts.utils.speakers import (get_speakers, load_speaker_mapping,
|
||||||
|
@ -602,6 +602,7 @@ if __name__ == '__main__':
|
||||||
# setup output paths and read configs
|
# setup output paths and read configs
|
||||||
c = load_config(args.config_path)
|
c = load_config(args.config_path)
|
||||||
# check_config(c)
|
# check_config(c)
|
||||||
|
check_config_tts(c)
|
||||||
_ = os.path.dirname(os.path.realpath(__file__))
|
_ = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
|
||||||
if c.apex_amp_level:
|
if c.apex_amp_level:
|
||||||
|
|
|
@ -37,7 +37,8 @@ class GlowTts(nn.Module):
|
||||||
hidden_channels_enc=None,
|
hidden_channels_enc=None,
|
||||||
hidden_channels_dec=None,
|
hidden_channels_dec=None,
|
||||||
use_encoder_prenet=False,
|
use_encoder_prenet=False,
|
||||||
encoder_type="transformer"):
|
encoder_type="transformer",
|
||||||
|
external_speaker_embedding_dim=None):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_chars = num_chars
|
self.num_chars = num_chars
|
||||||
|
@ -68,6 +69,13 @@ class GlowTts(nn.Module):
|
||||||
self.noise_scale = 0.66
|
self.noise_scale = 0.66
|
||||||
self.length_scale = 1.
|
self.length_scale = 1.
|
||||||
|
|
||||||
|
# if is a multispeaker and c_in_channels is 0, set to 256
|
||||||
|
if num_speakers > 1:
|
||||||
|
if self.c_in_channels == 0 and not external_speaker_embedding_dim:
|
||||||
|
self.c_in_channels = 256
|
||||||
|
elif external_speaker_embedding_dim:
|
||||||
|
self.c_in_channels = external_speaker_embedding_dim
|
||||||
|
|
||||||
self.encoder = Encoder(num_chars,
|
self.encoder = Encoder(num_chars,
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
hidden_channels=hidden_channels,
|
hidden_channels=hidden_channels,
|
||||||
|
@ -94,7 +102,7 @@ class GlowTts(nn.Module):
|
||||||
sigmoid_scale=sigmoid_scale,
|
sigmoid_scale=sigmoid_scale,
|
||||||
c_in_channels=c_in_channels)
|
c_in_channels=c_in_channels)
|
||||||
|
|
||||||
if num_speakers > 1:
|
if num_speakers > 1 and not external_speaker_embedding_dim:
|
||||||
self.emb_g = nn.Embedding(num_speakers, c_in_channels)
|
self.emb_g = nn.Embedding(num_speakers, c_in_channels)
|
||||||
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
|
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
|
||||||
|
|
||||||
|
|
|
@ -129,10 +129,11 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
||||||
use_encoder_prenet=True)
|
use_encoder_prenet=True)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def is_tacotron(c):
|
||||||
|
return False if c['model'] == 'glow_tts' else True
|
||||||
|
|
||||||
def check_config_tts(c):
|
def check_config_tts(c):
|
||||||
check_argument('model', c, enum_list=['tacotron', 'tacotron2'], restricted=True, val_type=str)
|
check_argument('model', c, enum_list=['tacotron', 'tacotron2', 'glow_tts'], restricted=True, val_type=str)
|
||||||
check_argument('run_name', c, restricted=True, val_type=str)
|
check_argument('run_name', c, restricted=True, val_type=str)
|
||||||
check_argument('run_description', c, val_type=str)
|
check_argument('run_description', c, val_type=str)
|
||||||
|
|
||||||
|
@ -195,27 +196,30 @@ def check_config_tts(c):
|
||||||
check_argument('seq_len_norm', c, restricted=True, val_type=bool)
|
check_argument('seq_len_norm', c, restricted=True, val_type=bool)
|
||||||
|
|
||||||
# tacotron prenet
|
# tacotron prenet
|
||||||
check_argument('memory_size', c, restricted=True, val_type=int, min_val=-1)
|
check_argument('memory_size', c, restricted=is_tacotron(c), val_type=int, min_val=-1)
|
||||||
check_argument('prenet_type', c, restricted=True, val_type=str, enum_list=['original', 'bn'])
|
check_argument('prenet_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['original', 'bn'])
|
||||||
check_argument('prenet_dropout', c, restricted=True, val_type=bool)
|
check_argument('prenet_dropout', c, restricted=is_tacotron(c), val_type=bool)
|
||||||
|
|
||||||
# attention
|
# attention
|
||||||
check_argument('attention_type', c, restricted=True, val_type=str, enum_list=['graves', 'original'])
|
check_argument('attention_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['graves', 'original'])
|
||||||
check_argument('attention_heads', c, restricted=True, val_type=int)
|
check_argument('attention_heads', c, restricted=is_tacotron(c), val_type=int)
|
||||||
check_argument('attention_norm', c, restricted=True, val_type=str, enum_list=['sigmoid', 'softmax'])
|
check_argument('attention_norm', c, restricted=is_tacotron(c), val_type=str, enum_list=['sigmoid', 'softmax'])
|
||||||
check_argument('windowing', c, restricted=True, val_type=bool)
|
check_argument('windowing', c, restricted=is_tacotron(c), val_type=bool)
|
||||||
check_argument('use_forward_attn', c, restricted=True, val_type=bool)
|
check_argument('use_forward_attn', c, restricted=is_tacotron(c), val_type=bool)
|
||||||
check_argument('forward_attn_mask', c, restricted=True, val_type=bool)
|
check_argument('forward_attn_mask', c, restricted=is_tacotron(c), val_type=bool)
|
||||||
check_argument('transition_agent', c, restricted=True, val_type=bool)
|
check_argument('transition_agent', c, restricted=is_tacotron(c), val_type=bool)
|
||||||
check_argument('transition_agent', c, restricted=True, val_type=bool)
|
check_argument('transition_agent', c, restricted=is_tacotron(c), val_type=bool)
|
||||||
check_argument('location_attn', c, restricted=True, val_type=bool)
|
check_argument('location_attn', c, restricted=is_tacotron(c), val_type=bool)
|
||||||
check_argument('bidirectional_decoder', c, restricted=True, val_type=bool)
|
check_argument('bidirectional_decoder', c, restricted=is_tacotron(c), val_type=bool)
|
||||||
check_argument('double_decoder_consistency', c, restricted=True, val_type=bool)
|
check_argument('double_decoder_consistency', c, restricted=is_tacotron(c), val_type=bool)
|
||||||
check_argument('ddc_r', c, restricted='double_decoder_consistency' in c.keys(), min_val=1, max_val=7, val_type=int)
|
check_argument('ddc_r', c, restricted='double_decoder_consistency' in c.keys(), min_val=1, max_val=7, val_type=int)
|
||||||
|
|
||||||
# stopnet
|
# stopnet
|
||||||
check_argument('stopnet', c, restricted=True, val_type=bool)
|
check_argument('stopnet', c, restricted=is_tacotron(c), val_type=bool)
|
||||||
check_argument('separate_stopnet', c, restricted=True, val_type=bool)
|
check_argument('separate_stopnet', c, restricted=is_tacotron(c), val_type=bool)
|
||||||
|
|
||||||
|
# GlowTTS parameters
|
||||||
|
check_argument('encoder_type', c, restricted=not is_tacotron(c), val_type=str)
|
||||||
|
|
||||||
# tensorboard
|
# tensorboard
|
||||||
check_argument('print_step', c, restricted=True, val_type=int, min_val=1)
|
check_argument('print_step', c, restricted=True, val_type=int, min_val=1)
|
||||||
|
@ -240,15 +244,16 @@ def check_config_tts(c):
|
||||||
|
|
||||||
# multi-speaker and gst
|
# multi-speaker and gst
|
||||||
check_argument('use_speaker_embedding', c, restricted=True, val_type=bool)
|
check_argument('use_speaker_embedding', c, restricted=True, val_type=bool)
|
||||||
check_argument('use_external_speaker_embedding_file', c, restricted=True, val_type=bool)
|
check_argument('use_external_speaker_embedding_file', c, restricted=True if c['use_speaker_embedding'] else False, val_type=bool)
|
||||||
check_argument('external_speaker_embedding_file', c, restricted=True, val_type=str)
|
check_argument('external_speaker_embedding_file', c, restricted=True if c['use_external_speaker_embedding_file'] else False, val_type=str)
|
||||||
check_argument('use_gst', c, restricted=True, val_type=bool)
|
check_argument('use_gst', c, restricted=is_tacotron(c), val_type=bool)
|
||||||
check_argument('gst', c, restricted=True, val_type=dict)
|
if c['use_gst']:
|
||||||
check_argument('gst_style_input', c['gst'], restricted=True, val_type=[str, dict])
|
check_argument('gst', c, restricted=is_tacotron(c), val_type=dict)
|
||||||
check_argument('gst_embedding_dim', c['gst'], restricted=True, val_type=int, min_val=0, max_val=1000)
|
check_argument('gst_style_input', c['gst'], restricted=is_tacotron(c), val_type=[str, dict])
|
||||||
check_argument('gst_use_speaker_embedding', c['gst'], restricted=True, val_type=bool)
|
check_argument('gst_embedding_dim', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=0, max_val=1000)
|
||||||
check_argument('gst_num_heads', c['gst'], restricted=True, val_type=int, min_val=2, max_val=10)
|
check_argument('gst_use_speaker_embedding', c['gst'], restricted=is_tacotron(c), val_type=bool)
|
||||||
check_argument('gst_style_tokens', c['gst'], restricted=True, val_type=int, min_val=1, max_val=1000)
|
check_argument('gst_num_heads', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=2, max_val=10)
|
||||||
|
check_argument('gst_style_tokens', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=1, max_val=1000)
|
||||||
|
|
||||||
# datasets - checking only the first entry
|
# datasets - checking only the first entry
|
||||||
check_argument('datasets', c, restricted=True, val_type=list)
|
check_argument('datasets', c, restricted=True, val_type=list)
|
||||||
|
|
Loading…
Reference in New Issue