mirror of https://github.com/coqui-ai/TTS.git
don't use sigmoid output for tacotron, fix bug for memory queue handling, remove maxout
This commit is contained in:
parent
ee706b50f6
commit
2bbb3f7a40
|
@ -1,6 +1,6 @@
|
||||||
{
|
{
|
||||||
"run_name": "ljspeech",
|
"run_name": "ljspeech",
|
||||||
"run_description": "gradual training with prenet frame size 1. Comparing to memory queue in gradual training. ",
|
"run_description": "gradual training with prenet frame size 1 + no maxout for cbhg + symmetric norm.",
|
||||||
|
|
||||||
"audio":{
|
"audio":{
|
||||||
// Audio processing parameters
|
// Audio processing parameters
|
||||||
|
@ -16,8 +16,8 @@
|
||||||
"griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
|
"griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
|
||||||
// Normalization parameters
|
// Normalization parameters
|
||||||
"signal_norm": true, // normalize the spec values in range [0, 1]
|
"signal_norm": true, // normalize the spec values in range [0, 1]
|
||||||
"symmetric_norm": false, // move normalization to range [-1, 1]
|
"symmetric_norm": true, // move normalization to range [-1, 1]
|
||||||
"max_norm": 1, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
|
"max_norm": 4, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
|
||||||
"clip_norm": true, // clip normalized values into the range.
|
"clip_norm": true, // clip normalized values into the range.
|
||||||
"mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
|
"mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
|
||||||
"mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
|
"mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
|
||||||
|
@ -71,7 +71,7 @@
|
||||||
"dataset": "ljspeech", // DATASET-RELATED: one of TTS.dataset.preprocessors depending on your target dataset. Use "tts_cache" for pre-computed dataset by extract_features.py
|
"dataset": "ljspeech", // DATASET-RELATED: one of TTS.dataset.preprocessors depending on your target dataset. Use "tts_cache" for pre-computed dataset by extract_features.py
|
||||||
"min_seq_len": 6, // DATASET-RELATED: minimum text length to use in training
|
"min_seq_len": 6, // DATASET-RELATED: minimum text length to use in training
|
||||||
"max_seq_len": 150, // DATASET-RELATED: maximum text length
|
"max_seq_len": 150, // DATASET-RELATED: maximum text length
|
||||||
"output_path": "../keep/", // DATASET-RELATED: output path for all training outputs.
|
"output_path": "/media/erogol/data_ssd/Models/libri_tts/", // DATASET-RELATED: output path for all training outputs.
|
||||||
"num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
"num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
||||||
"num_val_loader_workers": 4, // number of evaluation data loader processes.
|
"num_val_loader_workers": 4, // number of evaluation data loader processes.
|
||||||
"phoneme_cache_path": "mozilla_us_phonemes", // phoneme computation is slow, therefore, it caches results in the given folder.
|
"phoneme_cache_path": "mozilla_us_phonemes", // phoneme computation is slow, therefore, it caches results in the given folder.
|
||||||
|
|
|
@ -135,9 +135,6 @@ class CBHG(nn.Module):
|
||||||
])
|
])
|
||||||
# max pooling of conv bank, with padding
|
# max pooling of conv bank, with padding
|
||||||
# TODO: try average pooling OR larger kernel size
|
# TODO: try average pooling OR larger kernel size
|
||||||
self.max_pool1d = nn.Sequential(
|
|
||||||
nn.ConstantPad1d([0, 1], value=0),
|
|
||||||
nn.MaxPool1d(kernel_size=2, stride=1, padding=0))
|
|
||||||
out_features = [K * conv_bank_features] + conv_projections[:-1]
|
out_features = [K * conv_bank_features] + conv_projections[:-1]
|
||||||
activations = [self.relu] * (len(conv_projections) - 1)
|
activations = [self.relu] * (len(conv_projections) - 1)
|
||||||
activations += [None]
|
activations += [None]
|
||||||
|
@ -186,7 +183,6 @@ class CBHG(nn.Module):
|
||||||
outs.append(out)
|
outs.append(out)
|
||||||
x = torch.cat(outs, dim=1)
|
x = torch.cat(outs, dim=1)
|
||||||
assert x.size(1) == self.conv_bank_features * len(self.conv1d_banks)
|
assert x.size(1) == self.conv_bank_features * len(self.conv1d_banks)
|
||||||
x = self.max_pool1d(x)
|
|
||||||
for conv1d in self.conv1d_projections:
|
for conv1d in self.conv1d_projections:
|
||||||
x = conv1d(x)
|
x = conv1d(x)
|
||||||
# (B, T_in, hid_feature)
|
# (B, T_in, hid_feature)
|
||||||
|
@ -387,7 +383,7 @@ class Decoder(nn.Module):
|
||||||
del decoder_input
|
del decoder_input
|
||||||
# predict mel vectors from decoder vectors
|
# predict mel vectors from decoder vectors
|
||||||
output = self.proj_to_mel(decoder_output)
|
output = self.proj_to_mel(decoder_output)
|
||||||
output = torch.sigmoid(output)
|
# output = torch.sigmoid(output)
|
||||||
# predict stop token
|
# predict stop token
|
||||||
stopnet_input = torch.cat([decoder_output, output], -1)
|
stopnet_input = torch.cat([decoder_output, output], -1)
|
||||||
del decoder_output
|
del decoder_output
|
||||||
|
@ -403,15 +399,15 @@ class Decoder(nn.Module):
|
||||||
if self.memory_size > self.r:
|
if self.memory_size > self.r:
|
||||||
# memory queue size is larger than number of frames per decoder iter
|
# memory queue size is larger than number of frames per decoder iter
|
||||||
self.memory_input = torch.cat([
|
self.memory_input = torch.cat([
|
||||||
self.memory_input[:, self.r * self.memory_dim:].clone(), new_memory
|
new_memory, self.memory_input[:, :(self.memory_size - self.r) * self.memory_dim].clone()
|
||||||
],
|
],
|
||||||
dim=-1)
|
dim=-1)
|
||||||
else:
|
else:
|
||||||
# memory queue size smaller than number of frames per decoder iter
|
# memory queue size smaller than number of frames per decoder iter
|
||||||
self.memory_input = new_memory[:, (self.r - self.memory_size)*self.memory_dim:]
|
self.memory_input = new_memory[:, :self.memory_size * self.memory_dim]
|
||||||
else:
|
else:
|
||||||
# use only the last frame prediction
|
# use only the last frame prediction
|
||||||
self.memory_input = new_memory[:, (self.r-1) * self.memory_dim:]
|
self.memory_input = new_memory[:, :self.memory_dim]
|
||||||
|
|
||||||
def forward(self, inputs, memory, mask):
|
def forward(self, inputs, memory, mask):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -36,9 +36,7 @@ class Tacotron(nn.Module):
|
||||||
forward_attn, trans_agent, forward_attn_mask,
|
forward_attn, trans_agent, forward_attn_mask,
|
||||||
location_attn, separate_stopnet)
|
location_attn, separate_stopnet)
|
||||||
self.postnet = PostCBHG(mel_dim)
|
self.postnet = PostCBHG(mel_dim)
|
||||||
self.last_linear = nn.Sequential(
|
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim)
|
||||||
nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim),
|
|
||||||
nn.Sigmoid())
|
|
||||||
|
|
||||||
def forward(self, characters, text_lengths, mel_specs, speaker_ids=None):
|
def forward(self, characters, text_lengths, mel_specs, speaker_ids=None):
|
||||||
B = characters.size(0)
|
B = characters.size(0)
|
||||||
|
|
Loading…
Reference in New Issue