diff --git a/config.json b/config.json index 4666e5d0..b4fbd7a5 100644 --- a/config.json +++ b/config.json @@ -15,6 +15,7 @@ "lr": 0.00001875, "warmup_steps": 4000, "batch_size": 1, + "eval_batch_size": 32, "r": 5, "griffin_lim_iters": 60, diff --git a/train.py b/train.py index 288dad24..8ca4ed78 100644 --- a/train.py +++ b/train.py @@ -326,7 +326,7 @@ def main(args): c.power ) - val_loader = DataLoader(val_dataset, batch_size=c.batch_size, + val_loader = DataLoader(val_dataset, batch_size=c.eval_batch_size, shuffle=False, collate_fn=val_dataset.collate_fn, drop_last=False, num_workers= 4, pin_memory=True)