From b8885fe4b6f2523e273ad0b45ab074f68bb71386 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Mon, 30 Apr 2018 06:01:02 -0700 Subject: [PATCH] add stop token to tacotron testing --- config.json | 12 ++++++------ tests/tacotron_tests.py | 14 +++++++++++--- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/config.json b/config.json index df3b3372..190d8957 100644 --- a/config.json +++ b/config.json @@ -12,7 +12,7 @@ "text_cleaner": "english_cleaners", "epochs": 50, - "lr": 0.004, + "lr": 0.002, "warmup_steps": 4000, "batch_size": 32, "eval_batch_size":32, @@ -23,14 +23,14 @@ "griffin_lim_iters": 60, "power": 1.2, - "dataset": "TWEB", - "meta_file_train": "transcript_train.txt", - "meta_file_val": "transcript_val.txt", - "data_path": "/data/shared/BibleSpeech/", + "dataset": "LJSpeech", + "meta_file_train": "metadata_train.csv", + "meta_file_val": "metadata_val.csv", + "data_path": "/data/shared/KeithIto/LJSpeech-1.0/", "min_seq_len": 0, "num_loader_workers": 8, "checkpoint": true, - "save_step": 908, + "save_step": 600, "output_path": "/data/shared/erogol_models/" } diff --git a/tests/tacotron_tests.py b/tests/tacotron_tests.py index 378ae1f1..90bad689 100644 --- a/tests/tacotron_tests.py +++ b/tests/tacotron_tests.py @@ -5,6 +5,7 @@ import unittest import numpy as np from torch import optim +from torch import nn from TTS.utils.generic_utils import load_config from TTS.layers.losses import L1LossMasked from TTS.models.tacotron import Tacotron @@ -24,7 +25,11 @@ class TacotronTrainTest(unittest.TestCase): mel_spec = torch.rand(8, 30, c.num_mels).to(device) linear_spec = torch.rand(8, 30, c.num_freq).to(device) mel_lengths = torch.randint(20, 30, (8,)).long().to(device) + stop_targets = torch.zeros(8, 30, 1).float().to(device) + for idx in mel_lengths: + stop_targets[:, int(idx.item()):, 0] = 1.0 criterion = L1LossMasked().to(device) + criterion_st = nn.BCELoss().to(device) model = Tacotron(c.embedding_size, c.num_freq, c.num_mels, @@ -37,17 +42,20 @@ class TacotronTrainTest(unittest.TestCase): count += 1 optimizer = optim.Adam(model.parameters(), lr=c.lr) for i in range(5): - mel_out, linear_out, align = model.forward(input, mel_spec) + mel_out, linear_out, align, stop_tokens = model.forward(input, mel_spec) + assert stop_tokens.data.max() <= 1.0 + assert stop_tokens.data.min() >= 0.0 optimizer.zero_grad() loss = criterion(mel_out, mel_spec, mel_lengths) - loss = 0.5 * loss + 0.5 * criterion(linear_out, linear_spec, mel_lengths) + stop_loss = criterion_st(stop_tokens, stop_targets) + loss = loss + criterion(linear_out, linear_spec, mel_lengths) + stop_loss loss.backward() optimizer.step() # check parameter changes count = 0 for param, param_ref in zip(model.parameters(), model_ref.parameters()): # ignore pre-higway layer since it works conditional - if count not in [139, 59]: + if count not in [141, 59]: assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(count, param.shape, param, param_ref) count += 1