add check arguments for GlowTTS and multispeaker training bug fix

This commit is contained in:
Edresson 2020-10-19 17:17:58 -03:00
parent f632f59f40
commit b7f9ebd32b
3 changed files with 44 additions and 30 deletions

View File

@ -15,7 +15,7 @@ from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.layers.losses import GlowTTSLoss
from TTS.tts.utils.distribute import (DistributedSampler, init_distributed,
reduce_tensor)
from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.generic_utils import setup_model, check_config_tts
from TTS.tts.utils.io import save_best_model, save_checkpoint
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.speakers import (get_speakers, load_speaker_mapping,
@ -602,6 +602,7 @@ if __name__ == '__main__':
# setup output paths and read configs
c = load_config(args.config_path)
# check_config(c)
check_config_tts(c)
_ = os.path.dirname(os.path.realpath(__file__))
if c.apex_amp_level:

View File

@ -37,7 +37,8 @@ class GlowTts(nn.Module):
hidden_channels_enc=None,
hidden_channels_dec=None,
use_encoder_prenet=False,
encoder_type="transformer"):
encoder_type="transformer",
external_speaker_embedding_dim=None):
super().__init__()
self.num_chars = num_chars
@ -68,6 +69,13 @@ class GlowTts(nn.Module):
self.noise_scale = 0.66
self.length_scale = 1.
# if is a multispeaker and c_in_channels is 0, set to 256
if num_speakers > 1:
if self.c_in_channels == 0 and not external_speaker_embedding_dim:
self.c_in_channels = 256
elif external_speaker_embedding_dim:
self.c_in_channels = external_speaker_embedding_dim
self.encoder = Encoder(num_chars,
out_channels=out_channels,
hidden_channels=hidden_channels,
@ -94,7 +102,7 @@ class GlowTts(nn.Module):
sigmoid_scale=sigmoid_scale,
c_in_channels=c_in_channels)
if num_speakers > 1:
if num_speakers > 1 and not external_speaker_embedding_dim:
self.emb_g = nn.Embedding(num_speakers, c_in_channels)
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)

View File

@ -129,10 +129,11 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
use_encoder_prenet=True)
return model
def is_tacotron(c):
return False if c['model'] == 'glow_tts' else True
def check_config_tts(c):
check_argument('model', c, enum_list=['tacotron', 'tacotron2'], restricted=True, val_type=str)
check_argument('model', c, enum_list=['tacotron', 'tacotron2', 'glow_tts'], restricted=True, val_type=str)
check_argument('run_name', c, restricted=True, val_type=str)
check_argument('run_description', c, val_type=str)
@ -195,27 +196,30 @@ def check_config_tts(c):
check_argument('seq_len_norm', c, restricted=True, val_type=bool)
# tacotron prenet
check_argument('memory_size', c, restricted=True, val_type=int, min_val=-1)
check_argument('prenet_type', c, restricted=True, val_type=str, enum_list=['original', 'bn'])
check_argument('prenet_dropout', c, restricted=True, val_type=bool)
check_argument('memory_size', c, restricted=is_tacotron(c), val_type=int, min_val=-1)
check_argument('prenet_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['original', 'bn'])
check_argument('prenet_dropout', c, restricted=is_tacotron(c), val_type=bool)
# attention
check_argument('attention_type', c, restricted=True, val_type=str, enum_list=['graves', 'original'])
check_argument('attention_heads', c, restricted=True, val_type=int)
check_argument('attention_norm', c, restricted=True, val_type=str, enum_list=['sigmoid', 'softmax'])
check_argument('windowing', c, restricted=True, val_type=bool)
check_argument('use_forward_attn', c, restricted=True, val_type=bool)
check_argument('forward_attn_mask', c, restricted=True, val_type=bool)
check_argument('transition_agent', c, restricted=True, val_type=bool)
check_argument('transition_agent', c, restricted=True, val_type=bool)
check_argument('location_attn', c, restricted=True, val_type=bool)
check_argument('bidirectional_decoder', c, restricted=True, val_type=bool)
check_argument('double_decoder_consistency', c, restricted=True, val_type=bool)
check_argument('attention_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['graves', 'original'])
check_argument('attention_heads', c, restricted=is_tacotron(c), val_type=int)
check_argument('attention_norm', c, restricted=is_tacotron(c), val_type=str, enum_list=['sigmoid', 'softmax'])
check_argument('windowing', c, restricted=is_tacotron(c), val_type=bool)
check_argument('use_forward_attn', c, restricted=is_tacotron(c), val_type=bool)
check_argument('forward_attn_mask', c, restricted=is_tacotron(c), val_type=bool)
check_argument('transition_agent', c, restricted=is_tacotron(c), val_type=bool)
check_argument('transition_agent', c, restricted=is_tacotron(c), val_type=bool)
check_argument('location_attn', c, restricted=is_tacotron(c), val_type=bool)
check_argument('bidirectional_decoder', c, restricted=is_tacotron(c), val_type=bool)
check_argument('double_decoder_consistency', c, restricted=is_tacotron(c), val_type=bool)
check_argument('ddc_r', c, restricted='double_decoder_consistency' in c.keys(), min_val=1, max_val=7, val_type=int)
# stopnet
check_argument('stopnet', c, restricted=True, val_type=bool)
check_argument('separate_stopnet', c, restricted=True, val_type=bool)
check_argument('stopnet', c, restricted=is_tacotron(c), val_type=bool)
check_argument('separate_stopnet', c, restricted=is_tacotron(c), val_type=bool)
# GlowTTS parameters
check_argument('encoder_type', c, restricted=not is_tacotron(c), val_type=str)
# tensorboard
check_argument('print_step', c, restricted=True, val_type=int, min_val=1)
@ -240,15 +244,16 @@ def check_config_tts(c):
# multi-speaker and gst
check_argument('use_speaker_embedding', c, restricted=True, val_type=bool)
check_argument('use_external_speaker_embedding_file', c, restricted=True, val_type=bool)
check_argument('external_speaker_embedding_file', c, restricted=True, val_type=str)
check_argument('use_gst', c, restricted=True, val_type=bool)
check_argument('gst', c, restricted=True, val_type=dict)
check_argument('gst_style_input', c['gst'], restricted=True, val_type=[str, dict])
check_argument('gst_embedding_dim', c['gst'], restricted=True, val_type=int, min_val=0, max_val=1000)
check_argument('gst_use_speaker_embedding', c['gst'], restricted=True, val_type=bool)
check_argument('gst_num_heads', c['gst'], restricted=True, val_type=int, min_val=2, max_val=10)
check_argument('gst_style_tokens', c['gst'], restricted=True, val_type=int, min_val=1, max_val=1000)
check_argument('use_external_speaker_embedding_file', c, restricted=True if c['use_speaker_embedding'] else False, val_type=bool)
check_argument('external_speaker_embedding_file', c, restricted=True if c['use_external_speaker_embedding_file'] else False, val_type=str)
check_argument('use_gst', c, restricted=is_tacotron(c), val_type=bool)
if c['use_gst']:
check_argument('gst', c, restricted=is_tacotron(c), val_type=dict)
check_argument('gst_style_input', c['gst'], restricted=is_tacotron(c), val_type=[str, dict])
check_argument('gst_embedding_dim', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=0, max_val=1000)
check_argument('gst_use_speaker_embedding', c['gst'], restricted=is_tacotron(c), val_type=bool)
check_argument('gst_num_heads', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=2, max_val=10)
check_argument('gst_style_tokens', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=1, max_val=1000)
# datasets - checking only the first entry
check_argument('datasets', c, restricted=True, val_type=list)