From dadefb5dbcec900b7e117b4d865286af2858f133 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Fri, 9 Feb 2018 05:39:58 -0800 Subject: [PATCH] small bug fixes --- datasets/LJSpeech.py | 7 ++++--- notebooks/PlayGround.ipynb | 22 +++++++++++++++++----- train.py | 24 +++++++++--------------- utils/.data.py.swp | Bin 12288 -> 12288 bytes utils/data.py | 1 + utils/generic_utils.py | 22 ++++++++++++++++++++-- 6 files changed, 51 insertions(+), 25 deletions(-) diff --git a/datasets/LJSpeech.py b/datasets/LJSpeech.py index f95dcfc6..98f65daa 100644 --- a/datasets/LJSpeech.py +++ b/datasets/LJSpeech.py @@ -72,14 +72,15 @@ class LJSpeechDataset(Dataset): timesteps = mel.shape[2] # PAD with zeros that can be divided by outputs per step - # if timesteps % self.outputs_per_step != 0: - linear = pad_per_step(linear, self.outputs_per_step) - mel = pad_per_step(mel, self.outputs_per_step) + if timesteps % self.outputs_per_step != 0: + linear = pad_per_step(linear, self.outputs_per_step) + mel = pad_per_step(mel, self.outputs_per_step) # reshape jombo linear = linear.transpose(0, 2, 1) mel = mel.transpose(0, 2, 1) + # convert things to pytorch text_lenghts = torch.LongTensor(text_lenghts) text = torch.LongTensor(text) linear = torch.FloatTensor(linear) diff --git a/notebooks/PlayGround.ipynb b/notebooks/PlayGround.ipynb index 6c2aa281..30fb9f9d 100644 --- a/notebooks/PlayGround.ipynb +++ b/notebooks/PlayGround.ipynb @@ -60,12 +60,12 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ - "ROOT_PATH = '../result/February-03-2018_08:32PM/'\n", - "MODEL_PATH = ROOT_PATH + '/checkpoint_3800.pth.tar'\n", + "ROOT_PATH = '../result/February-08-2018_05:35AM/'\n", + "MODEL_PATH = ROOT_PATH + '/checkpoint_40600.pth.tar'\n", "CONFIG_PATH = ROOT_PATH + '/config.json'\n", "OUT_FOLDER = ROOT_PATH + '/test/'\n", "CONFIG = load_config(CONFIG_PATH)\n", @@ -74,7 +74,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -83,6 +83,18 @@ "text": [ " | > Embedding dim : 149\n" ] + }, + { + "ename": "TypeError", + "evalue": "__init__() missing 1 required positional argument: 'r'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# load the model\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m model = Tacotron(CONFIG.embedding_size, CONFIG.hidden_size,\n\u001b[0;32m----> 3\u001b[0;31m CONFIG.num_mels, CONFIG.num_freq, CONFIG.r)\n\u001b[0m\u001b[1;32m 4\u001b[0m ap = AudioProcessor(CONFIG.sample_rate, CONFIG.num_mels, CONFIG.min_level_db,\n\u001b[1;32m 5\u001b[0m \u001b[0mCONFIG\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mframe_shift_ms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mCONFIG\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mframe_length_ms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mCONFIG\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpreemphasis\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/projects/TTS/models/tacotron.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, embedding_dim, linear_dim, mel_dim, freq_dim, r, padding_idx, use_memory_mask)\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0membedding\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnormal_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0.3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencoder\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mEncoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0membedding_dim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 23\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecoder\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mDecoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmel_dim\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpostnet\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mCBHG\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmel_dim\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mK\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m8\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprojections\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m256\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmel_dim\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mTypeError\u001b[0m: __init__() missing 1 required positional argument: 'r'" + ] } ], "source": [ @@ -332,7 +344,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.3" + "version": "3.6.4" } }, "nbformat": 4, diff --git a/train.py b/train.py index 59ec7447..817cea52 100644 --- a/train.py +++ b/train.py @@ -121,6 +121,7 @@ def main(args): #lr_scheduler = ReduceLROnPlateau(optimizer, factor=c.lr_decay, # patience=c.lr_patience, verbose=True) epoch_time = 0 + best_loss = float('inf') for epoch in range(0, c.epochs): print("\n | > Epoch {}/{}".format(epoch, c.epochs)) @@ -131,7 +132,7 @@ def main(args): text_input = data[0] text_lengths = data[1] - magnitude_input = data[2] + linear_input = data[2] mel_input = data[3] current_step = num_iter + args.restore_step + epoch * len(dataloader) + 1 @@ -153,9 +154,10 @@ def main(args): # convert inputs to variables text_input_var = Variable(text_input) mel_spec_var = Variable(mel_input) - linear_spec_var = Variable(magnitude_input, volatile=True) + linear_spec_var = Variable(linear_input, volatile=True) - # sort sequence by length. Pytorch needs this. + # sort sequence by length. + # TODO: might be unnecessary sorted_lengths, indices = torch.sort( text_lengths.view(-1), dim=0, descending=True) sorted_lengths = sorted_lengths.long().numpy() @@ -211,18 +213,10 @@ def main(args): tb.add_image('Attn/Alignment', align_img, current_step) if current_step % c.save_step == 0: - checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step) - checkpoint_path = os.path.join(OUT_PATH, checkpoint_path) - save_checkpoint({'model': model.state_dict(), - 'optimizer': optimizer.state_dict(), - 'step': current_step, - 'epoch': epoch, - 'total_loss': loss.data[0], - 'linear_loss': linear_loss.data[0], - 'mel_loss': mel_loss.data[0], - 'date': datetime.date.today().strftime("%B %d, %Y")}, - checkpoint_path) - print("\n | > Checkpoint is saved : {}".format(checkpoint_path)) + + # save model + best_loss = save_checkpoint(model, loss.data[0], + best_loss, out_path=OUT_PATH) # Diagnostic visualizations const_spec = linear_output[0].data.cpu().numpy() diff --git a/utils/.data.py.swp b/utils/.data.py.swp index cf746ca8db91b40ae19afec4fcbf3d21e9bc95d4..0aadc6313386d21cc4ae28c00c5e61c51f6c7936 100644 GIT binary patch delta 534 zcmZY5KTE?v7zXguq9C-kMO#FnyoO>MNt5^o1R*#pI4KI&l16jlp-n<^RT0!8h;zBY z#V_DstMoJY1sq(RMHg`w-?UPszVL$_$@ARhYUNsad`I6{UWEm9p5P=1$zGprElp)e zbF@Bg_ph%#UfGd)d7$+c$wm9ptz_aW9(#`xV#Ye(r$rJ*R-WU8JfQ}`JJ&7XBqS3s<)kaDZstMb6~ zsBZ?GdUC0x5b;__iSxpsXdpu!5_hU}5vCsNgj?2MjM=W4F zhZK$(--q6E$on!@5S#9W72Qqqoi62HKwiYF3pe&@oL7AZ98V3nwR?~+if?sO_I&DD zJ~eHNTfZgUNEqvmxYs5!-BtUc;@Z^6?Ya&YIy|R{il%k$%nAbPb2w2Qn^nLtV1|SV PGyH1<^!D+^i#GiOEgg%g delta 396 zcmZojXh@JsG6?hZRWR2xVE_UF28IU$Nl`xXatxZ2CQ7BSCno2Y6y@h_Y)E0={E}In zhf!d&pulH--UGZ03^IHW9RiaD6@-H>@-r|T0pdI$&IRHeAPxj#RUm!=RQVi;p8@e{ zAYKK;Al+O*`~_&zeIPyp#9M%PG7z@_aVGC(MS;&elV9sf^D6*BQEEwPQJz9xfnEVn zOi=($ED@y87F{q!BQvi6t6(XnD9{8gc7 Checkpoint saving : {}".format(checkpoint_path)) + state = {'model': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'step': current_step, + 'epoch': epoch, + 'total_loss': loss.data[0], + 'linear_loss': linear_loss.data[0], + 'mel_loss': mel_loss.data[0], + 'date': datetime.date.today().strftime("%B %d, %Y")} + torch.save(state, checkpoint_path) + if model_loss < best_loss: + best_loss = model_loss + bestmodel_path = 'best_model.pth.tar'.format(current_step) + bestmodel_path = os.path.join(out_path, bestmodel_path) + print("\n | > Best model saving with loss {} : {}".format(model_loss, bestmodel_path)) + torch.save(state, bestmodel_path) + return best_loss def lr_decay(init_lr, global_step):