mirror of https://github.com/coqui-ai/TTS.git
Added support for Tacotron2 GST + abbility to condition style input with wav or tokens
This commit is contained in:
parent
fe081d4f7c
commit
84b7ab6ee6
|
@ -27,7 +27,6 @@ class Tacotron2(TacotronAbstract):
|
||||||
bidirectional_decoder=False,
|
bidirectional_decoder=False,
|
||||||
double_decoder_consistency=False,
|
double_decoder_consistency=False,
|
||||||
ddc_r=None,
|
ddc_r=None,
|
||||||
speaker_embedding_dim=None,
|
|
||||||
gst=False,
|
gst=False,
|
||||||
gst_embedding_dim=512,
|
gst_embedding_dim=512,
|
||||||
gst_num_heads=4,
|
gst_num_heads=4,
|
||||||
|
@ -42,33 +41,18 @@ class Tacotron2(TacotronAbstract):
|
||||||
ddc_r, gst)
|
ddc_r, gst)
|
||||||
|
|
||||||
# init layer dims
|
# init layer dims
|
||||||
decoder_in_features = 512
|
speaker_embedding_dim = 512 if num_speakers > 1 else 0
|
||||||
encoder_in_features = 512
|
gst_embedding_dim = gst_embedding_dim if self.gst else 0
|
||||||
|
decoder_in_features = 512+speaker_embedding_dim+gst_embedding_dim
|
||||||
if speaker_embedding_dim is None:
|
encoder_in_features = 512 if num_speakers > 1 else 512
|
||||||
# if speaker_embedding_dim is None we need use the nn.Embedding, with default speaker_embedding_dim
|
proj_speaker_dim = 80 if num_speakers > 1 else 0
|
||||||
self.embeddings_per_sample = False
|
# base layers
|
||||||
speaker_embedding_dim = 512
|
|
||||||
else:
|
|
||||||
# if speaker_embedding_dim is not None we need use speaker embedding per sample
|
|
||||||
self.embeddings_per_sample = True
|
|
||||||
|
|
||||||
# speaker and gst embeddings is concat in decoder input
|
|
||||||
if num_speakers > 1:
|
|
||||||
decoder_in_features = decoder_in_features + speaker_embedding_dim # add speaker embedding dim
|
|
||||||
if self.gst:
|
|
||||||
decoder_in_features = decoder_in_features + gst_embedding_dim # add gst embedding dim
|
|
||||||
|
|
||||||
# embedding layer
|
|
||||||
self.embedding = nn.Embedding(num_chars, 512, padding_idx=0)
|
self.embedding = nn.Embedding(num_chars, 512, padding_idx=0)
|
||||||
|
|
||||||
# speaker embedding layer
|
# speaker embedding layer
|
||||||
if num_speakers > 1:
|
if num_speakers > 1:
|
||||||
if not self.embeddings_per_sample:
|
self.speaker_embedding = nn.Embedding(num_speakers, speaker_embedding_dim)
|
||||||
self.speaker_embedding = nn.Embedding(num_speakers, speaker_embedding_dim)
|
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
|
||||||
|
|
||||||
# base model layers
|
|
||||||
self.encoder = Encoder(encoder_in_features)
|
self.encoder = Encoder(encoder_in_features)
|
||||||
self.decoder = Decoder(decoder_in_features, self.decoder_output_dim, r, attn_type, attn_win,
|
self.decoder = Decoder(decoder_in_features, self.decoder_output_dim, r, attn_type, attn_win,
|
||||||
attn_norm, prenet_type, prenet_dropout,
|
attn_norm, prenet_type, prenet_dropout,
|
||||||
|
@ -99,7 +83,7 @@ class Tacotron2(TacotronAbstract):
|
||||||
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
|
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
|
||||||
return mel_outputs, mel_outputs_postnet, alignments
|
return mel_outputs, mel_outputs_postnet, alignments
|
||||||
|
|
||||||
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ids=None, speaker_embeddings=None):
|
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ids=None):
|
||||||
# compute mask for padding
|
# compute mask for padding
|
||||||
# B x T_in_max (boolean)
|
# B x T_in_max (boolean)
|
||||||
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
||||||
|
@ -108,18 +92,20 @@ class Tacotron2(TacotronAbstract):
|
||||||
# B x T_in_max x D_en
|
# B x T_in_max x D_en
|
||||||
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
||||||
|
|
||||||
if self.gst:
|
|
||||||
# B x gst_dim
|
|
||||||
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs)
|
|
||||||
|
|
||||||
if self.num_speakers > 1:
|
if self.num_speakers > 1:
|
||||||
if not self.embeddings_per_sample:
|
embedded_speakers = self.speaker_embedding(speaker_ids)[:, None]
|
||||||
# B x 1 x speaker_embed_dim
|
embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1)
|
||||||
speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None]
|
if hasattr(self, 'gst'):
|
||||||
|
# B x gst_dim
|
||||||
|
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, mel_specs)
|
||||||
|
encoder_outputs = torch.cat([encoder_outputs, embedded_gst, embedded_speakers], dim=-1)
|
||||||
else:
|
else:
|
||||||
# B x 1 x speaker_embed_dim
|
encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1)
|
||||||
speaker_embeddings = torch.unsqueeze(speaker_embeddings, 1)
|
else:
|
||||||
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
|
if hasattr(self, 'gst'):
|
||||||
|
# B x gst_dim
|
||||||
|
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, mel_specs)
|
||||||
|
encoder_outputs = torch.cat([encoder_outputs, embedded_gst], dim=-1)
|
||||||
|
|
||||||
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
|
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
|
||||||
|
|
||||||
|
@ -147,18 +133,24 @@ class Tacotron2(TacotronAbstract):
|
||||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference(self, text, speaker_ids=None, style_mel=None, speaker_embeddings=None):
|
def inference(self, text, speaker_ids=None, style_mel=None):
|
||||||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||||
encoder_outputs = self.encoder.inference(embedded_inputs)
|
encoder_outputs = self.encoder.inference(embedded_inputs)
|
||||||
|
|
||||||
if self.gst:
|
|
||||||
# B x gst_dim
|
|
||||||
encoder_outputs = self.compute_gst(encoder_outputs, style_mel)
|
|
||||||
|
|
||||||
if self.num_speakers > 1:
|
if self.num_speakers > 1:
|
||||||
if not self.embeddings_per_sample:
|
embedded_speakers = self.speaker_embedding(speaker_ids)[:, None]
|
||||||
speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None]
|
embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1)
|
||||||
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
|
if hasattr(self, 'gst'):
|
||||||
|
# B x gst_dim
|
||||||
|
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel)
|
||||||
|
encoder_outputs = torch.cat([encoder_outputs, embedded_gst, embedded_speakers], dim=-1)
|
||||||
|
else:
|
||||||
|
encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1)
|
||||||
|
else:
|
||||||
|
if hasattr(self, 'gst'):
|
||||||
|
# B x gst_dim
|
||||||
|
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel)
|
||||||
|
encoder_outputs = torch.cat([encoder_outputs, embedded_gst], dim=-1)
|
||||||
|
|
||||||
decoder_outputs, alignments, stop_tokens = self.decoder.inference(
|
decoder_outputs, alignments, stop_tokens = self.decoder.inference(
|
||||||
encoder_outputs)
|
encoder_outputs)
|
||||||
|
@ -168,21 +160,27 @@ class Tacotron2(TacotronAbstract):
|
||||||
decoder_outputs, postnet_outputs, alignments)
|
decoder_outputs, postnet_outputs, alignments)
|
||||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
||||||
|
|
||||||
def inference_truncated(self, text, speaker_ids=None, style_mel=None, speaker_embeddings=None):
|
def inference_truncated(self, text, speaker_ids=None, style_mel=None):
|
||||||
"""
|
"""
|
||||||
Preserve model states for continuous inference
|
Preserve model states for continuous inference
|
||||||
"""
|
"""
|
||||||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||||
encoder_outputs = self.encoder.inference_truncated(embedded_inputs)
|
encoder_outputs = self.encoder.inference_truncated(embedded_inputs)
|
||||||
|
|
||||||
if self.gst:
|
|
||||||
# B x gst_dim
|
|
||||||
encoder_outputs = self.compute_gst(encoder_outputs, style_mel)
|
|
||||||
|
|
||||||
if self.num_speakers > 1:
|
if self.num_speakers > 1:
|
||||||
if not self.embeddings_per_sample:
|
embedded_speakers = self.speaker_embedding(speaker_ids)[:, None]
|
||||||
speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None]
|
embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1)
|
||||||
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
|
if hasattr(self, 'gst'):
|
||||||
|
# B x gst_dim
|
||||||
|
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel)
|
||||||
|
encoder_outputs = torch.cat([encoder_outputs, embedded_gst, embedded_speakers], dim=-1)
|
||||||
|
else:
|
||||||
|
encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1)
|
||||||
|
else:
|
||||||
|
if hasattr(self, 'gst'):
|
||||||
|
# B x gst_dim
|
||||||
|
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel)
|
||||||
|
encoder_outputs = torch.cat([encoder_outputs, embedded_gst], dim=-1)
|
||||||
|
|
||||||
mel_outputs, alignments, stop_tokens = self.decoder.inference_truncated(
|
mel_outputs, alignments, stop_tokens = self.decoder.inference_truncated(
|
||||||
encoder_outputs)
|
encoder_outputs)
|
||||||
|
|
|
@ -165,7 +165,6 @@ class TacotronAbstract(ABC, nn.Module):
|
||||||
self.speaker_embeddings).squeeze(1)
|
self.speaker_embeddings).squeeze(1)
|
||||||
|
|
||||||
def compute_gst(self, inputs, style_input):
|
def compute_gst(self, inputs, style_input):
|
||||||
""" Compute global style token """
|
|
||||||
device = inputs.device
|
device = inputs.device
|
||||||
if isinstance(style_input, dict):
|
if isinstance(style_input, dict):
|
||||||
query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device)
|
query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device)
|
||||||
|
@ -176,11 +175,17 @@ class TacotronAbstract(ABC, nn.Module):
|
||||||
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
|
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
|
||||||
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
|
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
|
||||||
elif style_input is None:
|
elif style_input is None:
|
||||||
|
query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device)
|
||||||
|
_GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens)
|
||||||
gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device)
|
gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device)
|
||||||
|
for k_token in range(self.gst_style_tokens):
|
||||||
|
key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1)
|
||||||
|
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
|
||||||
|
gst_outputs = gst_outputs + gst_outputs_att * 0
|
||||||
else:
|
else:
|
||||||
gst_outputs = self.gst_layer(style_input) # pylint: disable=not-callable
|
gst_outputs = self.gst_layer(style_input)
|
||||||
inputs = self._concat_speaker_embedding(inputs, gst_outputs)
|
embedded_gst = gst_outputs.repeat(1, inputs.size(1), 1)
|
||||||
return inputs
|
return inputs, embedded_gst
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _add_speaker_embedding(outputs, speaker_embeddings):
|
def _add_speaker_embedding(outputs, speaker_embeddings):
|
||||||
|
|
|
@ -210,10 +210,13 @@ def synthesis(model,
|
||||||
if backend == 'torch':
|
if backend == 'torch':
|
||||||
if speaker_id is not None:
|
if speaker_id is not None:
|
||||||
speaker_id = id_to_torch(speaker_id, cuda=use_cuda)
|
speaker_id = id_to_torch(speaker_id, cuda=use_cuda)
|
||||||
|
<<<<<<< HEAD:mozilla_voice_tts/tts/utils/synthesis.py
|
||||||
|
|
||||||
if speaker_embedding is not None:
|
if speaker_embedding is not None:
|
||||||
speaker_embedding = embedding_to_torch(speaker_embedding, cuda=use_cuda)
|
speaker_embedding = embedding_to_torch(speaker_embedding, cuda=use_cuda)
|
||||||
|
|
||||||
|
=======
|
||||||
|
>>>>>>> Added support for Tacotron2 GST + abbility to condition style input with wav or tokens:utils/synthesis.py
|
||||||
if not isinstance(style_mel, dict):
|
if not isinstance(style_mel, dict):
|
||||||
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
|
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
|
||||||
inputs = numpy_to_torch(inputs, torch.long, cuda=use_cuda)
|
inputs = numpy_to_torch(inputs, torch.long, cuda=use_cuda)
|
||||||
|
|
|
@ -0,0 +1,182 @@
|
||||||
|
# pylint: disable=redefined-outer-name, unused-argument
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
import json
|
||||||
|
import string
|
||||||
|
|
||||||
|
from TTS.utils.synthesis import synthesis
|
||||||
|
from TTS.utils.generic_utils import setup_model
|
||||||
|
from TTS.utils.io import load_config
|
||||||
|
from TTS.utils.text.symbols import make_symbols, symbols, phonemes
|
||||||
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
|
|
||||||
|
def tts(model,
|
||||||
|
vocoder_model,
|
||||||
|
C,
|
||||||
|
VC,
|
||||||
|
text,
|
||||||
|
ap,
|
||||||
|
ap_vocoder,
|
||||||
|
use_cuda,
|
||||||
|
batched_vocoder,
|
||||||
|
speaker_id=None,
|
||||||
|
figures=False):
|
||||||
|
t_1 = time.time()
|
||||||
|
use_vocoder_model = vocoder_model is not None
|
||||||
|
waveform, alignment, _, postnet_output, stop_tokens, _ = synthesis(
|
||||||
|
model, text, C, use_cuda, ap, speaker_id, style_wav=C.gst['gst_style_input'],
|
||||||
|
truncated=False, enable_eos_bos_chars=C.enable_eos_bos_chars,
|
||||||
|
use_griffin_lim=(not use_vocoder_model), do_trim_silence=True)
|
||||||
|
|
||||||
|
if C.model == "Tacotron" and use_vocoder_model:
|
||||||
|
postnet_output = ap.out_linear_to_mel(postnet_output.T).T
|
||||||
|
# correct if there is a scale difference b/w two models
|
||||||
|
if use_vocoder_model:
|
||||||
|
postnet_output = ap._denormalize(postnet_output)
|
||||||
|
postnet_output = ap_vocoder._normalize(postnet_output)
|
||||||
|
vocoder_input = torch.FloatTensor(postnet_output.T).unsqueeze(0)
|
||||||
|
waveform = vocoder_model.generate(
|
||||||
|
vocoder_input.cuda() if use_cuda else vocoder_input,
|
||||||
|
batched=batched_vocoder,
|
||||||
|
target=8000,
|
||||||
|
overlap=400)
|
||||||
|
print(" > Run-time: {}".format(time.time() - t_1))
|
||||||
|
return alignment, postnet_output, stop_tokens, waveform
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
global symbols, phonemes
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('text', type=str, help='Text to generate speech.')
|
||||||
|
parser.add_argument('config_path',
|
||||||
|
type=str,
|
||||||
|
help='Path to model config file.')
|
||||||
|
parser.add_argument(
|
||||||
|
'model_path',
|
||||||
|
type=str,
|
||||||
|
help='Path to model file.',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'out_path',
|
||||||
|
type=str,
|
||||||
|
help='Path to save final wav file. Wav file will be names as the text given.',
|
||||||
|
)
|
||||||
|
parser.add_argument('--use_cuda',
|
||||||
|
type=bool,
|
||||||
|
help='Run model on CUDA.',
|
||||||
|
default=False)
|
||||||
|
parser.add_argument(
|
||||||
|
'--vocoder_path',
|
||||||
|
type=str,
|
||||||
|
help=
|
||||||
|
'Path to vocoder model file. If it is not defined, model uses GL as vocoder. Please make sure that you installed vocoder library before (WaveRNN).',
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
parser.add_argument('--vocoder_config_path',
|
||||||
|
type=str,
|
||||||
|
help='Path to vocoder model config file.',
|
||||||
|
default="")
|
||||||
|
parser.add_argument(
|
||||||
|
'--batched_vocoder',
|
||||||
|
type=bool,
|
||||||
|
help="If True, vocoder model uses faster batch processing.",
|
||||||
|
default=True)
|
||||||
|
parser.add_argument('--speakers_json',
|
||||||
|
type=str,
|
||||||
|
help="JSON file for multi-speaker model.",
|
||||||
|
default="")
|
||||||
|
parser.add_argument(
|
||||||
|
'--speaker_id',
|
||||||
|
type=int,
|
||||||
|
help="target speaker_id if the model is multi-speaker.",
|
||||||
|
default=None)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.vocoder_path != "":
|
||||||
|
assert args.use_cuda, " [!] Enable cuda for vocoder."
|
||||||
|
from WaveRNN.models.wavernn import Model as VocoderModel
|
||||||
|
|
||||||
|
# load the config
|
||||||
|
C = load_config(args.config_path)
|
||||||
|
C.forward_attn_mask = True
|
||||||
|
|
||||||
|
# load the audio processor
|
||||||
|
ap = AudioProcessor(**C.audio)
|
||||||
|
|
||||||
|
# if the vocabulary was passed, replace the default
|
||||||
|
if 'characters' in C.keys():
|
||||||
|
symbols, phonemes = make_symbols(**C.characters)
|
||||||
|
|
||||||
|
# load speakers
|
||||||
|
if args.speakers_json != '':
|
||||||
|
speakers = json.load(open(args.speakers_json, 'r'))
|
||||||
|
num_speakers = len(speakers)
|
||||||
|
else:
|
||||||
|
num_speakers = 0
|
||||||
|
|
||||||
|
# load the model
|
||||||
|
num_chars = len(phonemes) if C.use_phonemes else len(symbols)
|
||||||
|
model = setup_model(num_chars, num_speakers, C)
|
||||||
|
cp = torch.load(args.model_path)
|
||||||
|
model.load_state_dict(cp['model'])
|
||||||
|
model.eval()
|
||||||
|
if args.use_cuda:
|
||||||
|
model.cuda()
|
||||||
|
model.decoder.set_r(cp['r'])
|
||||||
|
|
||||||
|
# load vocoder model
|
||||||
|
if args.vocoder_path != "":
|
||||||
|
VC = load_config(args.vocoder_config_path)
|
||||||
|
ap_vocoder = AudioProcessor(**VC.audio)
|
||||||
|
bits = 10
|
||||||
|
vocoder_model = VocoderModel(rnn_dims=512,
|
||||||
|
fc_dims=512,
|
||||||
|
mode=VC.mode,
|
||||||
|
mulaw=VC.mulaw,
|
||||||
|
pad=VC.pad,
|
||||||
|
upsample_factors=VC.upsample_factors,
|
||||||
|
feat_dims=VC.audio["num_mels"],
|
||||||
|
compute_dims=128,
|
||||||
|
res_out_dims=128,
|
||||||
|
res_blocks=10,
|
||||||
|
hop_length=ap.hop_length,
|
||||||
|
sample_rate=ap.sample_rate,
|
||||||
|
use_aux_net=True,
|
||||||
|
use_upsample_net=True)
|
||||||
|
|
||||||
|
check = torch.load(args.vocoder_path)
|
||||||
|
vocoder_model.load_state_dict(check['model'])
|
||||||
|
vocoder_model.eval()
|
||||||
|
if args.use_cuda:
|
||||||
|
vocoder_model.cuda()
|
||||||
|
else:
|
||||||
|
vocoder_model = None
|
||||||
|
VC = None
|
||||||
|
ap_vocoder = None
|
||||||
|
|
||||||
|
# synthesize voice
|
||||||
|
print(" > Text: {}".format(args.text))
|
||||||
|
_, _, _, wav = tts(model,
|
||||||
|
vocoder_model,
|
||||||
|
C,
|
||||||
|
VC,
|
||||||
|
args.text,
|
||||||
|
ap,
|
||||||
|
ap_vocoder,
|
||||||
|
args.use_cuda,
|
||||||
|
args.batched_vocoder,
|
||||||
|
speaker_id=args.speaker_id,
|
||||||
|
figures=False)
|
||||||
|
|
||||||
|
# save the results
|
||||||
|
file_name = args.text.replace(" ", "_")
|
||||||
|
file_name = file_name.translate(
|
||||||
|
str.maketrans('', '', string.punctuation.replace('_', ''))) + '.wav'
|
||||||
|
out_path = os.path.join(args.out_path, file_name)
|
||||||
|
print(" > Saving output to {}".format(out_path))
|
||||||
|
ap.save_wav(wav, out_path)
|
|
@ -1,3 +1,4 @@
|
||||||
|
<<<<<<< HEAD:tests/inputs/test_train_config.json
|
||||||
{
|
{
|
||||||
"model": "Tacotron2",
|
"model": "Tacotron2",
|
||||||
"run_name": "test_sample_dataset_run",
|
"run_name": "test_sample_dataset_run",
|
||||||
|
@ -150,3 +151,161 @@
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
=======
|
||||||
|
{
|
||||||
|
"model": "Tacotron2",
|
||||||
|
"run_name": "ljspeech-ddc-bn",
|
||||||
|
"run_description": "tacotron2 with ddc and batch-normalization",
|
||||||
|
|
||||||
|
// AUDIO PARAMETERS
|
||||||
|
"audio":{
|
||||||
|
// stft parameters
|
||||||
|
"fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame.
|
||||||
|
"win_length": 1024, // stft window length in ms.
|
||||||
|
"hop_length": 256, // stft window hop-lengh in ms.
|
||||||
|
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
|
||||||
|
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
|
||||||
|
|
||||||
|
// Audio processing parameters
|
||||||
|
"sample_rate": 22050, // DATASET-RELATED: wav sample-rate.
|
||||||
|
"preemphasis": 0.0, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
|
||||||
|
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
|
||||||
|
|
||||||
|
// Silence trimming
|
||||||
|
"do_trim_silence": true,// enable trimming of slience of audio as you load it. LJspeech (true), TWEB (false), Nancy (true)
|
||||||
|
"trim_db": 60, // threshold for timming silence. Set this according to your dataset.
|
||||||
|
|
||||||
|
// Griffin-Lim
|
||||||
|
"power": 1.5, // value to sharpen wav signals after GL algorithm.
|
||||||
|
"griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
|
||||||
|
|
||||||
|
// MelSpectrogram parameters
|
||||||
|
"num_mels": 80, // size of the mel spec frame.
|
||||||
|
"mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
|
||||||
|
"mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
|
||||||
|
"spec_gain": 20,
|
||||||
|
|
||||||
|
// Normalization parameters
|
||||||
|
"signal_norm": true, // normalize spec values. Mean-Var normalization if 'stats_path' is defined otherwise range normalization defined by the other params.
|
||||||
|
"min_level_db": -100, // lower bound for normalization
|
||||||
|
"symmetric_norm": true, // move normalization to range [-1, 1]
|
||||||
|
"max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
|
||||||
|
"clip_norm": true, // clip normalized values into the range.
|
||||||
|
"stats_path": null // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
|
||||||
|
},
|
||||||
|
|
||||||
|
// VOCABULARY PARAMETERS
|
||||||
|
// if custom character set is not defined,
|
||||||
|
// default set in symbols.py is used
|
||||||
|
// "characters":{
|
||||||
|
// "pad": "_",
|
||||||
|
// "eos": "~",
|
||||||
|
// "bos": "^",
|
||||||
|
// "characters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? ",
|
||||||
|
// "punctuations":"!'(),-.:;? ",
|
||||||
|
// "phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ"
|
||||||
|
// },
|
||||||
|
|
||||||
|
// DISTRIBUTED TRAINING
|
||||||
|
"distributed":{
|
||||||
|
"backend": "nccl",
|
||||||
|
"url": "tcp:\/\/localhost:54321"
|
||||||
|
},
|
||||||
|
|
||||||
|
"reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers.
|
||||||
|
|
||||||
|
// TRAINING
|
||||||
|
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
||||||
|
"eval_batch_size":16,
|
||||||
|
"r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
|
||||||
|
"gradual_training": [[0, 7, 64], [1, 5, 64], [50000, 3, 32], [130000, 2, 32], [290000, 1, 32]], //set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled. For Tacotron, you might need to reduce the 'batch_size' as you proceeed.
|
||||||
|
"loss_masking": true, // enable / disable loss masking against the sequence padding.
|
||||||
|
"ga_alpha": 10.0, // weight for guided attention loss. If > 0, guided attention is enabled.
|
||||||
|
|
||||||
|
// VALIDATION
|
||||||
|
"run_eval": true,
|
||||||
|
"test_delay_epochs": 10, //Until attention is aligned, testing only wastes computation time.
|
||||||
|
"test_sentences_file": null, // set a file to load sentences to be used for testing. If it is null then we use default english sentences.
|
||||||
|
|
||||||
|
// OPTIMIZER
|
||||||
|
"noam_schedule": false, // use noam warmup and lr schedule.
|
||||||
|
"grad_clip": 1.0, // upper limit for gradients for clipping.
|
||||||
|
"epochs": 1000, // total number of epochs to train.
|
||||||
|
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
||||||
|
"wd": 0.000001, // Weight decay weight.
|
||||||
|
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
||||||
|
"seq_len_norm": false, // Normalize eash sample loss with its length to alleviate imbalanced datasets. Use it if your dataset is small or has skewed distribution of sequence lengths.
|
||||||
|
|
||||||
|
// TACOTRON PRENET
|
||||||
|
"memory_size": -1, // ONLY TACOTRON - size of the memory queue used fro storing last decoder predictions for auto-regression. If < 0, memory queue is disabled and decoder only uses the last prediction frame.
|
||||||
|
"prenet_type": "bn", // "original" or "bn".
|
||||||
|
"prenet_dropout": false, // enable/disable dropout at prenet.
|
||||||
|
|
||||||
|
// TACOTRON ATTENTION
|
||||||
|
"attention_type": "original", // 'original' or 'graves'
|
||||||
|
"attention_heads": 4, // number of attention heads (only for 'graves')
|
||||||
|
"attention_norm": "sigmoid", // softmax or sigmoid.
|
||||||
|
"windowing": false, // Enables attention windowing. Used only in eval mode.
|
||||||
|
"use_forward_attn": false, // if it uses forward attention. In general, it aligns faster.
|
||||||
|
"forward_attn_mask": false, // Additional masking forcing monotonicity only in eval mode.
|
||||||
|
"transition_agent": false, // enable/disable transition agent of forward attention.
|
||||||
|
"location_attn": true, // enable_disable location sensitive attention. It is enabled for TACOTRON by default.
|
||||||
|
"bidirectional_decoder": false, // use https://arxiv.org/abs/1907.09006. Use it, if attention does not work well with your dataset.
|
||||||
|
"double_decoder_consistency": true, // use DDC explained here https://erogol.com/solving-attention-problems-of-tts-models-with-double-decoder-consistency-draft/
|
||||||
|
"ddc_r": 7, // reduction rate for coarse decoder.
|
||||||
|
|
||||||
|
// STOPNET
|
||||||
|
"stopnet": true, // Train stopnet predicting the end of synthesis.
|
||||||
|
"separate_stopnet": true, // Train stopnet seperately if 'stopnet==true'. It prevents stopnet loss to influence the rest of the model. It causes a better model, but it trains SLOWER.
|
||||||
|
|
||||||
|
// TENSORBOARD and LOGGING
|
||||||
|
"print_step": 25, // Number of steps to log training on console.
|
||||||
|
"tb_plot_step:": 100, // Number of steps to plot TB training figures.
|
||||||
|
"print_eval": false, // If True, it prints intermediate loss values in evalulation.
|
||||||
|
"save_step": 10000, // Number of training steps expected to save traninpg stats and checkpoints.
|
||||||
|
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||||
|
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||||
|
|
||||||
|
// DATA LOADING
|
||||||
|
"text_cleaner": "phoneme_cleaners",
|
||||||
|
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
|
||||||
|
"num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
||||||
|
"num_val_loader_workers": 4, // number of evaluation data loader processes.
|
||||||
|
"batch_group_size": 0, //Number of batches to shuffle after bucketing.
|
||||||
|
"min_seq_len": 6, // DATASET-RELATED: minimum text length to use in training
|
||||||
|
"max_seq_len": 153, // DATASET-RELATED: maximum text length
|
||||||
|
|
||||||
|
// PATHS
|
||||||
|
"output_path": "/home/erogol/Models/LJSpeech/",
|
||||||
|
|
||||||
|
// PHONEMES
|
||||||
|
"phoneme_cache_path": "/media/erogol/data_ssd2/mozilla_us_phonemes_3", // phoneme computation is slow, therefore, it caches results in the given folder.
|
||||||
|
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.
|
||||||
|
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
|
||||||
|
|
||||||
|
// MULTI-SPEAKER and GST
|
||||||
|
"use_speaker_embedding": false, // use speaker embedding to enable multi-speaker learning.
|
||||||
|
"use_gst": true, // use global style tokens
|
||||||
|
"gst": { // gst parameter if gst is enabled
|
||||||
|
"gst_style_input": null, // Condition the style input either on a
|
||||||
|
// -> wave file [path to wave] or
|
||||||
|
// -> dictionary using the style tokens {'token1': 'value', 'token2': 'value'} example {"0": 0.15, "1": 0.15, "5": -0.15}
|
||||||
|
// with the dictionary being len(dict) == len(gst_style_tokens).
|
||||||
|
"gst_embedding_dim": 512,
|
||||||
|
"gst_num_heads": 4,
|
||||||
|
"gst_style_tokens": 10
|
||||||
|
},
|
||||||
|
|
||||||
|
// DATASETS
|
||||||
|
"datasets": // List of datasets. They all merged and they get different speaker_ids.
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "ljspeech",
|
||||||
|
"path": "/home/erogol/Data/LJSpeech-1.1/",
|
||||||
|
"meta_file_train": "metadata.csv",
|
||||||
|
"meta_file_val": null
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
>>>>>>> Added support for Tacotron2 GST + abbility to condition style input with wav or tokens:config.json
|
||||||
|
|
|
@ -0,0 +1,374 @@
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
import torch
|
||||||
|
import shutil
|
||||||
|
import datetime
|
||||||
|
import subprocess
|
||||||
|
import importlib
|
||||||
|
import numpy as np
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
|
|
||||||
|
def get_git_branch():
|
||||||
|
try:
|
||||||
|
out = subprocess.check_output(["git", "branch"]).decode("utf8")
|
||||||
|
current = next(line for line in out.split("\n")
|
||||||
|
if line.startswith("*"))
|
||||||
|
current.replace("* ", "")
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
current = "inside_docker"
|
||||||
|
return current
|
||||||
|
|
||||||
|
|
||||||
|
def get_commit_hash():
|
||||||
|
"""https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script"""
|
||||||
|
# try:
|
||||||
|
# subprocess.check_output(['git', 'diff-index', '--quiet',
|
||||||
|
# 'HEAD']) # Verify client is clean
|
||||||
|
# except:
|
||||||
|
# raise RuntimeError(
|
||||||
|
# " !! Commit before training to get the commit hash.")
|
||||||
|
try:
|
||||||
|
commit = subprocess.check_output(
|
||||||
|
['git', 'rev-parse', '--short', 'HEAD']).decode().strip()
|
||||||
|
# Not copying .git folder into docker container
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
commit = "0000000"
|
||||||
|
print(' > Git Hash: {}'.format(commit))
|
||||||
|
return commit
|
||||||
|
|
||||||
|
|
||||||
|
def create_experiment_folder(root_path, model_name, debug):
|
||||||
|
""" Create a folder with the current date and time """
|
||||||
|
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
|
||||||
|
if debug:
|
||||||
|
commit_hash = 'debug'
|
||||||
|
else:
|
||||||
|
commit_hash = get_commit_hash()
|
||||||
|
output_folder = os.path.join(
|
||||||
|
root_path, model_name + '-' + date_str + '-' + commit_hash)
|
||||||
|
os.makedirs(output_folder, exist_ok=True)
|
||||||
|
print(" > Experiment folder: {}".format(output_folder))
|
||||||
|
return output_folder
|
||||||
|
|
||||||
|
|
||||||
|
def remove_experiment_folder(experiment_path):
|
||||||
|
"""Check folder if there is a checkpoint, otherwise remove the folder"""
|
||||||
|
|
||||||
|
checkpoint_files = glob.glob(experiment_path + "/*.pth.tar")
|
||||||
|
if not checkpoint_files:
|
||||||
|
if os.path.exists(experiment_path):
|
||||||
|
shutil.rmtree(experiment_path, ignore_errors=True)
|
||||||
|
print(" ! Run is removed from {}".format(experiment_path))
|
||||||
|
else:
|
||||||
|
print(" ! Run is kept in {}".format(experiment_path))
|
||||||
|
|
||||||
|
|
||||||
|
def count_parameters(model):
|
||||||
|
r"""Count number of trainable parameters in a network"""
|
||||||
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
|
||||||
|
|
||||||
|
def split_dataset(items):
|
||||||
|
is_multi_speaker = False
|
||||||
|
speakers = [item[-1] for item in items]
|
||||||
|
is_multi_speaker = len(set(speakers)) > 1
|
||||||
|
eval_split_size = 500 if len(items) * 0.01 > 500 else int(
|
||||||
|
len(items) * 0.01)
|
||||||
|
assert eval_split_size > 0, " [!] You do not have enough samples to train. You need at least 100 samples."
|
||||||
|
np.random.seed(0)
|
||||||
|
np.random.shuffle(items)
|
||||||
|
if is_multi_speaker:
|
||||||
|
items_eval = []
|
||||||
|
# most stupid code ever -- Fix it !
|
||||||
|
while len(items_eval) < eval_split_size:
|
||||||
|
speakers = [item[-1] for item in items]
|
||||||
|
speaker_counter = Counter(speakers)
|
||||||
|
item_idx = np.random.randint(0, len(items))
|
||||||
|
if speaker_counter[items[item_idx][-1]] > 1:
|
||||||
|
items_eval.append(items[item_idx])
|
||||||
|
del items[item_idx]
|
||||||
|
return items_eval, items
|
||||||
|
return items[:eval_split_size], items[eval_split_size:]
|
||||||
|
|
||||||
|
|
||||||
|
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
|
||||||
|
def sequence_mask(sequence_length, max_len=None):
|
||||||
|
if max_len is None:
|
||||||
|
max_len = sequence_length.data.max()
|
||||||
|
batch_size = sequence_length.size(0)
|
||||||
|
seq_range = torch.arange(0, max_len).long()
|
||||||
|
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
||||||
|
if sequence_length.is_cuda:
|
||||||
|
seq_range_expand = seq_range_expand.to(sequence_length.device)
|
||||||
|
seq_length_expand = (
|
||||||
|
sequence_length.unsqueeze(1).expand_as(seq_range_expand))
|
||||||
|
# B x T_max
|
||||||
|
return seq_range_expand < seq_length_expand
|
||||||
|
|
||||||
|
|
||||||
|
def set_init_dict(model_dict, checkpoint_state, c):
|
||||||
|
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
|
||||||
|
for k, v in checkpoint_state.items():
|
||||||
|
if k not in model_dict:
|
||||||
|
print(" | > Layer missing in the model definition: {}".format(k))
|
||||||
|
# 1. filter out unnecessary keys
|
||||||
|
pretrained_dict = {
|
||||||
|
k: v
|
||||||
|
for k, v in checkpoint_state.items() if k in model_dict
|
||||||
|
}
|
||||||
|
# 2. filter out different size layers
|
||||||
|
pretrained_dict = {
|
||||||
|
k: v
|
||||||
|
for k, v in pretrained_dict.items()
|
||||||
|
if v.numel() == model_dict[k].numel()
|
||||||
|
}
|
||||||
|
# 3. skip reinit layers
|
||||||
|
if c.reinit_layers is not None:
|
||||||
|
for reinit_layer_name in c.reinit_layers:
|
||||||
|
pretrained_dict = {
|
||||||
|
k: v
|
||||||
|
for k, v in pretrained_dict.items()
|
||||||
|
if reinit_layer_name not in k
|
||||||
|
}
|
||||||
|
# 4. overwrite entries in the existing state dict
|
||||||
|
model_dict.update(pretrained_dict)
|
||||||
|
print(" | > {} / {} layers are restored.".format(len(pretrained_dict),
|
||||||
|
len(model_dict)))
|
||||||
|
return model_dict
|
||||||
|
|
||||||
|
|
||||||
|
def setup_model(num_chars, num_speakers, c):
|
||||||
|
print(" > Using model: {}".format(c.model))
|
||||||
|
MyModel = importlib.import_module('TTS.models.' + c.model.lower())
|
||||||
|
MyModel = getattr(MyModel, c.model)
|
||||||
|
if c.model.lower() in "tacotron":
|
||||||
|
model = MyModel(num_chars=num_chars,
|
||||||
|
num_speakers=num_speakers,
|
||||||
|
r=c.r,
|
||||||
|
postnet_output_dim=int(c.audio['fft_size'] / 2 + 1),
|
||||||
|
decoder_output_dim=c.audio['num_mels'],
|
||||||
|
gst=c.use_gst,
|
||||||
|
gst_embedding_dim=c.gst['gst_embedding_dim'],
|
||||||
|
gst_num_heads=c.gst['gst_num_heads'],
|
||||||
|
gst_style_tokens=c.gst['gst_style_tokens'],
|
||||||
|
memory_size=c.memory_size,
|
||||||
|
attn_type=c.attention_type,
|
||||||
|
attn_win=c.windowing,
|
||||||
|
attn_norm=c.attention_norm,
|
||||||
|
prenet_type=c.prenet_type,
|
||||||
|
prenet_dropout=c.prenet_dropout,
|
||||||
|
forward_attn=c.use_forward_attn,
|
||||||
|
trans_agent=c.transition_agent,
|
||||||
|
forward_attn_mask=c.forward_attn_mask,
|
||||||
|
location_attn=c.location_attn,
|
||||||
|
attn_K=c.attention_heads,
|
||||||
|
separate_stopnet=c.separate_stopnet,
|
||||||
|
bidirectional_decoder=c.bidirectional_decoder,
|
||||||
|
double_decoder_consistency=c.double_decoder_consistency,
|
||||||
|
ddc_r=c.ddc_r)
|
||||||
|
elif c.model.lower() == "tacotron2":
|
||||||
|
model = MyModel(num_chars=num_chars,
|
||||||
|
num_speakers=num_speakers,
|
||||||
|
r=c.r,
|
||||||
|
postnet_output_dim=c.audio['num_mels'],
|
||||||
|
decoder_output_dim=c.audio['num_mels'],
|
||||||
|
gst=c.use_gst,
|
||||||
|
gst_embedding_dim=c.gst['gst_embedding_dim'],
|
||||||
|
gst_num_heads=c.gst['gst_num_heads'],
|
||||||
|
gst_style_tokens=c.gst['gst_style_tokens'],
|
||||||
|
attn_type=c.attention_type,
|
||||||
|
attn_win=c.windowing,
|
||||||
|
attn_norm=c.attention_norm,
|
||||||
|
prenet_type=c.prenet_type,
|
||||||
|
prenet_dropout=c.prenet_dropout,
|
||||||
|
forward_attn=c.use_forward_attn,
|
||||||
|
trans_agent=c.transition_agent,
|
||||||
|
forward_attn_mask=c.forward_attn_mask,
|
||||||
|
location_attn=c.location_attn,
|
||||||
|
attn_K=c.attention_heads,
|
||||||
|
separate_stopnet=c.separate_stopnet,
|
||||||
|
bidirectional_decoder=c.bidirectional_decoder,
|
||||||
|
double_decoder_consistency=c.double_decoder_consistency,
|
||||||
|
ddc_r=c.ddc_r)
|
||||||
|
return model
|
||||||
|
|
||||||
|
class KeepAverage():
|
||||||
|
def __init__(self):
|
||||||
|
self.avg_values = {}
|
||||||
|
self.iters = {}
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self.avg_values[key]
|
||||||
|
|
||||||
|
def items(self):
|
||||||
|
return self.avg_values.items()
|
||||||
|
|
||||||
|
def add_value(self, name, init_val=0, init_iter=0):
|
||||||
|
self.avg_values[name] = init_val
|
||||||
|
self.iters[name] = init_iter
|
||||||
|
|
||||||
|
def update_value(self, name, value, weighted_avg=False):
|
||||||
|
if name not in self.avg_values:
|
||||||
|
# add value if not exist before
|
||||||
|
self.add_value(name, init_val=value)
|
||||||
|
else:
|
||||||
|
# else update existing value
|
||||||
|
if weighted_avg:
|
||||||
|
self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value
|
||||||
|
self.iters[name] += 1
|
||||||
|
else:
|
||||||
|
self.avg_values[name] = self.avg_values[name] * \
|
||||||
|
self.iters[name] + value
|
||||||
|
self.iters[name] += 1
|
||||||
|
self.avg_values[name] /= self.iters[name]
|
||||||
|
|
||||||
|
def add_values(self, name_dict):
|
||||||
|
for key, value in name_dict.items():
|
||||||
|
self.add_value(key, init_val=value)
|
||||||
|
|
||||||
|
def update_values(self, value_dict):
|
||||||
|
for key, value in value_dict.items():
|
||||||
|
self.update_value(key, value)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_argument(name, c, enum_list=None, max_val=None, min_val=None, restricted=False, val_type=None, alternative=None):
|
||||||
|
if alternative in c.keys() and c[alternative] is not None:
|
||||||
|
return
|
||||||
|
if restricted:
|
||||||
|
assert name in c.keys(), f' [!] {name} not defined in config.json'
|
||||||
|
if name in c.keys():
|
||||||
|
if max_val:
|
||||||
|
assert c[name] <= max_val, f' [!] {name} is larger than max value {max_val}'
|
||||||
|
if min_val:
|
||||||
|
assert c[name] >= min_val, f' [!] {name} is smaller than min value {min_val}'
|
||||||
|
if enum_list:
|
||||||
|
assert c[name].lower() in enum_list, f' [!] {name} is not a valid value'
|
||||||
|
if val_type:
|
||||||
|
assert isinstance(c[name], val_type) or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}'
|
||||||
|
|
||||||
|
|
||||||
|
def check_config(c):
|
||||||
|
_check_argument('model', c, enum_list=['tacotron', 'tacotron2'], restricted=True, val_type=str)
|
||||||
|
_check_argument('run_name', c, restricted=True, val_type=str)
|
||||||
|
_check_argument('run_description', c, val_type=str)
|
||||||
|
|
||||||
|
# AUDIO
|
||||||
|
_check_argument('audio', c, restricted=True, val_type=dict)
|
||||||
|
|
||||||
|
# audio processing parameters
|
||||||
|
_check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056)
|
||||||
|
_check_argument('fft_size', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058)
|
||||||
|
_check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000)
|
||||||
|
_check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length')
|
||||||
|
_check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length')
|
||||||
|
_check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1)
|
||||||
|
_check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10)
|
||||||
|
_check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000)
|
||||||
|
_check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5)
|
||||||
|
_check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000)
|
||||||
|
|
||||||
|
# vocabulary parameters
|
||||||
|
_check_argument('characters', c, restricted=False, val_type=dict)
|
||||||
|
_check_argument('pad', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||||
|
_check_argument('eos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||||
|
_check_argument('bos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||||
|
_check_argument('characters', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||||
|
_check_argument('phonemes', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||||
|
_check_argument('punctuations', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||||
|
|
||||||
|
# normalization parameters
|
||||||
|
_check_argument('signal_norm', c['audio'], restricted=True, val_type=bool)
|
||||||
|
_check_argument('symmetric_norm', c['audio'], restricted=True, val_type=bool)
|
||||||
|
_check_argument('max_norm', c['audio'], restricted=True, val_type=float, min_val=0.1, max_val=1000)
|
||||||
|
_check_argument('clip_norm', c['audio'], restricted=True, val_type=bool)
|
||||||
|
_check_argument('mel_fmin', c['audio'], restricted=True, val_type=float, min_val=0.0, max_val=1000)
|
||||||
|
_check_argument('mel_fmax', c['audio'], restricted=True, val_type=float, min_val=500.0)
|
||||||
|
_check_argument('spec_gain', c['audio'], restricted=True, val_type=float, min_val=1, max_val=100)
|
||||||
|
_check_argument('do_trim_silence', c['audio'], restricted=True, val_type=bool)
|
||||||
|
_check_argument('trim_db', c['audio'], restricted=True, val_type=int)
|
||||||
|
|
||||||
|
# training parameters
|
||||||
|
_check_argument('batch_size', c, restricted=True, val_type=int, min_val=1)
|
||||||
|
_check_argument('eval_batch_size', c, restricted=True, val_type=int, min_val=1)
|
||||||
|
_check_argument('r', c, restricted=True, val_type=int, min_val=1)
|
||||||
|
_check_argument('gradual_training', c, restricted=False, val_type=list)
|
||||||
|
_check_argument('loss_masking', c, restricted=True, val_type=bool)
|
||||||
|
# _check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100)
|
||||||
|
|
||||||
|
# validation parameters
|
||||||
|
_check_argument('run_eval', c, restricted=True, val_type=bool)
|
||||||
|
_check_argument('test_delay_epochs', c, restricted=True, val_type=int, min_val=0)
|
||||||
|
_check_argument('test_sentences_file', c, restricted=False, val_type=str)
|
||||||
|
|
||||||
|
# optimizer
|
||||||
|
_check_argument('noam_schedule', c, restricted=False, val_type=bool)
|
||||||
|
_check_argument('grad_clip', c, restricted=True, val_type=float, min_val=0.0)
|
||||||
|
_check_argument('epochs', c, restricted=True, val_type=int, min_val=1)
|
||||||
|
_check_argument('lr', c, restricted=True, val_type=float, min_val=0)
|
||||||
|
_check_argument('wd', c, restricted=True, val_type=float, min_val=0)
|
||||||
|
_check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0)
|
||||||
|
_check_argument('seq_len_norm', c, restricted=True, val_type=bool)
|
||||||
|
|
||||||
|
# tacotron prenet
|
||||||
|
_check_argument('memory_size', c, restricted=True, val_type=int, min_val=-1)
|
||||||
|
_check_argument('prenet_type', c, restricted=True, val_type=str, enum_list=['original', 'bn'])
|
||||||
|
_check_argument('prenet_dropout', c, restricted=True, val_type=bool)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
_check_argument('attention_type', c, restricted=True, val_type=str, enum_list=['graves', 'original'])
|
||||||
|
_check_argument('attention_heads', c, restricted=True, val_type=int)
|
||||||
|
_check_argument('attention_norm', c, restricted=True, val_type=str, enum_list=['sigmoid', 'softmax'])
|
||||||
|
_check_argument('windowing', c, restricted=True, val_type=bool)
|
||||||
|
_check_argument('use_forward_attn', c, restricted=True, val_type=bool)
|
||||||
|
_check_argument('forward_attn_mask', c, restricted=True, val_type=bool)
|
||||||
|
_check_argument('transition_agent', c, restricted=True, val_type=bool)
|
||||||
|
_check_argument('transition_agent', c, restricted=True, val_type=bool)
|
||||||
|
_check_argument('location_attn', c, restricted=True, val_type=bool)
|
||||||
|
_check_argument('bidirectional_decoder', c, restricted=True, val_type=bool)
|
||||||
|
_check_argument('double_decoder_consistency', c, restricted=True, val_type=bool)
|
||||||
|
_check_argument('ddc_r', c, restricted='double_decoder_consistency' in c.keys(), min_val=1, max_val=7, val_type=int)
|
||||||
|
|
||||||
|
# stopnet
|
||||||
|
_check_argument('stopnet', c, restricted=True, val_type=bool)
|
||||||
|
_check_argument('separate_stopnet', c, restricted=True, val_type=bool)
|
||||||
|
|
||||||
|
# tensorboard
|
||||||
|
_check_argument('print_step', c, restricted=True, val_type=int, min_val=1)
|
||||||
|
_check_argument('tb_plot_step', c, restricted=True, val_type=int, min_val=1)
|
||||||
|
_check_argument('save_step', c, restricted=True, val_type=int, min_val=1)
|
||||||
|
_check_argument('checkpoint', c, restricted=True, val_type=bool)
|
||||||
|
_check_argument('tb_model_param_stats', c, restricted=True, val_type=bool)
|
||||||
|
|
||||||
|
# dataloading
|
||||||
|
# pylint: disable=import-outside-toplevel
|
||||||
|
from TTS.utils.text import cleaners
|
||||||
|
_check_argument('text_cleaner', c, restricted=True, val_type=str, enum_list=dir(cleaners))
|
||||||
|
_check_argument('enable_eos_bos_chars', c, restricted=True, val_type=bool)
|
||||||
|
_check_argument('num_loader_workers', c, restricted=True, val_type=int, min_val=0)
|
||||||
|
_check_argument('num_val_loader_workers', c, restricted=True, val_type=int, min_val=0)
|
||||||
|
_check_argument('batch_group_size', c, restricted=True, val_type=int, min_val=0)
|
||||||
|
_check_argument('min_seq_len', c, restricted=True, val_type=int, min_val=0)
|
||||||
|
_check_argument('max_seq_len', c, restricted=True, val_type=int, min_val=10)
|
||||||
|
|
||||||
|
# paths
|
||||||
|
_check_argument('output_path', c, restricted=True, val_type=str)
|
||||||
|
|
||||||
|
# multi-speaker
|
||||||
|
_check_argument('use_speaker_embedding', c, restricted=True, val_type=bool)
|
||||||
|
|
||||||
|
# GST
|
||||||
|
_check_argument('use_gst', c, restricted=True, val_type=bool)
|
||||||
|
_check_argument('gst_style_input', c, restricted=True, val_type=str)
|
||||||
|
_check_argument('gst', c, restricted=True, val_type=dict)
|
||||||
|
_check_argument('gst_embedding_dim', c['gst'], restricted=True, val_type=int, min_val=1)
|
||||||
|
_check_argument('gst_num_heads', c['gst'], restricted=True, val_type=int, min_val=1)
|
||||||
|
_check_argument('gst_style_tokens', c['gst'], restricted=True, val_type=int, min_val=1)
|
||||||
|
|
||||||
|
# datasets - checking only the first entry
|
||||||
|
_check_argument('datasets', c, restricted=True, val_type=list)
|
||||||
|
for dataset_entry in c['datasets']:
|
||||||
|
_check_argument('name', dataset_entry, restricted=True, val_type=str)
|
||||||
|
_check_argument('path', dataset_entry, restricted=True, val_type=str)
|
||||||
|
_check_argument('meta_file_train', dataset_entry, restricted=True, val_type=str)
|
||||||
|
_check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str)
|
Loading…
Reference in New Issue