This commit is contained in:
Eren Golge 2018-03-09 15:37:58 -08:00
parent 332abb1de6
commit f434781ab3
2 changed files with 14 additions and 12 deletions

View File

@ -14,7 +14,7 @@
"epochs": 2000,
"lr": 0.00001875,
"warmup_steps": 4000,
"batch_size": 1,
"batch_size": 2,
"eval_batch_size": 32,
"r": 5,

View File

@ -177,9 +177,10 @@ def train(model, criterion, data_loader, optimizer, epoch):
tb.add_audio('SampleAudio', audio_signal, current_step,
sample_rate=c.sample_rate)
except:
print("\n > Error at audio signal on TB!!")
print(audio_signal.max())
print(audio_signal.min())
# print("\n > Error at audio signal on TB!!")
# print(audio_signal.max())
# print(audio_signal.min())
pass
avg_linear_loss /= (num_iter + 1)
@ -197,12 +198,12 @@ def train(model, criterion, data_loader, optimizer, epoch):
def evaluate(model, criterion, data_loader, current_step):
model = model.train()
model = model.eval()
epoch_time = 0
print(" | > Validation")
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
progbar = Progbar(len(data_loader.dataset) / c.batch_size)
progbar = Progbar(len(data_loader.dataset) / c.eval_batch_size)
avg_linear_loss = 0
avg_mel_loss = 0
@ -271,9 +272,10 @@ def evaluate(model, criterion, data_loader, current_step):
tb.add_audio('ValSampleAudio', audio_signal, current_step,
sample_rate=c.sample_rate)
except:
print(" | > Error at audio signal on TB!!")
print(audio_signal.max())
print(audio_signal.min())
# print(" | > Error at audio signal on TB!!")
# print(audio_signal.max())
# print(audio_signal.min())
pass
# compute average losses
avg_linear_loss /= (num_iter + 1)
@ -307,9 +309,9 @@ def main(args):
)
train_loader = DataLoader(train_dataset, batch_size=c.batch_size,
shuffle=False, collate_fn=train_dataset.collate_fn,
drop_last=False, num_workers=c.num_loader_workers,
pin_memory=True)
shuffle=False, collate_fn=train_dataset.collate_fn,
drop_last=False, num_workers=c.num_loader_workers,
pin_memory=True)
val_dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata_val.csv'),
os.path.join(c.data_path, 'wavs'),