mirror of https://github.com/coqui-ai/TTS.git
add check arguments for GlowTTS and multispeaker training bug fix
This commit is contained in:
parent
f632f59f40
commit
b7f9ebd32b
|
@ -15,7 +15,7 @@ from TTS.tts.datasets.TTSDataset import MyDataset
|
|||
from TTS.tts.layers.losses import GlowTTSLoss
|
||||
from TTS.tts.utils.distribute import (DistributedSampler, init_distributed,
|
||||
reduce_tensor)
|
||||
from TTS.tts.utils.generic_utils import setup_model
|
||||
from TTS.tts.utils.generic_utils import setup_model, check_config_tts
|
||||
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.speakers import (get_speakers, load_speaker_mapping,
|
||||
|
@ -602,6 +602,7 @@ if __name__ == '__main__':
|
|||
# setup output paths and read configs
|
||||
c = load_config(args.config_path)
|
||||
# check_config(c)
|
||||
check_config_tts(c)
|
||||
_ = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
if c.apex_amp_level:
|
||||
|
|
|
@ -37,7 +37,8 @@ class GlowTts(nn.Module):
|
|||
hidden_channels_enc=None,
|
||||
hidden_channels_dec=None,
|
||||
use_encoder_prenet=False,
|
||||
encoder_type="transformer"):
|
||||
encoder_type="transformer",
|
||||
external_speaker_embedding_dim=None):
|
||||
|
||||
super().__init__()
|
||||
self.num_chars = num_chars
|
||||
|
@ -68,6 +69,13 @@ class GlowTts(nn.Module):
|
|||
self.noise_scale = 0.66
|
||||
self.length_scale = 1.
|
||||
|
||||
# if is a multispeaker and c_in_channels is 0, set to 256
|
||||
if num_speakers > 1:
|
||||
if self.c_in_channels == 0 and not external_speaker_embedding_dim:
|
||||
self.c_in_channels = 256
|
||||
elif external_speaker_embedding_dim:
|
||||
self.c_in_channels = external_speaker_embedding_dim
|
||||
|
||||
self.encoder = Encoder(num_chars,
|
||||
out_channels=out_channels,
|
||||
hidden_channels=hidden_channels,
|
||||
|
@ -94,7 +102,7 @@ class GlowTts(nn.Module):
|
|||
sigmoid_scale=sigmoid_scale,
|
||||
c_in_channels=c_in_channels)
|
||||
|
||||
if num_speakers > 1:
|
||||
if num_speakers > 1 and not external_speaker_embedding_dim:
|
||||
self.emb_g = nn.Embedding(num_speakers, c_in_channels)
|
||||
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
|
||||
|
||||
|
|
|
@ -129,10 +129,11 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
|||
use_encoder_prenet=True)
|
||||
return model
|
||||
|
||||
|
||||
def is_tacotron(c):
|
||||
return False if c['model'] == 'glow_tts' else True
|
||||
|
||||
def check_config_tts(c):
|
||||
check_argument('model', c, enum_list=['tacotron', 'tacotron2'], restricted=True, val_type=str)
|
||||
check_argument('model', c, enum_list=['tacotron', 'tacotron2', 'glow_tts'], restricted=True, val_type=str)
|
||||
check_argument('run_name', c, restricted=True, val_type=str)
|
||||
check_argument('run_description', c, val_type=str)
|
||||
|
||||
|
@ -195,27 +196,30 @@ def check_config_tts(c):
|
|||
check_argument('seq_len_norm', c, restricted=True, val_type=bool)
|
||||
|
||||
# tacotron prenet
|
||||
check_argument('memory_size', c, restricted=True, val_type=int, min_val=-1)
|
||||
check_argument('prenet_type', c, restricted=True, val_type=str, enum_list=['original', 'bn'])
|
||||
check_argument('prenet_dropout', c, restricted=True, val_type=bool)
|
||||
check_argument('memory_size', c, restricted=is_tacotron(c), val_type=int, min_val=-1)
|
||||
check_argument('prenet_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['original', 'bn'])
|
||||
check_argument('prenet_dropout', c, restricted=is_tacotron(c), val_type=bool)
|
||||
|
||||
# attention
|
||||
check_argument('attention_type', c, restricted=True, val_type=str, enum_list=['graves', 'original'])
|
||||
check_argument('attention_heads', c, restricted=True, val_type=int)
|
||||
check_argument('attention_norm', c, restricted=True, val_type=str, enum_list=['sigmoid', 'softmax'])
|
||||
check_argument('windowing', c, restricted=True, val_type=bool)
|
||||
check_argument('use_forward_attn', c, restricted=True, val_type=bool)
|
||||
check_argument('forward_attn_mask', c, restricted=True, val_type=bool)
|
||||
check_argument('transition_agent', c, restricted=True, val_type=bool)
|
||||
check_argument('transition_agent', c, restricted=True, val_type=bool)
|
||||
check_argument('location_attn', c, restricted=True, val_type=bool)
|
||||
check_argument('bidirectional_decoder', c, restricted=True, val_type=bool)
|
||||
check_argument('double_decoder_consistency', c, restricted=True, val_type=bool)
|
||||
check_argument('attention_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['graves', 'original'])
|
||||
check_argument('attention_heads', c, restricted=is_tacotron(c), val_type=int)
|
||||
check_argument('attention_norm', c, restricted=is_tacotron(c), val_type=str, enum_list=['sigmoid', 'softmax'])
|
||||
check_argument('windowing', c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument('use_forward_attn', c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument('forward_attn_mask', c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument('transition_agent', c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument('transition_agent', c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument('location_attn', c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument('bidirectional_decoder', c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument('double_decoder_consistency', c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument('ddc_r', c, restricted='double_decoder_consistency' in c.keys(), min_val=1, max_val=7, val_type=int)
|
||||
|
||||
# stopnet
|
||||
check_argument('stopnet', c, restricted=True, val_type=bool)
|
||||
check_argument('separate_stopnet', c, restricted=True, val_type=bool)
|
||||
check_argument('stopnet', c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument('separate_stopnet', c, restricted=is_tacotron(c), val_type=bool)
|
||||
|
||||
# GlowTTS parameters
|
||||
check_argument('encoder_type', c, restricted=not is_tacotron(c), val_type=str)
|
||||
|
||||
# tensorboard
|
||||
check_argument('print_step', c, restricted=True, val_type=int, min_val=1)
|
||||
|
@ -240,15 +244,16 @@ def check_config_tts(c):
|
|||
|
||||
# multi-speaker and gst
|
||||
check_argument('use_speaker_embedding', c, restricted=True, val_type=bool)
|
||||
check_argument('use_external_speaker_embedding_file', c, restricted=True, val_type=bool)
|
||||
check_argument('external_speaker_embedding_file', c, restricted=True, val_type=str)
|
||||
check_argument('use_gst', c, restricted=True, val_type=bool)
|
||||
check_argument('gst', c, restricted=True, val_type=dict)
|
||||
check_argument('gst_style_input', c['gst'], restricted=True, val_type=[str, dict])
|
||||
check_argument('gst_embedding_dim', c['gst'], restricted=True, val_type=int, min_val=0, max_val=1000)
|
||||
check_argument('gst_use_speaker_embedding', c['gst'], restricted=True, val_type=bool)
|
||||
check_argument('gst_num_heads', c['gst'], restricted=True, val_type=int, min_val=2, max_val=10)
|
||||
check_argument('gst_style_tokens', c['gst'], restricted=True, val_type=int, min_val=1, max_val=1000)
|
||||
check_argument('use_external_speaker_embedding_file', c, restricted=True if c['use_speaker_embedding'] else False, val_type=bool)
|
||||
check_argument('external_speaker_embedding_file', c, restricted=True if c['use_external_speaker_embedding_file'] else False, val_type=str)
|
||||
check_argument('use_gst', c, restricted=is_tacotron(c), val_type=bool)
|
||||
if c['use_gst']:
|
||||
check_argument('gst', c, restricted=is_tacotron(c), val_type=dict)
|
||||
check_argument('gst_style_input', c['gst'], restricted=is_tacotron(c), val_type=[str, dict])
|
||||
check_argument('gst_embedding_dim', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=0, max_val=1000)
|
||||
check_argument('gst_use_speaker_embedding', c['gst'], restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument('gst_num_heads', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=2, max_val=10)
|
||||
check_argument('gst_style_tokens', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=1, max_val=1000)
|
||||
|
||||
# datasets - checking only the first entry
|
||||
check_argument('datasets', c, restricted=True, val_type=list)
|
||||
|
|
Loading…
Reference in New Issue