mirror of https://github.com/coqui-ai/TTS.git
small bug fixes
This commit is contained in:
parent
7d5bcd6ca4
commit
dadefb5dbc
|
@ -72,7 +72,7 @@ class LJSpeechDataset(Dataset):
|
|||
timesteps = mel.shape[2]
|
||||
|
||||
# PAD with zeros that can be divided by outputs per step
|
||||
# if timesteps % self.outputs_per_step != 0:
|
||||
if timesteps % self.outputs_per_step != 0:
|
||||
linear = pad_per_step(linear, self.outputs_per_step)
|
||||
mel = pad_per_step(mel, self.outputs_per_step)
|
||||
|
||||
|
@ -80,6 +80,7 @@ class LJSpeechDataset(Dataset):
|
|||
linear = linear.transpose(0, 2, 1)
|
||||
mel = mel.transpose(0, 2, 1)
|
||||
|
||||
# convert things to pytorch
|
||||
text_lenghts = torch.LongTensor(text_lenghts)
|
||||
text = torch.LongTensor(text)
|
||||
linear = torch.FloatTensor(linear)
|
||||
|
|
|
@ -60,12 +60,12 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ROOT_PATH = '../result/February-03-2018_08:32PM/'\n",
|
||||
"MODEL_PATH = ROOT_PATH + '/checkpoint_3800.pth.tar'\n",
|
||||
"ROOT_PATH = '../result/February-08-2018_05:35AM/'\n",
|
||||
"MODEL_PATH = ROOT_PATH + '/checkpoint_40600.pth.tar'\n",
|
||||
"CONFIG_PATH = ROOT_PATH + '/config.json'\n",
|
||||
"OUT_FOLDER = ROOT_PATH + '/test/'\n",
|
||||
"CONFIG = load_config(CONFIG_PATH)\n",
|
||||
|
@ -74,7 +74,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -83,6 +83,18 @@
|
|||
"text": [
|
||||
" | > Embedding dim : 149\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "TypeError",
|
||||
"evalue": "__init__() missing 1 required positional argument: 'r'",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
|
||||
"\u001b[0;32m<ipython-input-6-561a2605d2cd>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# load the model\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m model = Tacotron(CONFIG.embedding_size, CONFIG.hidden_size,\n\u001b[0;32m----> 3\u001b[0;31m CONFIG.num_mels, CONFIG.num_freq, CONFIG.r)\n\u001b[0m\u001b[1;32m 4\u001b[0m ap = AudioProcessor(CONFIG.sample_rate, CONFIG.num_mels, CONFIG.min_level_db,\n\u001b[1;32m 5\u001b[0m \u001b[0mCONFIG\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mframe_shift_ms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mCONFIG\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mframe_length_ms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mCONFIG\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpreemphasis\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/projects/TTS/models/tacotron.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, embedding_dim, linear_dim, mel_dim, freq_dim, r, padding_idx, use_memory_mask)\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0membedding\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnormal_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0.3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencoder\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mEncoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0membedding_dim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 23\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecoder\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mDecoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmel_dim\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpostnet\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mCBHG\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmel_dim\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mK\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m8\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprojections\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m256\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmel_dim\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;31mTypeError\u001b[0m: __init__() missing 1 required positional argument: 'r'"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
|
@ -332,7 +344,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.3"
|
||||
"version": "3.6.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
24
train.py
24
train.py
|
@ -121,6 +121,7 @@ def main(args):
|
|||
#lr_scheduler = ReduceLROnPlateau(optimizer, factor=c.lr_decay,
|
||||
# patience=c.lr_patience, verbose=True)
|
||||
epoch_time = 0
|
||||
best_loss = float('inf')
|
||||
for epoch in range(0, c.epochs):
|
||||
|
||||
print("\n | > Epoch {}/{}".format(epoch, c.epochs))
|
||||
|
@ -131,7 +132,7 @@ def main(args):
|
|||
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
magnitude_input = data[2]
|
||||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
|
||||
current_step = num_iter + args.restore_step + epoch * len(dataloader) + 1
|
||||
|
@ -153,9 +154,10 @@ def main(args):
|
|||
# convert inputs to variables
|
||||
text_input_var = Variable(text_input)
|
||||
mel_spec_var = Variable(mel_input)
|
||||
linear_spec_var = Variable(magnitude_input, volatile=True)
|
||||
linear_spec_var = Variable(linear_input, volatile=True)
|
||||
|
||||
# sort sequence by length. Pytorch needs this.
|
||||
# sort sequence by length.
|
||||
# TODO: might be unnecessary
|
||||
sorted_lengths, indices = torch.sort(
|
||||
text_lengths.view(-1), dim=0, descending=True)
|
||||
sorted_lengths = sorted_lengths.long().numpy()
|
||||
|
@ -211,18 +213,10 @@ def main(args):
|
|||
tb.add_image('Attn/Alignment', align_img, current_step)
|
||||
|
||||
if current_step % c.save_step == 0:
|
||||
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
|
||||
checkpoint_path = os.path.join(OUT_PATH, checkpoint_path)
|
||||
save_checkpoint({'model': model.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'step': current_step,
|
||||
'epoch': epoch,
|
||||
'total_loss': loss.data[0],
|
||||
'linear_loss': linear_loss.data[0],
|
||||
'mel_loss': mel_loss.data[0],
|
||||
'date': datetime.date.today().strftime("%B %d, %Y")},
|
||||
checkpoint_path)
|
||||
print("\n | > Checkpoint is saved : {}".format(checkpoint_path))
|
||||
|
||||
# save model
|
||||
best_loss = save_checkpoint(model, loss.data[0],
|
||||
best_loss, out_path=OUT_PATH)
|
||||
|
||||
# Diagnostic visualizations
|
||||
const_spec = linear_output[0].data.cpu().numpy()
|
||||
|
|
Binary file not shown.
|
@ -15,6 +15,7 @@ def prepare_data(inputs):
|
|||
|
||||
|
||||
def pad_per_step(inputs, outputs_per_step):
|
||||
"""zero pad inputs if it is not divisible with outputs_per_step (r)"""
|
||||
timesteps = inputs.shape[-1]
|
||||
return np.pad(inputs, [[0, 0], [0, 0],
|
||||
[0, outputs_per_step - (timesteps % outputs_per_step)]],
|
||||
|
|
|
@ -48,8 +48,26 @@ def copy_config_file(config_file, path):
|
|||
shutil.copyfile(config_file, out_path)
|
||||
|
||||
|
||||
def save_checkpoint(state, filename='checkpoint.pth.tar'):
|
||||
torch.save(state, filename)
|
||||
def save_checkpoint(model, model_loss, best_loss, out_path):
|
||||
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
|
||||
checkpoint_path = os.path.join(out_path, checkpoint_path)
|
||||
print("\n | > Checkpoint saving : {}".format(checkpoint_path))
|
||||
state = {'model': model.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'step': current_step,
|
||||
'epoch': epoch,
|
||||
'total_loss': loss.data[0],
|
||||
'linear_loss': linear_loss.data[0],
|
||||
'mel_loss': mel_loss.data[0],
|
||||
'date': datetime.date.today().strftime("%B %d, %Y")}
|
||||
torch.save(state, checkpoint_path)
|
||||
if model_loss < best_loss:
|
||||
best_loss = model_loss
|
||||
bestmodel_path = 'best_model.pth.tar'.format(current_step)
|
||||
bestmodel_path = os.path.join(out_path, bestmodel_path)
|
||||
print("\n | > Best model saving with loss {} : {}".format(model_loss, bestmodel_path))
|
||||
torch.save(state, bestmodel_path)
|
||||
return best_loss
|
||||
|
||||
|
||||
def lr_decay(init_lr, global_step):
|
||||
|
|
Loading…
Reference in New Issue