diff --git a/TTS/bin/train_glow_tts.py b/TTS/bin/train_glow_tts.py index 3d34d978..c5e570e5 100644 --- a/TTS/bin/train_glow_tts.py +++ b/TTS/bin/train_glow_tts.py @@ -15,7 +15,7 @@ from TTS.tts.datasets.TTSDataset import MyDataset from TTS.tts.layers.losses import GlowTTSLoss from TTS.tts.utils.distribute import (DistributedSampler, init_distributed, reduce_tensor) -from TTS.tts.utils.generic_utils import setup_model +from TTS.tts.utils.generic_utils import setup_model, check_config_tts from TTS.tts.utils.io import save_best_model, save_checkpoint from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.speakers import (get_speakers, load_speaker_mapping, @@ -602,6 +602,7 @@ if __name__ == '__main__': # setup output paths and read configs c = load_config(args.config_path) # check_config(c) + check_config_tts(c) _ = os.path.dirname(os.path.realpath(__file__)) if c.apex_amp_level: diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 902de699..a9b6f8c0 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -37,7 +37,8 @@ class GlowTts(nn.Module): hidden_channels_enc=None, hidden_channels_dec=None, use_encoder_prenet=False, - encoder_type="transformer"): + encoder_type="transformer", + external_speaker_embedding_dim=None): super().__init__() self.num_chars = num_chars @@ -68,6 +69,13 @@ class GlowTts(nn.Module): self.noise_scale = 0.66 self.length_scale = 1. + # if is a multispeaker and c_in_channels is 0, set to 256 + if num_speakers > 1: + if self.c_in_channels == 0 and not external_speaker_embedding_dim: + self.c_in_channels = 256 + elif external_speaker_embedding_dim: + self.c_in_channels = external_speaker_embedding_dim + self.encoder = Encoder(num_chars, out_channels=out_channels, hidden_channels=hidden_channels, @@ -94,7 +102,7 @@ class GlowTts(nn.Module): sigmoid_scale=sigmoid_scale, c_in_channels=c_in_channels) - if num_speakers > 1: + if num_speakers > 1 and not external_speaker_embedding_dim: self.emb_g = nn.Embedding(num_speakers, c_in_channels) nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index 5480cbcd..aacac898 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -129,10 +129,11 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): use_encoder_prenet=True) return model - +def is_tacotron(c): + return False if c['model'] == 'glow_tts' else True def check_config_tts(c): - check_argument('model', c, enum_list=['tacotron', 'tacotron2'], restricted=True, val_type=str) + check_argument('model', c, enum_list=['tacotron', 'tacotron2', 'glow_tts'], restricted=True, val_type=str) check_argument('run_name', c, restricted=True, val_type=str) check_argument('run_description', c, val_type=str) @@ -195,27 +196,30 @@ def check_config_tts(c): check_argument('seq_len_norm', c, restricted=True, val_type=bool) # tacotron prenet - check_argument('memory_size', c, restricted=True, val_type=int, min_val=-1) - check_argument('prenet_type', c, restricted=True, val_type=str, enum_list=['original', 'bn']) - check_argument('prenet_dropout', c, restricted=True, val_type=bool) + check_argument('memory_size', c, restricted=is_tacotron(c), val_type=int, min_val=-1) + check_argument('prenet_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['original', 'bn']) + check_argument('prenet_dropout', c, restricted=is_tacotron(c), val_type=bool) # attention - check_argument('attention_type', c, restricted=True, val_type=str, enum_list=['graves', 'original']) - check_argument('attention_heads', c, restricted=True, val_type=int) - check_argument('attention_norm', c, restricted=True, val_type=str, enum_list=['sigmoid', 'softmax']) - check_argument('windowing', c, restricted=True, val_type=bool) - check_argument('use_forward_attn', c, restricted=True, val_type=bool) - check_argument('forward_attn_mask', c, restricted=True, val_type=bool) - check_argument('transition_agent', c, restricted=True, val_type=bool) - check_argument('transition_agent', c, restricted=True, val_type=bool) - check_argument('location_attn', c, restricted=True, val_type=bool) - check_argument('bidirectional_decoder', c, restricted=True, val_type=bool) - check_argument('double_decoder_consistency', c, restricted=True, val_type=bool) + check_argument('attention_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['graves', 'original']) + check_argument('attention_heads', c, restricted=is_tacotron(c), val_type=int) + check_argument('attention_norm', c, restricted=is_tacotron(c), val_type=str, enum_list=['sigmoid', 'softmax']) + check_argument('windowing', c, restricted=is_tacotron(c), val_type=bool) + check_argument('use_forward_attn', c, restricted=is_tacotron(c), val_type=bool) + check_argument('forward_attn_mask', c, restricted=is_tacotron(c), val_type=bool) + check_argument('transition_agent', c, restricted=is_tacotron(c), val_type=bool) + check_argument('transition_agent', c, restricted=is_tacotron(c), val_type=bool) + check_argument('location_attn', c, restricted=is_tacotron(c), val_type=bool) + check_argument('bidirectional_decoder', c, restricted=is_tacotron(c), val_type=bool) + check_argument('double_decoder_consistency', c, restricted=is_tacotron(c), val_type=bool) check_argument('ddc_r', c, restricted='double_decoder_consistency' in c.keys(), min_val=1, max_val=7, val_type=int) # stopnet - check_argument('stopnet', c, restricted=True, val_type=bool) - check_argument('separate_stopnet', c, restricted=True, val_type=bool) + check_argument('stopnet', c, restricted=is_tacotron(c), val_type=bool) + check_argument('separate_stopnet', c, restricted=is_tacotron(c), val_type=bool) + + # GlowTTS parameters + check_argument('encoder_type', c, restricted=not is_tacotron(c), val_type=str) # tensorboard check_argument('print_step', c, restricted=True, val_type=int, min_val=1) @@ -240,15 +244,16 @@ def check_config_tts(c): # multi-speaker and gst check_argument('use_speaker_embedding', c, restricted=True, val_type=bool) - check_argument('use_external_speaker_embedding_file', c, restricted=True, val_type=bool) - check_argument('external_speaker_embedding_file', c, restricted=True, val_type=str) - check_argument('use_gst', c, restricted=True, val_type=bool) - check_argument('gst', c, restricted=True, val_type=dict) - check_argument('gst_style_input', c['gst'], restricted=True, val_type=[str, dict]) - check_argument('gst_embedding_dim', c['gst'], restricted=True, val_type=int, min_val=0, max_val=1000) - check_argument('gst_use_speaker_embedding', c['gst'], restricted=True, val_type=bool) - check_argument('gst_num_heads', c['gst'], restricted=True, val_type=int, min_val=2, max_val=10) - check_argument('gst_style_tokens', c['gst'], restricted=True, val_type=int, min_val=1, max_val=1000) + check_argument('use_external_speaker_embedding_file', c, restricted=True if c['use_speaker_embedding'] else False, val_type=bool) + check_argument('external_speaker_embedding_file', c, restricted=True if c['use_external_speaker_embedding_file'] else False, val_type=str) + check_argument('use_gst', c, restricted=is_tacotron(c), val_type=bool) + if c['use_gst']: + check_argument('gst', c, restricted=is_tacotron(c), val_type=dict) + check_argument('gst_style_input', c['gst'], restricted=is_tacotron(c), val_type=[str, dict]) + check_argument('gst_embedding_dim', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=0, max_val=1000) + check_argument('gst_use_speaker_embedding', c['gst'], restricted=is_tacotron(c), val_type=bool) + check_argument('gst_num_heads', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=2, max_val=10) + check_argument('gst_style_tokens', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=1, max_val=1000) # datasets - checking only the first entry check_argument('datasets', c, restricted=True, val_type=list)