From 851718751131fb8b17906659cde3e3c2e8112fd8 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Fri, 2 Mar 2018 08:01:04 -0800 Subject: [PATCH] Run ready --- config.json | 2 +- datasets/LJSpeech.py | 1 - train.py | 2 ++ 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/config.json b/config.json index cd3aef72..0ad2921b 100644 --- a/config.json +++ b/config.json @@ -25,5 +25,5 @@ "checkpoint": false, "save_step": 69, "data_path": "/run/shm/erogol/LJSpeech-1.0", - "output_path": "result", + "output_path": "result" } diff --git a/datasets/LJSpeech.py b/datasets/LJSpeech.py index ded16ed5..334047a1 100644 --- a/datasets/LJSpeech.py +++ b/datasets/LJSpeech.py @@ -18,7 +18,6 @@ class LJSpeechDataset(Dataset): with open(csv_file, "r") as f: self.frames = [line.split('|') for line in f] - self.frames = self.frames[:256] self.root_dir = root_dir self.outputs_per_step = outputs_per_step self.sample_rate = sample_rate diff --git a/train.py b/train.py index c806f965..53c5698d 100644 --- a/train.py +++ b/train.py @@ -285,6 +285,7 @@ def evaluate(model, criterion, data_loader, current_step): def main(args): + # Setup the dataset train_dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata_train.csv'), os.path.join(c.data_path, 'wavs'), c.r, @@ -325,6 +326,7 @@ def main(args): drop_last=True, num_workers= 4, pin_memory=True) + model = Tacotron(c.embedding_size, c.hidden_size, c.num_mels, c.num_freq,