small bug fixes

This commit is contained in:
Eren Golge 2018-02-09 05:39:58 -08:00
parent 7d5bcd6ca4
commit dadefb5dbc
6 changed files with 51 additions and 25 deletions

View File

@ -72,14 +72,15 @@ class LJSpeechDataset(Dataset):
timesteps = mel.shape[2]
# PAD with zeros that can be divided by outputs per step
# if timesteps % self.outputs_per_step != 0:
linear = pad_per_step(linear, self.outputs_per_step)
mel = pad_per_step(mel, self.outputs_per_step)
if timesteps % self.outputs_per_step != 0:
linear = pad_per_step(linear, self.outputs_per_step)
mel = pad_per_step(mel, self.outputs_per_step)
# reshape jombo
linear = linear.transpose(0, 2, 1)
mel = mel.transpose(0, 2, 1)
# convert things to pytorch
text_lenghts = torch.LongTensor(text_lenghts)
text = torch.LongTensor(text)
linear = torch.FloatTensor(linear)

View File

@ -60,12 +60,12 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"ROOT_PATH = '../result/February-03-2018_08:32PM/'\n",
"MODEL_PATH = ROOT_PATH + '/checkpoint_3800.pth.tar'\n",
"ROOT_PATH = '../result/February-08-2018_05:35AM/'\n",
"MODEL_PATH = ROOT_PATH + '/checkpoint_40600.pth.tar'\n",
"CONFIG_PATH = ROOT_PATH + '/config.json'\n",
"OUT_FOLDER = ROOT_PATH + '/test/'\n",
"CONFIG = load_config(CONFIG_PATH)\n",
@ -74,7 +74,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 6,
"metadata": {},
"outputs": [
{
@ -83,6 +83,18 @@
"text": [
" | > Embedding dim : 149\n"
]
},
{
"ename": "TypeError",
"evalue": "__init__() missing 1 required positional argument: 'r'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-6-561a2605d2cd>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# load the model\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m model = Tacotron(CONFIG.embedding_size, CONFIG.hidden_size,\n\u001b[0;32m----> 3\u001b[0;31m CONFIG.num_mels, CONFIG.num_freq, CONFIG.r)\n\u001b[0m\u001b[1;32m 4\u001b[0m ap = AudioProcessor(CONFIG.sample_rate, CONFIG.num_mels, CONFIG.min_level_db,\n\u001b[1;32m 5\u001b[0m \u001b[0mCONFIG\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mframe_shift_ms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mCONFIG\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mframe_length_ms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mCONFIG\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpreemphasis\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/projects/TTS/models/tacotron.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, embedding_dim, linear_dim, mel_dim, freq_dim, r, padding_idx, use_memory_mask)\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0membedding\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnormal_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0.3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencoder\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mEncoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0membedding_dim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 23\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecoder\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mDecoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmel_dim\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpostnet\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mCBHG\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmel_dim\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mK\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m8\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprojections\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m256\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmel_dim\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mTypeError\u001b[0m: __init__() missing 1 required positional argument: 'r'"
]
}
],
"source": [
@ -332,7 +344,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
"version": "3.6.4"
}
},
"nbformat": 4,

View File

@ -121,6 +121,7 @@ def main(args):
#lr_scheduler = ReduceLROnPlateau(optimizer, factor=c.lr_decay,
# patience=c.lr_patience, verbose=True)
epoch_time = 0
best_loss = float('inf')
for epoch in range(0, c.epochs):
print("\n | > Epoch {}/{}".format(epoch, c.epochs))
@ -131,7 +132,7 @@ def main(args):
text_input = data[0]
text_lengths = data[1]
magnitude_input = data[2]
linear_input = data[2]
mel_input = data[3]
current_step = num_iter + args.restore_step + epoch * len(dataloader) + 1
@ -153,9 +154,10 @@ def main(args):
# convert inputs to variables
text_input_var = Variable(text_input)
mel_spec_var = Variable(mel_input)
linear_spec_var = Variable(magnitude_input, volatile=True)
linear_spec_var = Variable(linear_input, volatile=True)
# sort sequence by length. Pytorch needs this.
# sort sequence by length.
# TODO: might be unnecessary
sorted_lengths, indices = torch.sort(
text_lengths.view(-1), dim=0, descending=True)
sorted_lengths = sorted_lengths.long().numpy()
@ -211,18 +213,10 @@ def main(args):
tb.add_image('Attn/Alignment', align_img, current_step)
if current_step % c.save_step == 0:
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
checkpoint_path = os.path.join(OUT_PATH, checkpoint_path)
save_checkpoint({'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'step': current_step,
'epoch': epoch,
'total_loss': loss.data[0],
'linear_loss': linear_loss.data[0],
'mel_loss': mel_loss.data[0],
'date': datetime.date.today().strftime("%B %d, %Y")},
checkpoint_path)
print("\n | > Checkpoint is saved : {}".format(checkpoint_path))
# save model
best_loss = save_checkpoint(model, loss.data[0],
best_loss, out_path=OUT_PATH)
# Diagnostic visualizations
const_spec = linear_output[0].data.cpu().numpy()

Binary file not shown.

View File

@ -15,6 +15,7 @@ def prepare_data(inputs):
def pad_per_step(inputs, outputs_per_step):
"""zero pad inputs if it is not divisible with outputs_per_step (r)"""
timesteps = inputs.shape[-1]
return np.pad(inputs, [[0, 0], [0, 0],
[0, outputs_per_step - (timesteps % outputs_per_step)]],

View File

@ -48,8 +48,26 @@ def copy_config_file(config_file, path):
shutil.copyfile(config_file, out_path)
def save_checkpoint(state, filename='checkpoint.pth.tar'):
torch.save(state, filename)
def save_checkpoint(model, model_loss, best_loss, out_path):
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
checkpoint_path = os.path.join(out_path, checkpoint_path)
print("\n | > Checkpoint saving : {}".format(checkpoint_path))
state = {'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'step': current_step,
'epoch': epoch,
'total_loss': loss.data[0],
'linear_loss': linear_loss.data[0],
'mel_loss': mel_loss.data[0],
'date': datetime.date.today().strftime("%B %d, %Y")}
torch.save(state, checkpoint_path)
if model_loss < best_loss:
best_loss = model_loss
bestmodel_path = 'best_model.pth.tar'.format(current_step)
bestmodel_path = os.path.join(out_path, bestmodel_path)
print("\n | > Best model saving with loss {} : {}".format(model_loss, bestmodel_path))
torch.save(state, bestmodel_path)
return best_loss
def lr_decay(init_lr, global_step):