From 1b9f07918e508cdd88dc30240363c64063ddd9bf Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Fri, 9 Mar 2018 15:37:58 -0800 Subject: [PATCH] bug fix --- config.json | 2 +- train.py | 24 +++++++++++++----------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/config.json b/config.json index b4fbd7a5..29b62a11 100644 --- a/config.json +++ b/config.json @@ -14,7 +14,7 @@ "epochs": 2000, "lr": 0.00001875, "warmup_steps": 4000, - "batch_size": 1, + "batch_size": 2, "eval_batch_size": 32, "r": 5, diff --git a/train.py b/train.py index 8ca4ed78..3b7ff638 100644 --- a/train.py +++ b/train.py @@ -177,9 +177,10 @@ def train(model, criterion, data_loader, optimizer, epoch): tb.add_audio('SampleAudio', audio_signal, current_step, sample_rate=c.sample_rate) except: - print("\n > Error at audio signal on TB!!") - print(audio_signal.max()) - print(audio_signal.min()) + # print("\n > Error at audio signal on TB!!") + # print(audio_signal.max()) + # print(audio_signal.min()) + pass avg_linear_loss /= (num_iter + 1) @@ -197,12 +198,12 @@ def train(model, criterion, data_loader, optimizer, epoch): def evaluate(model, criterion, data_loader, current_step): - model = model.train() + model = model.eval() epoch_time = 0 print(" | > Validation") n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq) - progbar = Progbar(len(data_loader.dataset) / c.batch_size) + progbar = Progbar(len(data_loader.dataset) / c.eval_batch_size) avg_linear_loss = 0 avg_mel_loss = 0 @@ -271,9 +272,10 @@ def evaluate(model, criterion, data_loader, current_step): tb.add_audio('ValSampleAudio', audio_signal, current_step, sample_rate=c.sample_rate) except: - print(" | > Error at audio signal on TB!!") - print(audio_signal.max()) - print(audio_signal.min()) + # print(" | > Error at audio signal on TB!!") + # print(audio_signal.max()) + # print(audio_signal.min()) + pass # compute average losses avg_linear_loss /= (num_iter + 1) @@ -307,9 +309,9 @@ def main(args): ) train_loader = DataLoader(train_dataset, batch_size=c.batch_size, - shuffle=False, collate_fn=train_dataset.collate_fn, - drop_last=False, num_workers=c.num_loader_workers, - pin_memory=True) + shuffle=False, collate_fn=train_dataset.collate_fn, + drop_last=False, num_workers=c.num_loader_workers, + pin_memory=True) val_dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata_val.csv'), os.path.join(c.data_path, 'wavs'),